1 /**
2 * TLS Handshake Serialization
3 * 
4 * Copyright:
5 * (C) 2012 Jack Lloyd
6 * (C) 2014-2015 Etienne Cimon
7 *
8 * License:
9 * Botan is released under the Simplified BSD License (see LICENSE.md)
10 */
11 module botan.tls.handshake_io;
12 
13 import botan.constants;
14 static if (BOTAN_HAS_TLS):
15 package:
16 
17 import botan.tls.magic;
18 import botan.tls.version_;
19 import botan.utils.loadstor;
20 import botan.tls.messages;
21 import botan.tls.record;
22 import botan.tls.seq_numbers;
23 import botan.utils.exceptn;
24 import std.algorithm : count, min;
25 import botan.utils.types;
26 import memutils.hashmap;
27 import std.typecons : Tuple;
28 import std.datetime;
29 
30 struct NextRecord
31 {
32     HandshakeType type;
33     Vector!ubyte data;
34 }
35 
36 /**
37 * Handshake IO Interface
38 */
39 interface HandshakeIO
40 {
41 public:
42     abstract TLSProtocolVersion initialRecordVersion() const;
43 
44     abstract Vector!ubyte send(in HandshakeMessage msg);
45 
46     abstract Vector!ubyte format(const ref Vector!ubyte handshake_msg, HandshakeType handshake_type) const;
47 
48     abstract bool timeoutCheck();
49 
50     abstract void addRecord(const ref Vector!ubyte record, RecordType type, ulong sequence_number);
51 
52     /**
53     * Returns (HANDSHAKE_NONE, Vector!(  )()) if no message currently available
54     */
55     abstract NextRecord getNextRecord(bool expecting_ccs);
56 }
57 
58 /**
59 * Handshake IO for stream-based handshakes
60 */
61 package final class StreamHandshakeIO : HandshakeIO
62 {
63 public:
64 	alias InternalDataWriter = void delegate(ubyte, const ref Vector!ubyte);
65 
66 	this(InternalDataWriter writer) 
67     {
68         m_send_hs = writer;
69 		m_queue.reserve(128);
70     }
71 
72     override TLSProtocolVersion initialRecordVersion() const
73     {
74         return cast(TLSProtocolVersion)TLSProtocolVersion.TLS_V10;
75     }
76 
77     override bool timeoutCheck() { return false; }
78 
79     override Vector!ubyte send(in HandshakeMessage msg)
80     {
81         const Vector!ubyte msg_bits = msg.serialize();
82         
83         if (msg.type() == HANDSHAKE_CCS)
84         {
85             m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
86             return Vector!ubyte(); // not included in handshake hashes
87         }
88         
89         Vector!ubyte buf = format(msg_bits, msg.type());
90         m_send_hs(HANDSHAKE, buf);
91         return buf.move();
92     }
93 
94     override Vector!ubyte format(const ref Vector!ubyte msg, HandshakeType type) const
95     {
96         Vector!ubyte send_buf = Vector!ubyte(4 + msg.length);
97         
98         const size_t buf_size = msg.length;
99         
100         send_buf[0] = type;
101         
102         storeBigEndian24(send_buf.ptr[1 .. 4], buf_size);
103         
104         copyMem(send_buf.ptr+4, msg.ptr, msg.length);
105         
106         return send_buf.move();
107     }
108 
109     override void addRecord(const ref Vector!ubyte record, RecordType record_type, ulong)
110     {
111         if (record_type == HANDSHAKE)
112         {
113             m_queue ~= record[];
114         }
115         else if (record_type == CHANGE_CIPHER_SPEC)
116         {
117             if (record.length != 1 || record[0] != 1)
118                 throw new DecodingError("Invalid ChangeCipherSpec");
119             
120             // Pretend it's a regular handshake message of zero length
121             const(ubyte)[] ccs_hs = [ HANDSHAKE_CCS, 0, 0, 0 ];
122             m_queue.insert(ccs_hs);
123         }
124         else
125             throw new DecodingError("Unknown message type " ~ record_type.to!string ~ " in handshake processing");
126     }
127 
128     override NextRecord getNextRecord(bool)
129     {
130         if (m_queue.length >= 4)
131         {
132             const size_t length = make_uint(0, m_queue[1], m_queue[2], m_queue[3]);
133             if (m_queue.length >= length + 4)
134             {
135                 HandshakeType type = cast(HandshakeType)(m_queue[0]);
136                 
137                 Vector!ubyte contents = Vector!ubyte(m_queue.ptr[4 .. 4 + length]);
138                 Vector!ubyte ret = Vector!ubyte(m_queue.ptr[4 + length .. m_queue.length]);
139                 m_queue = ret;
140                 
141                 return NextRecord(type, contents.move());
142             }
143         }
144 
145         return NextRecord(HANDSHAKE_NONE, Vector!ubyte());
146     }
147 
148 private:
149     Vector!ubyte m_queue;
150 	InternalDataWriter m_send_hs;
151 }
152 
153 /**
154 * Handshake IO for datagram-based handshakes
155 */
156 package final class DatagramHandshakeIO : HandshakeIO
157 {
158 	alias InternalDataWriter = void delegate(ushort, ubyte, const ref Vector!ubyte);
159 private:
160     // 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1
161     const ulong INITIAL_TIMEOUT = 1*1000;
162     const ulong MAXIMUM_TIMEOUT = 60*1000;
163 
164     static ulong steadyClockMs() {
165         return (Clock.currTime(UTC()).stdTime - SysTime(DateTime(1970, 1, 1, 0, 0, 0), UTC()).stdTime)/10_000;
166     }
167 
168 public:
169 	this(ConnectionSequenceNumbers seq, InternalDataWriter writer, ushort mtu) 
170     {
171         m_seqs = seq;
172         m_flights.length = 1;
173         m_send_hs = writer; 
174         m_mtu = mtu;
175     }
176 
177     override TLSProtocolVersion initialRecordVersion() const
178     {
179         return TLSProtocolVersion(TLSProtocolVersion.DTLS_V10);
180     }
181 
182     override bool timeoutCheck() {
183         import std.range : empty;
184         if (m_last_write == 0 || (m_flights.length > 1 && !m_flights[0].empty))
185         {
186             /*
187             If we haven't written anything yet obviously no timeout.
188             Also no timeout possible if we are mid-flight,
189             */
190             return false;
191         }
192         const ulong ms_since_write = steadyClockMs() - m_last_write;
193         if (ms_since_write < m_next_timeout)
194             return false;
195         Vector!ushort flight;
196         if (m_flights.length == 1)
197             flight = m_flights[0]; // lost initial client hello
198         else
199             flight = m_flights[m_flights.length - 2];
200         assert(flight.length > 0, "Nonempty flight to retransmit");
201         ushort epoch = m_flight_data[flight[0]].epoch;
202         foreach(msg_seq; flight)
203         {
204             auto msg = m_flight_data[msg_seq];
205             if (msg.epoch != epoch)
206             {
207                 // Epoch gap: insert the CCS
208                 Vector!ubyte ccs;
209                 ccs ~= 1;
210                 m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
211             }
212             sendMessage(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
213             epoch = msg.epoch;
214         }
215         m_next_timeout = min(2 * m_next_timeout, MAXIMUM_TIMEOUT);
216         return true;
217     }
218 
219     override Vector!ubyte send(in HandshakeMessage msg)
220     {
221         Vector!ubyte msg_bits = msg.serialize();
222         ushort epoch = m_seqs.currentWriteEpoch();
223         HandshakeType msg_type = msg.type();
224 
225         if (msg_type == HANDSHAKE_CCS)
226         {
227             m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
228             return Vector!ubyte(); // not included in handshake hashes
229         }
230         
231         // Note: not saving CCS, instead we know it was there due to change in epoch
232         m_flights[$-1].pushBack(m_out_message_seq);
233         m_flight_data[m_out_message_seq] = MessageInfo(epoch, msg_type, msg_bits);
234 
235         m_out_message_seq += 1;
236         m_last_write = steadyClockMs();
237         m_next_timeout = INITIAL_TIMEOUT;
238 
239         return sendMessage(cast(ushort)(m_out_message_seq - 1), epoch, msg_type, msg_bits);
240     }
241 
242     Vector!ubyte sendMessage(ushort msg_seq, ushort epoch, HandshakeType msg_type, const ref Vector!ubyte msg_bits)
243     {
244         const Vector!ubyte no_fragment = formatWSeq(msg_bits, msg_type, msg_seq);
245         
246         if (no_fragment.length + DTLS_HEADER_SIZE <= m_mtu)
247             m_send_hs(epoch, HANDSHAKE, no_fragment);
248         else
249         {
250             const size_t parts = splitForMtu(m_mtu, msg_bits.length);
251             
252             const size_t parts_size = (msg_bits.length + parts) / parts;
253             
254             size_t frag_offset = 0;
255             
256             while (frag_offset != msg_bits.length)
257             {
258                 import std.algorithm : min;
259                 const size_t frag_len = min(msg_bits.length - frag_offset, parts_size);
260                 auto frag = formatFragment(cast(const(ubyte)*)&msg_bits[frag_offset],
261                     frag_len,
262                     cast(ushort)frag_offset,
263                     cast(ushort)msg_bits.length,
264                     msg_type,
265                     msg_seq);
266 
267                 m_send_hs(epoch, HANDSHAKE, frag);
268 
269                 frag_offset += frag_len;
270             }
271         }
272 
273         return (cast()no_fragment).move;
274     }
275     override Vector!ubyte format(const ref Vector!ubyte msg, HandshakeType type) const
276     {
277         return formatWSeq(msg, type, cast(ushort) (m_in_message_seq - 1));
278     }
279 
280     override void addRecord(const ref Vector!ubyte record, RecordType record_type, ulong record_sequence)
281     {
282         const ushort epoch = cast(ushort)(record_sequence >> 48);
283         
284         if (record_type == CHANGE_CIPHER_SPEC)
285         {
286             // TODO: check this is otherwise empty
287             m_ccs_epochs ~= epoch;
288             return;
289         }
290         
291         __gshared immutable size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
292         
293         const(ubyte)* record_bits = record.ptr;
294         size_t record_size = record.length;
295         
296         while (record_size)
297         {
298             if (record_size < DTLS_HANDSHAKE_HEADER_LEN)
299                 return; // completely bogus? at least degenerate/weird
300             
301             const ubyte msg_type = record_bits[0];
302             const size_t msg_len = loadBigEndian24((&record_bits[1])[0 .. 3]);
303             const ushort message_seq = loadBigEndian!ushort(&record_bits[4], 0);
304             const size_t fragment_offset = loadBigEndian24((&record_bits[6])[0 .. 3]);
305             const size_t fragment_length = loadBigEndian24((&record_bits[9])[0 .. 3]);
306             
307             const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
308             
309             if (record_size < total_size)
310                 throw new DecodingError("Bad lengths in DTLS header");
311             
312             if (message_seq >= m_in_message_seq)
313             {
314                 m_messages[message_seq] = HandshakeReassembly.init;
315                 m_messages[message_seq].addFragment(&record_bits[DTLS_HANDSHAKE_HEADER_LEN],
316                                                     fragment_length,
317                                                     fragment_offset,
318                                                     epoch,
319                                                     msg_type,
320                                                     msg_len);
321             }
322             else {
323                 // TODO: detect retransmitted flight
324             }
325             record_bits += total_size;
326             record_size -= total_size;
327         }
328     }
329 
330     override NextRecord getNextRecord(bool expecting_ccs)
331     {
332         // Expecting a message means the last flight is concluded
333 
334         if (!m_flights[$-1].empty)
335             m_flights.pushBack(Vector!ushort());
336         
337         if (expecting_ccs)
338         {
339             if (m_messages.length > 0)
340             {
341                 const ushort current_epoch = m_messages[cast(ushort)0].epoch();
342 
343                 if (m_ccs_epochs.canFind(current_epoch))
344                     return NextRecord(HANDSHAKE_CCS, Vector!ubyte());
345             }
346             
347             return NextRecord(HANDSHAKE_NONE, Vector!ubyte());
348         }
349         
350         auto rec = m_messages.get(m_in_message_seq, HandshakeReassembly.init);
351         
352         if (rec is HandshakeReassembly.init || !rec.complete())
353             return NextRecord(HANDSHAKE_NONE, Vector!ubyte());
354         
355         m_in_message_seq += 1;
356         
357         return rec.message();
358     }
359 
360 private:
361 
362     Vector!ubyte formatFragment(const(ubyte)* fragment,
363                                  size_t frag_len,
364                                  ushort frag_offset,
365                                  ushort msg_len,
366                                  HandshakeType type,
367                                  ushort msg_sequence) const
368     {
369         Vector!ubyte send_buf = Vector!ubyte(12 + frag_len);
370         
371         send_buf[0] = type;
372         
373         storeBigEndian24((&send_buf[1])[0 .. 3], msg_len);
374         
375         storeBigEndian(msg_sequence, &send_buf[4]);
376         
377         storeBigEndian24((&send_buf[6])[0 .. 3], frag_offset);
378         storeBigEndian24((&send_buf[9])[0 .. 3], frag_len);
379         
380         copyMem(&send_buf[12], fragment, frag_len);
381         
382         return send_buf;
383     }
384 
385     Vector!ubyte formatWSeq(const ref Vector!ubyte msg,
386                      HandshakeType type,
387                      ushort msg_sequence) const
388     {
389         return formatFragment(msg.ptr, msg.length, cast(ushort)  0, cast(ushort) msg.length, type, msg_sequence);
390     }
391 
392 
393     struct HandshakeReassembly
394     {
395     public:
396         void addFragment(const(ubyte)* fragment,
397                             size_t fragment_length,
398                             size_t fragment_offset,
399                             ushort epoch,
400                             ubyte msg_type,
401                             size_t msg_length)
402         {
403             if (complete())
404                 return; // already have entire message, ignore this
405             
406             if (m_msg_type == HANDSHAKE_NONE)
407             {
408                 m_epoch = epoch;
409                 m_msg_type = msg_type;
410                 m_msg_length = msg_length;
411             }
412             
413             if (msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
414                 throw new DecodingError("Inconsistent values in fragmented  DTLS handshake header");
415             
416             if (fragment_offset > m_msg_length)
417                 throw new DecodingError("Fragment offset past end of message");
418             
419             if (fragment_offset + fragment_length > m_msg_length)
420                 throw new DecodingError("Fragment overlaps past end of message");
421             
422             if (fragment_offset == 0 && fragment_length == m_msg_length)
423             {
424                 m_fragments.clear();
425                 m_message[] = fragment[0 .. fragment_length];
426             }
427             else
428             {
429                 /*
430                 * FIXME. This is a pretty lame way to do defragmentation, huge
431                 * overhead with a tree node per ubyte.
432                 *
433                 * Also should confirm that all overlaps have no changes,
434                 * otherwise we expose ourselves to the classic fingerprinting
435                 * and IDS evasion attacks on IP fragmentation.
436                 */
437                 foreach (size_t i; 0 .. fragment_length)
438                     m_fragments[fragment_offset+i] = cast(ubyte)fragment[i];
439                 
440                 if (m_fragments.length == m_msg_length)
441                 {
442                     m_message.resize(m_msg_length);
443                     foreach (size_t i; 0 .. m_msg_length)
444                         m_message[i] = m_fragments[i];
445                     m_fragments.clear();
446                 }
447             }
448         }
449 
450         bool complete() const
451         {
452             return (m_msg_type != HANDSHAKE_NONE && m_message.length == m_msg_length);
453         }
454 
455         ushort epoch() const { return m_epoch; }
456 
457         NextRecord message() const
458         {
459             if (!complete())
460                 throw new InternalError("DatagramHandshakeIO - message not complete");
461             
462             return NextRecord(cast(HandshakeType)(m_msg_type), m_message.dup);
463         }
464 
465         private:
466             ubyte m_msg_type = HANDSHAKE_NONE;
467             size_t m_msg_length = 0;
468             ushort m_epoch = 0;
469 
470             HashMapRef!(size_t, ubyte) m_fragments;
471             Array!ubyte m_message;
472     }
473 
474     struct MessageInfo
475     {
476         this(ushort e, HandshakeType mt, const ref Vector!ubyte msg)
477         {
478             epoch = e;
479             msg_type = mt;
480             msg_bits = msg.dupr;
481         }
482 
483         ushort epoch = 0xFFFF;
484         HandshakeType msg_type = HANDSHAKE_NONE;
485         Array!ubyte msg_bits;
486     };
487 
488 
489     ConnectionSequenceNumbers m_seqs;
490     HashMap!(ushort, HandshakeReassembly) m_messages;
491     ushort[] m_ccs_epochs;
492     Vector!( Array!ushort ) m_flights;
493     HashMap!(ushort, MessageInfo ) m_flight_data;
494 
495     ulong m_last_write = 0;
496     ulong m_next_timeout = 0;
497 
498     ushort m_in_message_seq = 0;
499     ushort m_out_message_seq = 0;
500 	InternalDataWriter m_send_hs;
501     ushort m_mtu;
502 }
503 
504 
505 private:
506 
507 size_t loadBigEndian24(in ubyte[3] q)
508 {
509     return make_uint(0, q[0], q[1], q[2]);
510 }
511 
512 void storeBigEndian24(ref ubyte[3] output, size_t val)
513 {
514     output[0] = get_byte!uint(1, cast(uint) val);
515     output[1] = get_byte!uint(2, cast(uint) val);
516     output[2] = get_byte!uint(3, cast(uint) val);
517 }
518 
519 size_t splitForMtu(size_t mtu, size_t msg_size)
520 {
521     __gshared immutable size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers
522     
523     const size_t parts = (msg_size + mtu) / mtu;
524     
525     if (parts + DTLS_HEADERS_SIZE > mtu)
526         return parts + 1;
527     
528     return parts;
529 }