1 /**
2 * TLS Handshake Serialization
3 * 
4 * Copyright:
5 * (C) 2012 Jack Lloyd
6 * (C) 2014-2015 Etienne Cimon
7 *
8 * License:
9 * Botan is released under the Simplified BSD License (see LICENSE.md)
10 */
11 module botan.tls.handshake_io;
12 
13 import botan.constants;
14 static if (BOTAN_HAS_TLS):
15 package:
16 
17 import botan.tls.magic;
18 import botan.tls.version_;
19 import botan.utils.loadstor;
20 import botan.tls.messages;
21 import botan.tls.record;
22 import botan.tls.seq_numbers;
23 import botan.utils.exceptn;
24 import std.algorithm : count, min;
25 import botan.utils.types;
26 import memutils.hashmap;
27 import std.typecons : Tuple;
28 import std.datetime;
29 
30 struct NextRecord
31 {
32     HandshakeType type;
33     Vector!ubyte data;
34 }
35 
36 /**
37 * Handshake IO Interface
38 */
39 interface HandshakeIO
40 {
41 public:
42     abstract TLSProtocolVersion initialRecordVersion() const;
43 
44     abstract Vector!ubyte send(in HandshakeMessage msg);
45 
46     abstract const(Vector!ubyte) format(const ref Vector!ubyte handshake_msg, HandshakeType handshake_type) const;
47 
48     abstract bool timeoutCheck();
49 
50     abstract void addRecord(const ref Vector!ubyte record, RecordType type, ulong sequence_number);
51 
52     /**
53     * Returns (HANDSHAKE_NONE, Vector!(  )()) if no message currently available
54     */
55     abstract NextRecord getNextRecord(bool expecting_ccs);
56 }
57 
58 /**
59 * Handshake IO for stream-based handshakes
60 */
61 package final class StreamHandshakeIO : HandshakeIO
62 {
63 public:
64 	alias InternalDataWriter = void delegate(ubyte, const ref Vector!ubyte);
65 
66 	this(InternalDataWriter writer) 
67     {
68         m_send_hs = writer;
69     }
70 
71     override TLSProtocolVersion initialRecordVersion() const
72     {
73         return cast(TLSProtocolVersion)TLSProtocolVersion.TLS_V10;
74     }
75 
76     override bool timeoutCheck() { return false; }
77 
78     override Vector!ubyte send(in HandshakeMessage msg)
79     {
80         const Vector!ubyte msg_bits = msg.serialize();
81         
82         if (msg.type() == HANDSHAKE_CCS)
83         {
84             m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
85             return Vector!ubyte(); // not included in handshake hashes
86         }
87         
88         Vector!ubyte buf = format(msg_bits, msg.type()).dup;
89         m_send_hs(HANDSHAKE, buf);
90         return buf.move();
91     }
92 
93     override const(Vector!ubyte) format(const ref Vector!ubyte msg, HandshakeType type) const
94     {
95         Vector!ubyte send_buf = Vector!ubyte(4 + msg.length);
96         
97         const size_t buf_size = msg.length;
98         
99         send_buf[0] = type;
100         
101         storeBigEndian24(send_buf.ptr[1 .. 4], buf_size);
102         
103         copyMem(send_buf.ptr+4, msg.ptr, msg.length);
104         
105         return send_buf;
106     }
107 
108     override void addRecord(const ref Vector!ubyte record, RecordType record_type, ulong)
109     {
110         if (record_type == HANDSHAKE)
111         {
112             m_queue ~= record[];
113         }
114         else if (record_type == CHANGE_CIPHER_SPEC)
115         {
116             if (record.length != 1 || record[0] != 1)
117                 throw new DecodingError("Invalid ChangeCipherSpec");
118             
119             // Pretend it's a regular handshake message of zero length
120             const(ubyte)[] ccs_hs = [ HANDSHAKE_CCS, 0, 0, 0 ];
121             m_queue.insert(ccs_hs);
122         }
123         else
124             throw new DecodingError("Unknown message type " ~ record_type.to!string ~ " in handshake processing");
125     }
126 
127     override NextRecord getNextRecord(bool)
128     {
129         if (m_queue.length >= 4)
130         {
131             const size_t length = make_uint(0, m_queue[1], m_queue[2], m_queue[3]);
132             if (m_queue.length >= length + 4)
133             {
134                 HandshakeType type = cast(HandshakeType)(m_queue[0]);
135                 
136                 Vector!ubyte contents = Vector!ubyte(m_queue.ptr[4 .. 4 + length]);
137                 Vector!ubyte ret = Vector!ubyte(m_queue.ptr[4 + length .. m_queue.length]);
138                 m_queue = ret;
139                 
140                 return NextRecord(type, contents.move());
141             }
142         }
143 
144         return NextRecord(HANDSHAKE_NONE, Vector!ubyte());
145     }
146 
147 private:
148     Vector!ubyte m_queue;
149 	InternalDataWriter m_send_hs;
150 }
151 
152 /**
153 * Handshake IO for datagram-based handshakes
154 */
155 package final class DatagramHandshakeIO : HandshakeIO
156 {
157 	alias InternalDataWriter = void delegate(ushort, ubyte, const ref Vector!ubyte);
158 private:
159     // 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1
160     const ulong INITIAL_TIMEOUT = 1*1000;
161     const ulong MAXIMUM_TIMEOUT = 60*1000;
162 
163     static ulong steadyClockMs() {
164         return (Clock.currTime(UTC()).stdTime - SysTime(DateTime(1970, 1, 1, 0, 0, 0), UTC()).stdTime)/10_000;
165     }
166 
167 public:
168 	this(ConnectionSequenceNumbers seq, InternalDataWriter writer, ushort mtu) 
169     {
170         m_seqs = seq;
171         m_flights.length = 1;
172         m_send_hs = writer; 
173         m_mtu = mtu;
174     }
175 
176     override TLSProtocolVersion initialRecordVersion() const
177     {
178         return TLSProtocolVersion(TLSProtocolVersion.DTLS_V10);
179     }
180 
181     override bool timeoutCheck() {
182         import std.range : empty;
183         if (m_last_write == 0 || (m_flights.length > 1 && !m_flights[0].empty))
184         {
185             /*
186             If we haven't written anything yet obviously no timeout.
187             Also no timeout possible if we are mid-flight,
188             */
189             return false;
190         }
191         const ulong ms_since_write = steadyClockMs() - m_last_write;
192         if (ms_since_write < m_next_timeout)
193             return false;
194         Vector!ushort flight;
195         if (m_flights.length == 1)
196             flight = m_flights[0]; // lost initial client hello
197         else
198             flight = m_flights[m_flights.length - 2];
199         assert(flight.length > 0, "Nonempty flight to retransmit");
200         ushort epoch = m_flight_data[flight[0]].epoch;
201         foreach(msg_seq; flight)
202         {
203             auto msg = m_flight_data[msg_seq];
204             if (msg.epoch != epoch)
205             {
206                 // Epoch gap: insert the CCS
207                 Vector!ubyte ccs;
208                 ccs ~= 1;
209                 m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
210             }
211             sendMessage(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
212             epoch = msg.epoch;
213         }
214         m_next_timeout = min(2 * m_next_timeout, MAXIMUM_TIMEOUT);
215         return true;
216     }
217 
218     override Vector!ubyte send(in HandshakeMessage msg)
219     {
220         Vector!ubyte msg_bits = msg.serialize();
221         ushort epoch = m_seqs.currentWriteEpoch();
222         HandshakeType msg_type = msg.type();
223 
224         if (msg_type == HANDSHAKE_CCS)
225         {
226             m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
227             return Vector!ubyte(); // not included in handshake hashes
228         }
229         
230         // Note: not saving CCS, instead we know it was there due to change in epoch
231         m_flights[$-1].pushBack(m_out_message_seq);
232         m_flight_data[m_out_message_seq] = MessageInfo(epoch, msg_type, msg_bits);
233 
234         m_out_message_seq += 1;
235         m_last_write = steadyClockMs();
236         m_next_timeout = INITIAL_TIMEOUT;
237 
238         return sendMessage(cast(ushort)(m_out_message_seq - 1), epoch, msg_type, msg_bits);
239     }
240 
241     Vector!ubyte sendMessage(ushort msg_seq, ushort epoch, HandshakeType msg_type, const ref Vector!ubyte msg_bits)
242     {
243         const Vector!ubyte no_fragment = formatWSeq(msg_bits, msg_type, msg_seq);
244         
245         if (no_fragment.length + DTLS_HEADER_SIZE <= m_mtu)
246             m_send_hs(epoch, HANDSHAKE, no_fragment);
247         else
248         {
249             const size_t parts = splitForMtu(m_mtu, msg_bits.length);
250             
251             const size_t parts_size = (msg_bits.length + parts) / parts;
252             
253             size_t frag_offset = 0;
254             
255             while (frag_offset != msg_bits.length)
256             {
257                 const size_t frag_len =    std.algorithm.min(msg_bits.length - frag_offset, parts_size);
258                 auto frag = formatFragment(cast(const(ubyte)*)&msg_bits[frag_offset],
259                     frag_len,
260                     cast(ushort)frag_offset,
261                     cast(ushort)msg_bits.length,
262                     msg_type,
263                     msg_seq);
264 
265                 m_send_hs(epoch, HANDSHAKE, frag);
266 
267                 frag_offset += frag_len;
268             }
269         }
270 
271         return (cast()no_fragment).move;
272     }
273     override const(Vector!ubyte) format(const ref Vector!ubyte msg, HandshakeType type) const
274     {
275         return formatWSeq(msg, type, cast(ushort) (m_in_message_seq - 1));
276     }
277 
278     override void addRecord(const ref Vector!ubyte record, RecordType record_type, ulong record_sequence)
279     {
280         const ushort epoch = cast(ushort)(record_sequence >> 48);
281         
282         if (record_type == CHANGE_CIPHER_SPEC)
283         {
284             // TODO: check this is otherwise empty
285             m_ccs_epochs ~= epoch;
286             return;
287         }
288         
289         __gshared immutable size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
290         
291         const(ubyte)* record_bits = record.ptr;
292         size_t record_size = record.length;
293         
294         while (record_size)
295         {
296             if (record_size < DTLS_HANDSHAKE_HEADER_LEN)
297                 return; // completely bogus? at least degenerate/weird
298             
299             const ubyte msg_type = record_bits[0];
300             const size_t msg_len = loadBigEndian24((&record_bits[1])[0 .. 3]);
301             const ushort message_seq = loadBigEndian!ushort(&record_bits[4], 0);
302             const size_t fragment_offset = loadBigEndian24((&record_bits[6])[0 .. 3]);
303             const size_t fragment_length = loadBigEndian24((&record_bits[9])[0 .. 3]);
304             
305             const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
306             
307             if (record_size < total_size)
308                 throw new DecodingError("Bad lengths in DTLS header");
309             
310             if (message_seq >= m_in_message_seq)
311             {
312                 m_messages[message_seq] = HandshakeReassembly.init;
313                 m_messages[message_seq].addFragment(&record_bits[DTLS_HANDSHAKE_HEADER_LEN],
314                                                     fragment_length,
315                                                     fragment_offset,
316                                                     epoch,
317                                                     msg_type,
318                                                     msg_len);
319             }
320             else {
321                 // TODO: detect retransmitted flight
322             }
323             record_bits += total_size;
324             record_size -= total_size;
325         }
326     }
327 
328     override NextRecord getNextRecord(bool expecting_ccs)
329     {
330         // Expecting a message means the last flight is concluded
331 
332         if (!m_flights[$-1].empty)
333             m_flights.pushBack(Vector!ushort());
334         
335         if (expecting_ccs)
336         {
337             if (m_messages.length > 0)
338             {
339                 const ushort current_epoch = m_messages[cast(ushort)0].epoch();
340 
341                 if (m_ccs_epochs.canFind(current_epoch))
342                     return NextRecord(HANDSHAKE_CCS, Vector!ubyte());
343             }
344             
345             return NextRecord(HANDSHAKE_NONE, Vector!ubyte());
346         }
347         
348         auto rec = m_messages.get(m_in_message_seq, HandshakeReassembly.init);
349         
350         if (rec is HandshakeReassembly.init || !rec.complete())
351             return NextRecord(HANDSHAKE_NONE, Vector!ubyte());
352         
353         m_in_message_seq += 1;
354         
355         return rec.message();
356     }
357 
358 private:
359 
360     Vector!ubyte formatFragment(const(ubyte)* fragment,
361                                  size_t frag_len,
362                                  ushort frag_offset,
363                                  ushort msg_len,
364                                  HandshakeType type,
365                                  ushort msg_sequence) const
366     {
367         Vector!ubyte send_buf = Vector!ubyte(12 + frag_len);
368         
369         send_buf[0] = type;
370         
371         storeBigEndian24((&send_buf[1])[0 .. 3], msg_len);
372         
373         storeBigEndian(msg_sequence, &send_buf[4]);
374         
375         storeBigEndian24((&send_buf[6])[0 .. 3], frag_offset);
376         storeBigEndian24((&send_buf[9])[0 .. 3], frag_len);
377         
378         copyMem(&send_buf[12], fragment, frag_len);
379         
380         return send_buf;
381     }
382 
383     Vector!ubyte formatWSeq(const ref Vector!ubyte msg,
384                      HandshakeType type,
385                      ushort msg_sequence) const
386     {
387         return formatFragment(msg.ptr, msg.length, cast(ushort)  0, cast(ushort) msg.length, type, msg_sequence);
388     }
389 
390 
391     struct HandshakeReassembly
392     {
393     public:
394         void addFragment(const(ubyte)* fragment,
395                             size_t fragment_length,
396                             size_t fragment_offset,
397                             ushort epoch,
398                             ubyte msg_type,
399                             size_t msg_length)
400         {
401             if (complete())
402                 return; // already have entire message, ignore this
403             
404             if (m_msg_type == HANDSHAKE_NONE)
405             {
406                 m_epoch = epoch;
407                 m_msg_type = msg_type;
408                 m_msg_length = msg_length;
409             }
410             
411             if (msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
412                 throw new DecodingError("Inconsistent values in fragmented  DTLS handshake header");
413             
414             if (fragment_offset > m_msg_length)
415                 throw new DecodingError("Fragment offset past end of message");
416             
417             if (fragment_offset + fragment_length > m_msg_length)
418                 throw new DecodingError("Fragment overlaps past end of message");
419             
420             if (fragment_offset == 0 && fragment_length == m_msg_length)
421             {
422                 m_fragments.clear();
423                 m_message[] = fragment[0 .. fragment_length];
424             }
425             else
426             {
427                 /*
428                 * FIXME. This is a pretty lame way to do defragmentation, huge
429                 * overhead with a tree node per ubyte.
430                 *
431                 * Also should confirm that all overlaps have no changes,
432                 * otherwise we expose ourselves to the classic fingerprinting
433                 * and IDS evasion attacks on IP fragmentation.
434                 */
435                 foreach (size_t i; 0 .. fragment_length)
436                     m_fragments[fragment_offset+i] = cast(ubyte)fragment[i];
437                 
438                 if (m_fragments.length == m_msg_length)
439                 {
440                     m_message.resize(m_msg_length);
441                     foreach (size_t i; 0 .. m_msg_length)
442                         m_message[i] = m_fragments[i];
443                     m_fragments.clear();
444                 }
445             }
446         }
447 
448         bool complete() const
449         {
450             return (m_msg_type != HANDSHAKE_NONE && m_message.length == m_msg_length);
451         }
452 
453         ushort epoch() const { return m_epoch; }
454 
455         NextRecord message() const
456         {
457             if (!complete())
458                 throw new InternalError("DatagramHandshakeIO - message not complete");
459             
460             return NextRecord(cast(HandshakeType)(m_msg_type), m_message.dup);
461         }
462 
463         private:
464             ubyte m_msg_type = HANDSHAKE_NONE;
465             size_t m_msg_length = 0;
466             ushort m_epoch = 0;
467 
468             HashMapRef!(size_t, ubyte) m_fragments;
469             Array!ubyte m_message;
470     }
471 
472     struct MessageInfo
473     {
474         this(ushort e, HandshakeType mt, const ref Vector!ubyte msg)
475         {
476             epoch = e;
477             msg_type = mt;
478             msg_bits = msg.dupr;
479         }
480 
481         ushort epoch = 0xFFFF;
482         HandshakeType msg_type = HANDSHAKE_NONE;
483         Array!ubyte msg_bits;
484     };
485 
486 
487     ConnectionSequenceNumbers m_seqs;
488     HashMap!(ushort, HandshakeReassembly) m_messages;
489     ushort[] m_ccs_epochs;
490     Vector!( Array!ushort ) m_flights;
491     HashMap!(ushort, MessageInfo ) m_flight_data;
492 
493     ulong m_last_write = 0;
494     ulong m_next_timeout = 0;
495 
496     ushort m_in_message_seq = 0;
497     ushort m_out_message_seq = 0;
498 	InternalDataWriter m_send_hs;
499     ushort m_mtu;
500 }
501 
502 
503 private:
504 
505 size_t loadBigEndian24(in ubyte[3] q)
506 {
507     return make_uint(0, q[0], q[1], q[2]);
508 }
509 
510 void storeBigEndian24(ref ubyte[3] output, size_t val)
511 {
512     output[0] = get_byte!uint(1, cast(uint) val);
513     output[1] = get_byte!uint(2, cast(uint) val);
514     output[2] = get_byte!uint(3, cast(uint) val);
515 }
516 
517 size_t splitForMtu(size_t mtu, size_t msg_size)
518 {
519     __gshared immutable size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers
520     
521     const size_t parts = (msg_size + mtu) / mtu;
522     
523     if (parts + DTLS_HEADERS_SIZE > mtu)
524         return parts + 1;
525     
526     return parts;
527 }