1 /** 2 * TLS Handshake Serialization 3 * 4 * Copyright: 5 * (C) 2012 Jack Lloyd 6 * (C) 2014-2015 Etienne Cimon 7 * 8 * License: 9 * Botan is released under the Simplified BSD License (see LICENSE.md) 10 */ 11 module botan.tls.handshake_io; 12 13 import botan.constants; 14 static if (BOTAN_HAS_TLS): 15 package: 16 17 import botan.tls.magic; 18 import botan.tls.version_; 19 import botan.utils.loadstor; 20 import botan.tls.messages; 21 import botan.tls.record; 22 import botan.tls.seq_numbers; 23 import botan.utils.exceptn; 24 import std.algorithm : count, min; 25 import botan.utils.types; 26 import memutils.hashmap; 27 import std.typecons : Tuple; 28 import std.datetime; 29 30 struct NextRecord 31 { 32 HandshakeType type; 33 Vector!ubyte data; 34 } 35 36 /** 37 * Handshake IO Interface 38 */ 39 interface HandshakeIO 40 { 41 public: 42 abstract TLSProtocolVersion initialRecordVersion() const; 43 44 abstract Vector!ubyte send(in HandshakeMessage msg); 45 46 abstract Vector!ubyte format(const ref Vector!ubyte handshake_msg, HandshakeType handshake_type) const; 47 48 abstract bool timeoutCheck(); 49 50 abstract void addRecord(const ref Vector!ubyte record, RecordType type, ulong sequence_number); 51 52 /** 53 * Returns (HANDSHAKE_NONE, Vector!( )()) if no message currently available 54 */ 55 abstract NextRecord getNextRecord(bool expecting_ccs); 56 } 57 58 /** 59 * Handshake IO for stream-based handshakes 60 */ 61 package final class StreamHandshakeIO : HandshakeIO 62 { 63 public: 64 alias InternalDataWriter = void delegate(ubyte, const ref Vector!ubyte); 65 66 this(InternalDataWriter writer) 67 { 68 m_send_hs = writer; 69 m_queue.reserve(128); 70 } 71 72 override TLSProtocolVersion initialRecordVersion() const 73 { 74 return cast(TLSProtocolVersion)TLSProtocolVersion.TLS_V10; 75 } 76 77 override bool timeoutCheck() { return false; } 78 79 override Vector!ubyte send(in HandshakeMessage msg) 80 { 81 const Vector!ubyte msg_bits = msg.serialize(); 82 83 if (msg.type() == HANDSHAKE_CCS) 84 { 85 m_send_hs(CHANGE_CIPHER_SPEC, msg_bits); 86 return Vector!ubyte(); // not included in handshake hashes 87 } 88 89 Vector!ubyte buf = format(msg_bits, msg.type()); 90 m_send_hs(HANDSHAKE, buf); 91 return buf.move(); 92 } 93 94 override Vector!ubyte format(const ref Vector!ubyte msg, HandshakeType type) const 95 { 96 Vector!ubyte send_buf = Vector!ubyte(4 + msg.length); 97 98 const size_t buf_size = msg.length; 99 100 send_buf[0] = type; 101 102 storeBigEndian24(send_buf.ptr[1 .. 4], buf_size); 103 104 copyMem(send_buf.ptr+4, msg.ptr, msg.length); 105 106 return send_buf.move(); 107 } 108 109 override void addRecord(const ref Vector!ubyte record, RecordType record_type, ulong) 110 { 111 if (record_type == HANDSHAKE) 112 { 113 m_queue ~= record[]; 114 } 115 else if (record_type == CHANGE_CIPHER_SPEC) 116 { 117 if (record.length != 1 || record[0] != 1) 118 throw new DecodingError("Invalid ChangeCipherSpec"); 119 120 // Pretend it's a regular handshake message of zero length 121 const(ubyte)[] ccs_hs = [ HANDSHAKE_CCS, 0, 0, 0 ]; 122 m_queue.insert(ccs_hs); 123 } 124 else 125 throw new DecodingError("Unknown message type " ~ record_type.to!string ~ " in handshake processing"); 126 } 127 128 override NextRecord getNextRecord(bool) 129 { 130 if (m_queue.length >= 4) 131 { 132 const size_t length = make_uint(0, m_queue[1], m_queue[2], m_queue[3]); 133 if (m_queue.length >= length + 4) 134 { 135 HandshakeType type = cast(HandshakeType)(m_queue[0]); 136 137 Vector!ubyte contents = Vector!ubyte(m_queue.ptr[4 .. 4 + length]); 138 Vector!ubyte ret = Vector!ubyte(m_queue.ptr[4 + length .. m_queue.length]); 139 m_queue = ret; 140 141 return NextRecord(type, contents.move()); 142 } 143 } 144 145 return NextRecord(HANDSHAKE_NONE, Vector!ubyte()); 146 } 147 148 private: 149 Vector!ubyte m_queue; 150 InternalDataWriter m_send_hs; 151 } 152 153 /** 154 * Handshake IO for datagram-based handshakes 155 */ 156 package final class DatagramHandshakeIO : HandshakeIO 157 { 158 alias InternalDataWriter = void delegate(ushort, ubyte, const ref Vector!ubyte); 159 private: 160 // 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1 161 const ulong INITIAL_TIMEOUT = 1*1000; 162 const ulong MAXIMUM_TIMEOUT = 60*1000; 163 164 static ulong steadyClockMs() { 165 return (Clock.currTime(UTC()).stdTime - SysTime(DateTime(1970, 1, 1, 0, 0, 0), UTC()).stdTime)/10_000; 166 } 167 168 public: 169 this(ConnectionSequenceNumbers seq, InternalDataWriter writer, ushort mtu) 170 { 171 m_seqs = seq; 172 m_flights.length = 1; 173 m_send_hs = writer; 174 m_mtu = mtu; 175 } 176 177 override TLSProtocolVersion initialRecordVersion() const 178 { 179 return TLSProtocolVersion(TLSProtocolVersion.DTLS_V10); 180 } 181 182 override bool timeoutCheck() { 183 import std.range : empty; 184 if (m_last_write == 0 || (m_flights.length > 1 && !m_flights[0].empty)) 185 { 186 /* 187 If we haven't written anything yet obviously no timeout. 188 Also no timeout possible if we are mid-flight, 189 */ 190 return false; 191 } 192 const ulong ms_since_write = steadyClockMs() - m_last_write; 193 if (ms_since_write < m_next_timeout) 194 return false; 195 Vector!ushort flight; 196 if (m_flights.length == 1) 197 flight = m_flights[0]; // lost initial client hello 198 else 199 flight = m_flights[m_flights.length - 2]; 200 assert(flight.length > 0, "Nonempty flight to retransmit"); 201 ushort epoch = m_flight_data[flight[0]].epoch; 202 foreach(msg_seq; flight) 203 { 204 auto msg = m_flight_data[msg_seq]; 205 if (msg.epoch != epoch) 206 { 207 // Epoch gap: insert the CCS 208 Vector!ubyte ccs; 209 ccs ~= 1; 210 m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs); 211 } 212 sendMessage(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits); 213 epoch = msg.epoch; 214 } 215 m_next_timeout = min(2 * m_next_timeout, MAXIMUM_TIMEOUT); 216 return true; 217 } 218 219 override Vector!ubyte send(in HandshakeMessage msg) 220 { 221 Vector!ubyte msg_bits = msg.serialize(); 222 ushort epoch = m_seqs.currentWriteEpoch(); 223 HandshakeType msg_type = msg.type(); 224 225 if (msg_type == HANDSHAKE_CCS) 226 { 227 m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits); 228 return Vector!ubyte(); // not included in handshake hashes 229 } 230 231 // Note: not saving CCS, instead we know it was there due to change in epoch 232 m_flights[$-1].pushBack(m_out_message_seq); 233 m_flight_data[m_out_message_seq] = MessageInfo(epoch, msg_type, msg_bits); 234 235 m_out_message_seq += 1; 236 m_last_write = steadyClockMs(); 237 m_next_timeout = INITIAL_TIMEOUT; 238 239 return sendMessage(cast(ushort)(m_out_message_seq - 1), epoch, msg_type, msg_bits); 240 } 241 242 Vector!ubyte sendMessage(ushort msg_seq, ushort epoch, HandshakeType msg_type, const ref Vector!ubyte msg_bits) 243 { 244 const Vector!ubyte no_fragment = formatWSeq(msg_bits, msg_type, msg_seq); 245 246 if (no_fragment.length + DTLS_HEADER_SIZE <= m_mtu) 247 m_send_hs(epoch, HANDSHAKE, no_fragment); 248 else 249 { 250 const size_t parts = splitForMtu(m_mtu, msg_bits.length); 251 252 const size_t parts_size = (msg_bits.length + parts) / parts; 253 254 size_t frag_offset = 0; 255 256 while (frag_offset != msg_bits.length) 257 { 258 import std.algorithm : min; 259 const size_t frag_len = min(msg_bits.length - frag_offset, parts_size); 260 auto frag = formatFragment(cast(const(ubyte)*)&msg_bits[frag_offset], 261 frag_len, 262 cast(ushort)frag_offset, 263 cast(ushort)msg_bits.length, 264 msg_type, 265 msg_seq); 266 267 m_send_hs(epoch, HANDSHAKE, frag); 268 269 frag_offset += frag_len; 270 } 271 } 272 273 return (cast()no_fragment).move; 274 } 275 override Vector!ubyte format(const ref Vector!ubyte msg, HandshakeType type) const 276 { 277 return formatWSeq(msg, type, cast(ushort) (m_in_message_seq - 1)); 278 } 279 280 override void addRecord(const ref Vector!ubyte record, RecordType record_type, ulong record_sequence) 281 { 282 const ushort epoch = cast(ushort)(record_sequence >> 48); 283 284 if (record_type == CHANGE_CIPHER_SPEC) 285 { 286 // TODO: check this is otherwise empty 287 m_ccs_epochs ~= epoch; 288 return; 289 } 290 291 __gshared immutable size_t DTLS_HANDSHAKE_HEADER_LEN = 12; 292 293 const(ubyte)* record_bits = record.ptr; 294 size_t record_size = record.length; 295 296 while (record_size) 297 { 298 if (record_size < DTLS_HANDSHAKE_HEADER_LEN) 299 return; // completely bogus? at least degenerate/weird 300 301 const ubyte msg_type = record_bits[0]; 302 const size_t msg_len = loadBigEndian24((&record_bits[1])[0 .. 3]); 303 const ushort message_seq = loadBigEndian!ushort(&record_bits[4], 0); 304 const size_t fragment_offset = loadBigEndian24((&record_bits[6])[0 .. 3]); 305 const size_t fragment_length = loadBigEndian24((&record_bits[9])[0 .. 3]); 306 307 const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length; 308 309 if (record_size < total_size) 310 throw new DecodingError("Bad lengths in DTLS header"); 311 312 if (message_seq >= m_in_message_seq) 313 { 314 m_messages[message_seq] = HandshakeReassembly.init; 315 m_messages[message_seq].addFragment(&record_bits[DTLS_HANDSHAKE_HEADER_LEN], 316 fragment_length, 317 fragment_offset, 318 epoch, 319 msg_type, 320 msg_len); 321 } 322 else { 323 // TODO: detect retransmitted flight 324 } 325 record_bits += total_size; 326 record_size -= total_size; 327 } 328 } 329 330 override NextRecord getNextRecord(bool expecting_ccs) 331 { 332 // Expecting a message means the last flight is concluded 333 334 if (!m_flights[$-1].empty) 335 m_flights.pushBack(Vector!ushort()); 336 337 if (expecting_ccs) 338 { 339 if (m_messages.length > 0) 340 { 341 const ushort current_epoch = m_messages[cast(ushort)0].epoch(); 342 343 if (m_ccs_epochs.canFind(current_epoch)) 344 return NextRecord(HANDSHAKE_CCS, Vector!ubyte()); 345 } 346 347 return NextRecord(HANDSHAKE_NONE, Vector!ubyte()); 348 } 349 350 auto rec = m_messages.get(m_in_message_seq, HandshakeReassembly.init); 351 352 if (rec is HandshakeReassembly.init || !rec.complete()) 353 return NextRecord(HANDSHAKE_NONE, Vector!ubyte()); 354 355 m_in_message_seq += 1; 356 357 return rec.message(); 358 } 359 360 private: 361 362 Vector!ubyte formatFragment(const(ubyte)* fragment, 363 size_t frag_len, 364 ushort frag_offset, 365 ushort msg_len, 366 HandshakeType type, 367 ushort msg_sequence) const 368 { 369 Vector!ubyte send_buf = Vector!ubyte(12 + frag_len); 370 371 send_buf[0] = type; 372 373 storeBigEndian24((&send_buf[1])[0 .. 3], msg_len); 374 375 storeBigEndian(msg_sequence, &send_buf[4]); 376 377 storeBigEndian24((&send_buf[6])[0 .. 3], frag_offset); 378 storeBigEndian24((&send_buf[9])[0 .. 3], frag_len); 379 380 copyMem(&send_buf[12], fragment, frag_len); 381 382 return send_buf; 383 } 384 385 Vector!ubyte formatWSeq(const ref Vector!ubyte msg, 386 HandshakeType type, 387 ushort msg_sequence) const 388 { 389 return formatFragment(msg.ptr, msg.length, cast(ushort) 0, cast(ushort) msg.length, type, msg_sequence); 390 } 391 392 393 struct HandshakeReassembly 394 { 395 public: 396 void addFragment(const(ubyte)* fragment, 397 size_t fragment_length, 398 size_t fragment_offset, 399 ushort epoch, 400 ubyte msg_type, 401 size_t msg_length) 402 { 403 if (complete()) 404 return; // already have entire message, ignore this 405 406 if (m_msg_type == HANDSHAKE_NONE) 407 { 408 m_epoch = epoch; 409 m_msg_type = msg_type; 410 m_msg_length = msg_length; 411 } 412 413 if (msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch) 414 throw new DecodingError("Inconsistent values in fragmented DTLS handshake header"); 415 416 if (fragment_offset > m_msg_length) 417 throw new DecodingError("Fragment offset past end of message"); 418 419 if (fragment_offset + fragment_length > m_msg_length) 420 throw new DecodingError("Fragment overlaps past end of message"); 421 422 if (fragment_offset == 0 && fragment_length == m_msg_length) 423 { 424 m_fragments.clear(); 425 m_message[] = fragment[0 .. fragment_length]; 426 } 427 else 428 { 429 /* 430 * FIXME. This is a pretty lame way to do defragmentation, huge 431 * overhead with a tree node per ubyte. 432 * 433 * Also should confirm that all overlaps have no changes, 434 * otherwise we expose ourselves to the classic fingerprinting 435 * and IDS evasion attacks on IP fragmentation. 436 */ 437 foreach (size_t i; 0 .. fragment_length) 438 m_fragments[fragment_offset+i] = cast(ubyte)fragment[i]; 439 440 if (m_fragments.length == m_msg_length) 441 { 442 m_message.resize(m_msg_length); 443 foreach (size_t i; 0 .. m_msg_length) 444 m_message[i] = m_fragments[i]; 445 m_fragments.clear(); 446 } 447 } 448 } 449 450 bool complete() const 451 { 452 return (m_msg_type != HANDSHAKE_NONE && m_message.length == m_msg_length); 453 } 454 455 ushort epoch() const { return m_epoch; } 456 457 NextRecord message() const 458 { 459 if (!complete()) 460 throw new InternalError("DatagramHandshakeIO - message not complete"); 461 462 return NextRecord(cast(HandshakeType)(m_msg_type), m_message.dup); 463 } 464 465 private: 466 ubyte m_msg_type = HANDSHAKE_NONE; 467 size_t m_msg_length = 0; 468 ushort m_epoch = 0; 469 470 HashMapRef!(size_t, ubyte) m_fragments; 471 Array!ubyte m_message; 472 } 473 474 struct MessageInfo 475 { 476 this(ushort e, HandshakeType mt, const ref Vector!ubyte msg) 477 { 478 epoch = e; 479 msg_type = mt; 480 msg_bits = msg.dupr; 481 } 482 483 ushort epoch = 0xFFFF; 484 HandshakeType msg_type = HANDSHAKE_NONE; 485 Array!ubyte msg_bits; 486 }; 487 488 489 ConnectionSequenceNumbers m_seqs; 490 HashMap!(ushort, HandshakeReassembly) m_messages; 491 ushort[] m_ccs_epochs; 492 Vector!( Array!ushort ) m_flights; 493 HashMap!(ushort, MessageInfo ) m_flight_data; 494 495 ulong m_last_write = 0; 496 ulong m_next_timeout = 0; 497 498 ushort m_in_message_seq = 0; 499 ushort m_out_message_seq = 0; 500 InternalDataWriter m_send_hs; 501 ushort m_mtu; 502 } 503 504 505 private: 506 507 size_t loadBigEndian24(in ubyte[3] q) 508 { 509 return make_uint(0, q[0], q[1], q[2]); 510 } 511 512 void storeBigEndian24(ref ubyte[3] output, size_t val) 513 { 514 output[0] = get_byte!uint(1, cast(uint) val); 515 output[1] = get_byte!uint(2, cast(uint) val); 516 output[2] = get_byte!uint(3, cast(uint) val); 517 } 518 519 size_t splitForMtu(size_t mtu, size_t msg_size) 520 { 521 __gshared immutable size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers 522 523 const size_t parts = (msg_size + mtu) / mtu; 524 525 if (parts + DTLS_HEADERS_SIZE > mtu) 526 return parts + 1; 527 528 return parts; 529 }