1 /**
2 * TLS Blocking API
3 * 
4 * Copyright:
5 * (C) 2013,2015 Jack Lloyd
6 * (C) 2014-2015 Etienne Cimon
7 *
8 * License:
9 * Botan is released under the Simplified BSD License (see LICENSE.md)
10 */
11 module botan.tls.blocking;
12 
13 import botan.constants;
14 static if (BOTAN_HAS_TLS):
15 
16 import std.exception : enforce;
17 import botan.tls.client;
18 import botan.tls.server;
19 import botan.rng.rng;
20 import botan.tls.channel;
21 import botan.tls.session_manager;
22 import botan.tls.version_;
23 import botan.utils.mem_ops;
24 import memutils.circularbuffer;
25 import memutils.utils;
26 import std.algorithm;
27 
28 alias DataReader = ubyte[] delegate(ubyte[]);
29 
30 /**
31 * Blocking TLS Channel
32 */
33 struct TLSBlockingChannel
34 {
35 public:
36     @disable this(this);
37     @disable this();
38 
39     /// Client constructor
40     this(DataReader read_fn,
41          DataWriter write_fn,
42 		 OnAlert alert_cb,
43 		 OnHandshakeComplete hs_cb,
44          TLSSessionManager session_manager,
45          TLSCredentialsManager creds,
46          TLSPolicy policy,
47          RandomNumberGenerator rng,
48          in TLSServerInformation server_info = TLSServerInformation(),
49          in TLSProtocolVersion offer_version = TLSProtocolVersion.latestTlsVersion(),
50          Vector!string next_protocols = Vector!string())
51     {
52         m_is_client = true;
53         m_read_fn = read_fn;
54 		m_alert_cb = alert_cb;
55 		m_handshake_complete = hs_cb;
56 		m_readbuf = Vector!ubyte(TLS_DEFAULT_BUFFERSIZE);
57 		scope(failure) m_readbuf.destroy();
58         m_impl.client = new TLSClient(write_fn, &dataCb, &alertCb, &handshakeCb, session_manager, creds,
59             policy, rng, server_info, offer_version, next_protocols.move);
60     }
61 
62     /// Server constructor
63     this(DataReader read_fn,
64          DataWriter write_fn,
65 		 OnAlert alert_cb,
66 		 OnHandshakeComplete hs_cb,
67          TLSSessionManager session_manager,
68          TLSCredentialsManager creds,
69          TLSPolicy policy,
70          RandomNumberGenerator rng,
71          NextProtocolHandler next_proto = null,
72 		 SNIHandler sni_handler = null,
73          bool is_datagram = false,
74          size_t io_buf_sz = 16*1024)
75     {
76         m_is_client = false;
77         m_read_fn = read_fn;
78 		m_alert_cb = alert_cb;
79 		m_handshake_complete = hs_cb;
80 		m_readbuf = Vector!ubyte(TLS_DEFAULT_BUFFERSIZE);
81 		scope(failure) m_readbuf.destroy();
82         m_impl.server = new TLSServer(write_fn, &dataCb, &alertCb, &handshakeCb, session_manager, creds,
83 			policy, rng, next_proto, sni_handler, is_datagram, io_buf_sz);
84     }
85 
86     /**
87     * Blocks until the full handhsake is complete
88     */
89     void doHandshake()
90 	{
91         while (!m_closed && channel !is null && !channel.isActive())
92         {
93             ubyte[] readref = m_readbuf.ptr[0 .. m_readbuf.length];
94             const ubyte[] from_socket = m_read_fn(readref);
95 			enforce(channel!is null, "Connection closed during handshake");
96             channel.receivedData(cast(const(ubyte)*)from_socket.ptr, from_socket.length);
97         }
98     }
99 
100     /**
101     * Number of bytes pending read in the plaintext buffer (bytes
102     * readable without blocking)
103     */
104 	size_t pending() const { return m_plaintext.length; }
105 
106 	/// Returns an array of pending data
107 	const(ubyte)[] peek() {
108 		return m_plaintext.length > 0 ? m_plaintext.peek : null;
109 	}
110 
111     /// Reads until the destination ubyte array is full, utilizing internal buffers if necessary
112     void read(ubyte[] dest) 
113     {
114 		enforce(dest.length > 0, "Empty destination array");
115 		ubyte[] destlog = dest;
116 		//logDebug("remaining length: ", dest.length);
117         ubyte[] remaining = dest;
118 		int i;
119         while (remaining.length > 0) {
120             dest = readBuf(remaining);
121 			enforce(++i < 1000 && dest.length > 0, "readBuf returned 0 length (connection closed)");
122             remaining = remaining[dest.length .. $];
123 			//logDebug("remaining length: ", remaining.length);
124         }
125 		//logDebug("finished with: ", cast(string) destlog);
126     }
127 
128     /**
129     * Blocking ( if !pending() ) read, will return at least 1 ubyte or 0 on connection close
130     *  supports replacement of internal read buffer when called until buf.length != returned buffer length
131     */
132 	ubyte[] readBuf(ubyte[] buf)
133     {
134 		m_reading = true;
135 		scope(exit) m_reading = false;
136 
137 		if (m_plaintext.length != 0) {
138 			size_t len = min(m_plaintext.length, buf.length);
139 			m_plaintext.read(buf[0 .. len]);
140 			return buf[0 .. len];
141 		}
142 
143         // if there's nothing in the buffers, read some packets and process them
144 		while (m_plaintext.empty)
145         {
146 			ubyte[] slice;
147 			if (m_readbuf.length > 0) {
148 				slice = m_readbuf.ptr[0 .. m_readbuf.length];
149 			}
150 			const ubyte[] from_socket = m_read_fn(slice);
151 			if (from_socket.length == 0)
152 				return null;
153 
154 			enforce(channel !is null, "Connection closed while reading from TLS Channel");
155 			channel.receivedData(cast(const(ubyte)*)from_socket.ptr, from_socket.length);
156 
157 			if (from_socket.length == slice.length && m_readbuf.length < 256*1024) {
158 				size_t next_len = m_readbuf.length * 2;
159 				m_readbuf.destroy();
160 				m_readbuf = Vector!ubyte(next_len);
161 				// increase for next time
162 			}
163 
164         }
165 
166 		if (buf.length == 0) return null;
167 
168         const size_t returned = std.algorithm.min(buf.length, m_plaintext.length);
169 		if (returned == 0) {
170 			//logDebug("Destroyed return object");
171 			return null;
172 		}
173 		m_plaintext.read(buf[0 .. returned]);
174 
175         
176 		//logDebug("Returning data");
177         return buf[0 .. returned];
178     }
179 
180 	void write(in ubyte[] buf) { 
181 		m_writing = true;
182 		scope(exit) m_writing = false;
183 
184 		enforce(channel !is null, "Connection closed when attempting to write to channel"); 
185 		channel.send(cast(const(ubyte)*)buf.ptr, buf.length);
186 	}
187 
188     inout(TLSChannel) underlyingChannel() inout { return channel; }
189 
190 	void close() { enforce(channel); m_closed = true; channel.close(); }
191 
192 	bool isClosed() const { return m_closed || m_impl.client is null; }
193 
194 	@property bool isBusy() const { return m_reading || m_writing; }
195 
196 	const(Vector!X509Certificate) peerCertChain() const { enforce(channel); return channel.peerCertChain(); }
197 
198 	~this()
199 	{
200 		if (isBusy) return;
201 		if (m_is_client)
202 			m_impl.client.destroy(); 
203 		else m_impl.server.destroy();
204 	}
205 
206     /**
207      * get handshake complete notifications
208     */
209     @property void onHandshakeComplete(OnHandshakeComplete handshake_complete)
210     { m_handshake_complete = handshake_complete; }
211 
212     /**
213     * get notification of alerts 
214     */
215     @property void onAlertNotification(OnAlert alert_cb)
216     {
217         m_alert_cb = alert_cb;
218     }
219 
220 private:
221 
222     bool handshakeCb(in TLSSession session)
223     {
224 		//logDebug("Handshake Complete");  
225 		if (m_handshake_complete)
226 	        return m_handshake_complete(session);
227 		return true;
228     }
229 
230     void dataCb(in ubyte[] data)
231     {
232 		if (m_plaintext.freeSpace < data.length) {
233 			//logDebug("Growing m_plaintext from: ", m_plaintext.capacity, " to ", 8192 + m_plaintext.length + m_plaintext.freeSpace);
234 			m_plaintext.capacity = std.algorithm.max(8192, data.length + data.length % 8192) + m_plaintext.capacity;
235 		}
236 		m_plaintext.put(data);
237     }
238 
239     void alertCb(in TLSAlert alert, in ubyte[] ub)
240     {
241 		//logDebug("Alert: ", alert.typeString(), " :", ub);  
242 		if (alert.isFatal)
243 			m_closed = true;
244 		if (m_alert_cb)
245 	        m_alert_cb(alert, ub); 
246     }
247 
248     union TLSImpl {
249         TLSClient client;
250         TLSServer server;
251     }
252 
253 	@property inout(TLSChannel) channel() inout { 
254 		return (m_is_client ? cast(inout(TLSChannel)) m_impl.client : cast(inout(TLSChannel)) m_impl.server); 
255 	}
256 
257 	bool m_reading;
258 	bool m_writing;
259     bool m_is_client;
260 	bool m_closed;
261     DataReader m_read_fn;
262     TLSImpl m_impl;
263     OnAlert m_alert_cb;
264     OnHandshakeComplete m_handshake_complete;
265 
266     // Buffer
267     CircularBuffer!(ubyte, 0, SecureMem) m_plaintext;
268 
269 	Vector!ubyte m_readbuf;
270 }
271