1 /** 2 * TLS Handshake Serialization 3 * 4 * Copyright: 5 * (C) 2012 Jack Lloyd 6 * (C) 2014-2015 Etienne Cimon 7 * 8 * License: 9 * Botan is released under the Simplified BSD License (see LICENSE.md) 10 */ 11 module botan.tls.handshake_io; 12 13 import botan.constants; 14 static if (BOTAN_HAS_TLS): 15 package: 16 17 import botan.tls.magic; 18 import botan.tls.version_; 19 import botan.utils.loadstor; 20 import botan.tls.messages; 21 import botan.tls.record; 22 import botan.tls.seq_numbers; 23 import botan.utils.exceptn; 24 import std.algorithm : count, min; 25 import botan.utils.types; 26 import memutils.hashmap; 27 import std.typecons : Tuple; 28 import std.datetime; 29 30 struct NextRecord 31 { 32 HandshakeType type; 33 Vector!ubyte data; 34 } 35 36 /** 37 * Handshake IO Interface 38 */ 39 interface HandshakeIO 40 { 41 public: 42 abstract TLSProtocolVersion initialRecordVersion() const; 43 44 abstract Vector!ubyte send(in HandshakeMessage msg); 45 46 abstract const(Vector!ubyte) format(const ref Vector!ubyte handshake_msg, HandshakeType handshake_type) const; 47 48 abstract bool timeoutCheck(); 49 50 abstract void addRecord(const ref Vector!ubyte record, RecordType type, ulong sequence_number); 51 52 /** 53 * Returns (HANDSHAKE_NONE, Vector!( )()) if no message currently available 54 */ 55 abstract NextRecord getNextRecord(bool expecting_ccs); 56 } 57 58 /** 59 * Handshake IO for stream-based handshakes 60 */ 61 package final class StreamHandshakeIO : HandshakeIO 62 { 63 public: 64 alias InternalDataWriter = void delegate(ubyte, const ref Vector!ubyte); 65 66 this(InternalDataWriter writer) 67 { 68 m_send_hs = writer; 69 } 70 71 override TLSProtocolVersion initialRecordVersion() const 72 { 73 return cast(TLSProtocolVersion)TLSProtocolVersion.TLS_V10; 74 } 75 76 override bool timeoutCheck() { return false; } 77 78 override Vector!ubyte send(in HandshakeMessage msg) 79 { 80 const Vector!ubyte msg_bits = msg.serialize(); 81 82 if (msg.type() == HANDSHAKE_CCS) 83 { 84 m_send_hs(CHANGE_CIPHER_SPEC, msg_bits); 85 return Vector!ubyte(); // not included in handshake hashes 86 } 87 88 Vector!ubyte buf = format(msg_bits, msg.type()).dup; 89 m_send_hs(HANDSHAKE, buf); 90 return buf.move(); 91 } 92 93 override const(Vector!ubyte) format(const ref Vector!ubyte msg, HandshakeType type) const 94 { 95 Vector!ubyte send_buf = Vector!ubyte(4 + msg.length); 96 97 const size_t buf_size = msg.length; 98 99 send_buf[0] = type; 100 101 storeBigEndian24(send_buf.ptr[1 .. 4], buf_size); 102 103 copyMem(send_buf.ptr+4, msg.ptr, msg.length); 104 105 return send_buf; 106 } 107 108 override void addRecord(const ref Vector!ubyte record, RecordType record_type, ulong) 109 { 110 if (record_type == HANDSHAKE) 111 { 112 m_queue ~= record[]; 113 } 114 else if (record_type == CHANGE_CIPHER_SPEC) 115 { 116 if (record.length != 1 || record[0] != 1) 117 throw new DecodingError("Invalid ChangeCipherSpec"); 118 119 // Pretend it's a regular handshake message of zero length 120 const(ubyte)[] ccs_hs = [ HANDSHAKE_CCS, 0, 0, 0 ]; 121 m_queue.insert(ccs_hs); 122 } 123 else 124 throw new DecodingError("Unknown message type " ~ record_type.to!string ~ " in handshake processing"); 125 } 126 127 override NextRecord getNextRecord(bool) 128 { 129 if (m_queue.length >= 4) 130 { 131 const size_t length = make_uint(0, m_queue[1], m_queue[2], m_queue[3]); 132 if (m_queue.length >= length + 4) 133 { 134 HandshakeType type = cast(HandshakeType)(m_queue[0]); 135 136 Vector!ubyte contents = Vector!ubyte(m_queue.ptr[4 .. 4 + length]); 137 Vector!ubyte ret = Vector!ubyte(m_queue.ptr[4 + length .. m_queue.length]); 138 m_queue = ret; 139 140 return NextRecord(type, contents.move()); 141 } 142 } 143 144 return NextRecord(HANDSHAKE_NONE, Vector!ubyte()); 145 } 146 147 private: 148 Vector!ubyte m_queue; 149 InternalDataWriter m_send_hs; 150 } 151 152 /** 153 * Handshake IO for datagram-based handshakes 154 */ 155 package final class DatagramHandshakeIO : HandshakeIO 156 { 157 alias InternalDataWriter = void delegate(ushort, ubyte, const ref Vector!ubyte); 158 private: 159 // 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1 160 const ulong INITIAL_TIMEOUT = 1*1000; 161 const ulong MAXIMUM_TIMEOUT = 60*1000; 162 163 static ulong steadyClockMs() { 164 return (Clock.currTime(UTC()).stdTime - SysTime(DateTime(1970, 1, 1, 0, 0, 0), UTC()).stdTime)/10_000; 165 } 166 167 public: 168 this(ConnectionSequenceNumbers seq, InternalDataWriter writer, ushort mtu) 169 { 170 m_seqs = seq; 171 m_flights.length = 1; 172 m_send_hs = writer; 173 m_mtu = mtu; 174 } 175 176 override TLSProtocolVersion initialRecordVersion() const 177 { 178 return TLSProtocolVersion(TLSProtocolVersion.DTLS_V10); 179 } 180 181 override bool timeoutCheck() { 182 import std.range : empty; 183 if (m_last_write == 0 || (m_flights.length > 1 && !m_flights[0].empty)) 184 { 185 /* 186 If we haven't written anything yet obviously no timeout. 187 Also no timeout possible if we are mid-flight, 188 */ 189 return false; 190 } 191 const ulong ms_since_write = steadyClockMs() - m_last_write; 192 if (ms_since_write < m_next_timeout) 193 return false; 194 Vector!ushort flight; 195 if (m_flights.length == 1) 196 flight = m_flights[0]; // lost initial client hello 197 else 198 flight = m_flights[m_flights.length - 2]; 199 assert(flight.length > 0, "Nonempty flight to retransmit"); 200 ushort epoch = m_flight_data[flight[0]].epoch; 201 foreach(msg_seq; flight) 202 { 203 auto msg = m_flight_data[msg_seq]; 204 if (msg.epoch != epoch) 205 { 206 // Epoch gap: insert the CCS 207 Vector!ubyte ccs; 208 ccs ~= 1; 209 m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs); 210 } 211 sendMessage(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits); 212 epoch = msg.epoch; 213 } 214 m_next_timeout = min(2 * m_next_timeout, MAXIMUM_TIMEOUT); 215 return true; 216 } 217 218 override Vector!ubyte send(in HandshakeMessage msg) 219 { 220 Vector!ubyte msg_bits = msg.serialize(); 221 ushort epoch = m_seqs.currentWriteEpoch(); 222 HandshakeType msg_type = msg.type(); 223 224 if (msg_type == HANDSHAKE_CCS) 225 { 226 m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits); 227 return Vector!ubyte(); // not included in handshake hashes 228 } 229 230 // Note: not saving CCS, instead we know it was there due to change in epoch 231 m_flights[$-1].pushBack(m_out_message_seq); 232 m_flight_data[m_out_message_seq] = MessageInfo(epoch, msg_type, msg_bits); 233 234 m_out_message_seq += 1; 235 m_last_write = steadyClockMs(); 236 m_next_timeout = INITIAL_TIMEOUT; 237 238 return sendMessage(cast(ushort)(m_out_message_seq - 1), epoch, msg_type, msg_bits); 239 } 240 241 Vector!ubyte sendMessage(ushort msg_seq, ushort epoch, HandshakeType msg_type, const ref Vector!ubyte msg_bits) 242 { 243 const Vector!ubyte no_fragment = formatWSeq(msg_bits, msg_type, msg_seq); 244 245 if (no_fragment.length + DTLS_HEADER_SIZE <= m_mtu) 246 m_send_hs(epoch, HANDSHAKE, no_fragment); 247 else 248 { 249 const size_t parts = splitForMtu(m_mtu, msg_bits.length); 250 251 const size_t parts_size = (msg_bits.length + parts) / parts; 252 253 size_t frag_offset = 0; 254 255 while (frag_offset != msg_bits.length) 256 { 257 const size_t frag_len = std.algorithm.min(msg_bits.length - frag_offset, parts_size); 258 auto frag = formatFragment(cast(const(ubyte)*)&msg_bits[frag_offset], 259 frag_len, 260 cast(ushort)frag_offset, 261 cast(ushort)msg_bits.length, 262 msg_type, 263 msg_seq); 264 265 m_send_hs(epoch, HANDSHAKE, frag); 266 267 frag_offset += frag_len; 268 } 269 } 270 271 return (cast()no_fragment).move; 272 } 273 override const(Vector!ubyte) format(const ref Vector!ubyte msg, HandshakeType type) const 274 { 275 return formatWSeq(msg, type, cast(ushort) (m_in_message_seq - 1)); 276 } 277 278 override void addRecord(const ref Vector!ubyte record, RecordType record_type, ulong record_sequence) 279 { 280 const ushort epoch = cast(ushort)(record_sequence >> 48); 281 282 if (record_type == CHANGE_CIPHER_SPEC) 283 { 284 // TODO: check this is otherwise empty 285 m_ccs_epochs ~= epoch; 286 return; 287 } 288 289 __gshared immutable size_t DTLS_HANDSHAKE_HEADER_LEN = 12; 290 291 const(ubyte)* record_bits = record.ptr; 292 size_t record_size = record.length; 293 294 while (record_size) 295 { 296 if (record_size < DTLS_HANDSHAKE_HEADER_LEN) 297 return; // completely bogus? at least degenerate/weird 298 299 const ubyte msg_type = record_bits[0]; 300 const size_t msg_len = loadBigEndian24((&record_bits[1])[0 .. 3]); 301 const ushort message_seq = loadBigEndian!ushort(&record_bits[4], 0); 302 const size_t fragment_offset = loadBigEndian24((&record_bits[6])[0 .. 3]); 303 const size_t fragment_length = loadBigEndian24((&record_bits[9])[0 .. 3]); 304 305 const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length; 306 307 if (record_size < total_size) 308 throw new DecodingError("Bad lengths in DTLS header"); 309 310 if (message_seq >= m_in_message_seq) 311 { 312 m_messages[message_seq] = HandshakeReassembly.init; 313 m_messages[message_seq].addFragment(&record_bits[DTLS_HANDSHAKE_HEADER_LEN], 314 fragment_length, 315 fragment_offset, 316 epoch, 317 msg_type, 318 msg_len); 319 } 320 else { 321 // TODO: detect retransmitted flight 322 } 323 record_bits += total_size; 324 record_size -= total_size; 325 } 326 } 327 328 override NextRecord getNextRecord(bool expecting_ccs) 329 { 330 // Expecting a message means the last flight is concluded 331 332 if (!m_flights[$-1].empty) 333 m_flights.pushBack(Vector!ushort()); 334 335 if (expecting_ccs) 336 { 337 if (m_messages.length > 0) 338 { 339 const ushort current_epoch = m_messages[cast(ushort)0].epoch(); 340 341 if (m_ccs_epochs.canFind(current_epoch)) 342 return NextRecord(HANDSHAKE_CCS, Vector!ubyte()); 343 } 344 345 return NextRecord(HANDSHAKE_NONE, Vector!ubyte()); 346 } 347 348 auto rec = m_messages.get(m_in_message_seq, HandshakeReassembly.init); 349 350 if (rec is HandshakeReassembly.init || !rec.complete()) 351 return NextRecord(HANDSHAKE_NONE, Vector!ubyte()); 352 353 m_in_message_seq += 1; 354 355 return rec.message(); 356 } 357 358 private: 359 360 Vector!ubyte formatFragment(const(ubyte)* fragment, 361 size_t frag_len, 362 ushort frag_offset, 363 ushort msg_len, 364 HandshakeType type, 365 ushort msg_sequence) const 366 { 367 Vector!ubyte send_buf = Vector!ubyte(12 + frag_len); 368 369 send_buf[0] = type; 370 371 storeBigEndian24((&send_buf[1])[0 .. 3], msg_len); 372 373 storeBigEndian(msg_sequence, &send_buf[4]); 374 375 storeBigEndian24((&send_buf[6])[0 .. 3], frag_offset); 376 storeBigEndian24((&send_buf[9])[0 .. 3], frag_len); 377 378 copyMem(&send_buf[12], fragment, frag_len); 379 380 return send_buf; 381 } 382 383 Vector!ubyte formatWSeq(const ref Vector!ubyte msg, 384 HandshakeType type, 385 ushort msg_sequence) const 386 { 387 return formatFragment(msg.ptr, msg.length, cast(ushort) 0, cast(ushort) msg.length, type, msg_sequence); 388 } 389 390 391 struct HandshakeReassembly 392 { 393 public: 394 void addFragment(const(ubyte)* fragment, 395 size_t fragment_length, 396 size_t fragment_offset, 397 ushort epoch, 398 ubyte msg_type, 399 size_t msg_length) 400 { 401 if (complete()) 402 return; // already have entire message, ignore this 403 404 if (m_msg_type == HANDSHAKE_NONE) 405 { 406 m_epoch = epoch; 407 m_msg_type = msg_type; 408 m_msg_length = msg_length; 409 } 410 411 if (msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch) 412 throw new DecodingError("Inconsistent values in fragmented DTLS handshake header"); 413 414 if (fragment_offset > m_msg_length) 415 throw new DecodingError("Fragment offset past end of message"); 416 417 if (fragment_offset + fragment_length > m_msg_length) 418 throw new DecodingError("Fragment overlaps past end of message"); 419 420 if (fragment_offset == 0 && fragment_length == m_msg_length) 421 { 422 m_fragments.clear(); 423 m_message[] = fragment[0 .. fragment_length]; 424 } 425 else 426 { 427 /* 428 * FIXME. This is a pretty lame way to do defragmentation, huge 429 * overhead with a tree node per ubyte. 430 * 431 * Also should confirm that all overlaps have no changes, 432 * otherwise we expose ourselves to the classic fingerprinting 433 * and IDS evasion attacks on IP fragmentation. 434 */ 435 foreach (size_t i; 0 .. fragment_length) 436 m_fragments[fragment_offset+i] = cast(ubyte)fragment[i]; 437 438 if (m_fragments.length == m_msg_length) 439 { 440 m_message.resize(m_msg_length); 441 foreach (size_t i; 0 .. m_msg_length) 442 m_message[i] = m_fragments[i]; 443 m_fragments.clear(); 444 } 445 } 446 } 447 448 bool complete() const 449 { 450 return (m_msg_type != HANDSHAKE_NONE && m_message.length == m_msg_length); 451 } 452 453 ushort epoch() const { return m_epoch; } 454 455 NextRecord message() const 456 { 457 if (!complete()) 458 throw new InternalError("DatagramHandshakeIO - message not complete"); 459 460 return NextRecord(cast(HandshakeType)(m_msg_type), m_message.dup); 461 } 462 463 private: 464 ubyte m_msg_type = HANDSHAKE_NONE; 465 size_t m_msg_length = 0; 466 ushort m_epoch = 0; 467 468 HashMapRef!(size_t, ubyte) m_fragments; 469 Array!ubyte m_message; 470 } 471 472 struct MessageInfo 473 { 474 this(ushort e, HandshakeType mt, const ref Vector!ubyte msg) 475 { 476 epoch = e; 477 msg_type = mt; 478 msg_bits = msg.dupr; 479 } 480 481 ushort epoch = 0xFFFF; 482 HandshakeType msg_type = HANDSHAKE_NONE; 483 Array!ubyte msg_bits; 484 }; 485 486 487 ConnectionSequenceNumbers m_seqs; 488 HashMap!(ushort, HandshakeReassembly) m_messages; 489 ushort[] m_ccs_epochs; 490 Vector!( Array!ushort ) m_flights; 491 HashMap!(ushort, MessageInfo ) m_flight_data; 492 493 ulong m_last_write = 0; 494 ulong m_next_timeout = 0; 495 496 ushort m_in_message_seq = 0; 497 ushort m_out_message_seq = 0; 498 InternalDataWriter m_send_hs; 499 ushort m_mtu; 500 } 501 502 503 private: 504 505 size_t loadBigEndian24(in ubyte[3] q) 506 { 507 return make_uint(0, q[0], q[1], q[2]); 508 } 509 510 void storeBigEndian24(ref ubyte[3] output, size_t val) 511 { 512 output[0] = get_byte!uint(1, cast(uint) val); 513 output[1] = get_byte!uint(2, cast(uint) val); 514 output[2] = get_byte!uint(3, cast(uint) val); 515 } 516 517 size_t splitForMtu(size_t mtu, size_t msg_size) 518 { 519 __gshared immutable size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers 520 521 const size_t parts = (msg_size + mtu) / mtu; 522 523 if (parts + DTLS_HEADERS_SIZE > mtu) 524 return parts + 1; 525 526 return parts; 527 }