1 /** 2 * TLS Handshake State 3 * 4 * Copyright: 5 * (C) 2004-2006,2011,2012,2015 Jack Lloyd 6 * (C) 2014-2015 Etienne Cimon 7 * 8 * License: 9 * Botan is released under the Simplified BSD License (see LICENSE.md) 10 */ 11 module botan.tls.handshake_state; 12 13 import botan.constants; 14 static if (BOTAN_HAS_TLS): 15 package: 16 17 import botan.tls.handshake_hash; 18 import botan.tls.handshake_io; 19 import botan.tls.session_key; 20 import botan.tls.ciphersuite; 21 import botan.tls.exceptn; 22 import botan.tls.messages; 23 import botan.pubkey.pk_keys; 24 import botan.pubkey.pubkey; 25 import botan.kdf.kdf; 26 import botan.tls.messages; 27 import botan.tls.record; 28 29 package: 30 /** 31 * TLS Handshake State 32 */ 33 class HandshakeState 34 { 35 public: 36 /* 37 * Initialize the TLS Handshake State 38 */ 39 this(HandshakeIO io, void delegate(in HandshakeMessage) msg_callback = null) 40 { 41 m_ciphersuite = TLSCiphersuite.init; 42 m_session_keys = TLSSessionKeys.init; 43 m_msg_callback = msg_callback; 44 m_handshake_io = io; 45 m_version = m_handshake_io.initialRecordVersion(); 46 m_handshake_hash.reset(); 47 } 48 49 HandshakeIO handshakeIo() { return *m_handshake_io; } 50 51 /** 52 * Return true iff we have received a particular message already 53 * Params: 54 * handshake_msg = the message type 55 */ 56 bool receivedHandshakeMsg(HandshakeType handshake_msg) const 57 { 58 const uint mask = bitmaskForHandshakeType(handshake_msg); 59 60 return cast(bool)(m_hand_received_mask & mask); 61 } 62 63 /** 64 * Confirm that we were expecting this message type 65 * Params: 66 * handshake_msg = the message type 67 */ 68 void confirmTransitionTo(HandshakeType handshake_msg) 69 { 70 const uint mask = bitmaskForHandshakeType(handshake_msg); 71 72 m_hand_received_mask |= mask; 73 74 const bool ok = cast(bool)(m_hand_expecting_mask & mask); // overlap? 75 76 if (!ok) 77 throw new TLSUnexpectedMessage("Unexpected state transition in handshake, got " ~ 78 to!string(handshake_msg) ~ 79 " expected " ~ to!string(m_hand_expecting_mask) ~ 80 " received " ~ to!string(m_hand_received_mask)); 81 82 /* We don't know what to expect next, so force a call to 83 set_expected_next; if it doesn't happen, the next transition 84 check will always fail which is what we want. 85 */ 86 m_hand_expecting_mask = 0; 87 } 88 89 /** 90 * Record that we are expecting a particular message type next 91 * Params: 92 * handshake_msg = the message type 93 */ 94 void setExpectedNext(HandshakeType handshake_msg) 95 { 96 m_hand_expecting_mask |= bitmaskForHandshakeType(handshake_msg); 97 } 98 99 NextRecord getNextHandshakeMsg() 100 { 101 const bool expecting_ccs = cast(bool)(bitmaskForHandshakeType(HANDSHAKE_CCS) & m_hand_expecting_mask); 102 103 return m_handshake_io.getNextRecord(expecting_ccs); 104 } 105 106 Vector!ubyte sessionTicket() const 107 { 108 if (newSessionTicket() && !newSessionTicket().ticket().empty()) 109 return newSessionTicket().ticket().dup; 110 111 return clientHello().sessionTicket(); 112 } 113 114 const(Pair!(string, SignatureFormat)) 115 understandSigFormat(in PublicKey key, string hash_algo, string sig_algo, bool for_client_auth) const 116 { 117 string algo_name = key.algoName; 118 119 /* 120 FIXME: This should check what was sent against the client hello 121 preferences, or the certificate request, to ensure it was allowed 122 by those restrictions. 123 124 Or not? 125 */ 126 127 if (this.Version().supportsNegotiableSignatureAlgorithms()) 128 { 129 if (hash_algo == "") 130 throw new DecodingError("Counterparty did not send hash/sig IDS"); 131 132 if (sig_algo != algo_name) 133 algo_name = sig_algo; 134 //throw new DecodingError("Counterparty sent inconsistent key and sig types"); 135 } 136 else 137 { 138 if (hash_algo != "" || sig_algo != "") 139 throw new DecodingError("Counterparty sent hash/sig IDs with old version"); 140 } 141 142 if (algo_name == "RSA") 143 { 144 if (!this.Version().supportsNegotiableSignatureAlgorithms()) 145 { 146 hash_algo = "Parallel(MD5,SHA-160)"; 147 } 148 149 const string padding = "EMSA3(" ~ hash_algo ~ ")"; 150 return makePair(padding, IEEE_1363); 151 } 152 else if (algo_name == "RSA-PSS") { 153 const string padding = "PSSR(" ~ hash_algo ~ ")"; 154 return makePair(padding, IEEE_1363); 155 } 156 else if (algo_name == "DSA" || algo_name == "ECDSA") 157 { 158 if (!this.Version().supportsNegotiableSignatureAlgorithms()) 159 { 160 hash_algo = "SHA-1"; 161 } 162 163 const string padding = "EMSA1(" ~ hash_algo ~ ")"; 164 165 return makePair(padding, DER_SEQUENCE); 166 } 167 168 throw new InvalidArgument(algo_name ~ " is invalid/unknown for TLS signatures"); 169 } 170 171 const(Pair!(string, SignatureFormat)) 172 chooseSigFormat(in PrivateKey key, 173 ref string hash_algo_out, 174 ref string sig_algo_out, 175 bool for_client_auth, 176 in TLSPolicy policy) const 177 { 178 const string sig_algo = key.algoName; 179 180 const string hash_algo = chooseHash(sig_algo, 181 this.Version(), 182 policy, 183 for_client_auth, 184 clientHello(), 185 certReq()); 186 187 if (this.Version().supportsNegotiableSignatureAlgorithms()) 188 { 189 hash_algo_out = hash_algo; 190 sig_algo_out = sig_algo; 191 } 192 193 if (sig_algo == "RSA") 194 { 195 const string padding = "EMSA3(" ~ hash_algo ~ ")"; 196 197 return makePair(padding, IEEE_1363); 198 } 199 else if (sig_algo == "RSA-PSS") 200 { 201 const string padding = "PSSR(" ~ hash_algo ~ ")"; 202 203 return makePair(padding, IEEE_1363); 204 } 205 else if (sig_algo == "DSA" || sig_algo == "ECDSA") 206 { 207 const string padding = "EMSA1(" ~ hash_algo ~ ")"; 208 209 return makePair(padding, DER_SEQUENCE); 210 } 211 212 throw new InvalidArgument(sig_algo ~ " is invalid/unknown for TLS signatures"); 213 } 214 215 const(string) srpIdentifier() const 216 { 217 if (ciphersuite().valid() && ciphersuite().kexAlgo() == "SRP_SHA") 218 return clientHello().srpIdentifier(); 219 220 return ""; 221 } 222 223 KDF protocolSpecificPrf() const 224 { 225 if (Version().supportsCiphersuiteSpecificPrf()) 226 { 227 const string prf_algo = ciphersuite().prfAlgo(); 228 229 if (prf_algo == "MD5" || prf_algo == "SHA-1") 230 return getKdf("TLS-12-PRF(SHA-256)"); 231 232 return getKdf("TLS-12-PRF(" ~ prf_algo ~ ")"); 233 } 234 else 235 { 236 // TLS v1.0, v1.1 and DTLS v1.0 237 return getKdf("TLS-PRF"); 238 } 239 240 // throw new InternalError("Unknown version code " ~ Version().toString()); 241 } 242 243 const(TLSProtocolVersion) Version() const { return m_version; } 244 245 void setOriginalHandshakeHash(SecureVector!ubyte orig_hs_hash) 246 { 247 m_orig_hs_hash = orig_hs_hash.move(); 248 } 249 250 void setVersion(in TLSProtocolVersion _version) 251 { 252 m_version = _version; 253 } 254 255 void helloVerifyRequest(in HelloVerifyRequest hello_verify) 256 { 257 noteMessage(hello_verify); 258 259 m_client_hello.updateHelloCookie(hello_verify); 260 hash().reset(); 261 hash().update(handshakeIo().send(*m_client_hello)); 262 noteMessage(*m_client_hello); 263 } 264 265 266 void clientHello(ClientHello clientHello) 267 { 268 m_client_hello = clientHello; 269 noteMessage(*m_client_hello); 270 } 271 272 void serverHello(ServerHello server_hello) 273 { 274 m_server_hello = server_hello; 275 m_ciphersuite = TLSCiphersuite.byId(m_server_hello.ciphersuite()); 276 noteMessage(*m_server_hello); 277 } 278 279 void serverCerts(Certificate server_certs) 280 { 281 m_server_certs = server_certs; 282 noteMessage(*m_server_certs); 283 } 284 285 void serverKex(ServerKeyExchange server_kex) 286 { 287 m_server_kex = server_kex; 288 noteMessage(*m_server_kex); 289 } 290 291 void certReq(CertificateReq cert_req) 292 { 293 m_cert_req = cert_req; 294 noteMessage(*m_cert_req); 295 } 296 297 void serverHelloDone(ServerHelloDone server_hello_done) 298 { 299 m_server_hello_done = server_hello_done; 300 noteMessage(*m_server_hello_done); 301 } 302 303 void clientCerts(Certificate client_certs) 304 { 305 m_client_certs = client_certs; 306 noteMessage(*m_client_certs); 307 } 308 309 void clientKex(ClientKeyExchange client_kex) 310 { 311 m_client_kex = client_kex; 312 noteMessage(*m_client_kex); 313 } 314 315 void clientVerify(CertificateVerify client_verify) 316 { 317 m_client_verify = client_verify; 318 noteMessage(*m_client_verify); 319 } 320 321 void channelID(ChannelID channel_id) 322 { 323 m_channel_id = channel_id; 324 noteMessage(*m_channel_id); 325 } 326 327 void newSessionTicket(NewSessionTicket new_session_ticket) 328 { 329 m_new_session_ticket = new_session_ticket; 330 noteMessage(*m_new_session_ticket); 331 } 332 333 void serverFinished(Finished server_finished) 334 { 335 m_server_finished = server_finished; 336 noteMessage(*m_server_finished); 337 } 338 339 void clientFinished(Finished client_finished) 340 { 341 m_client_finished = client_finished; 342 noteMessage(*m_client_finished); 343 } 344 345 const(ClientHello) clientHello() const 346 { return *m_client_hello; } 347 348 const(ServerHello) serverHello() const 349 { return *m_server_hello; } 350 351 const(Certificate) serverCerts() const 352 { return *m_server_certs; } 353 354 const(ServerKeyExchange) serverKex() const 355 { return *m_server_kex; } 356 357 const(CertificateReq) certReq() const 358 { return *m_cert_req; } 359 360 const(ServerHelloDone) serverHelloDone() const 361 { return *m_server_hello_done; } 362 363 const(Certificate) clientCerts() const 364 { return *m_client_certs; } 365 366 const(ClientKeyExchange) clientKex() const 367 { return *m_client_kex; } 368 369 const(CertificateVerify) clientVerify() const 370 { return *m_client_verify; } 371 372 const(ChannelID) channelID() const 373 { return *m_channel_id; } 374 375 const(NewSessionTicket) newSessionTicket() const 376 { return *m_new_session_ticket; } 377 378 const(Finished) serverFinished() const 379 { return *m_server_finished; } 380 381 const(Finished) clientFinished() const 382 { return *m_client_finished; } 383 384 ref const(TLSCiphersuite) ciphersuite() const { return m_ciphersuite; } 385 386 ref const(TLSSessionKeys) sessionKeys() const { return m_session_keys; } 387 388 void computeSessionKeys() 389 { 390 m_session_keys = TLSSessionKeys(this, clientKex().preMasterSecret().dup, false); 391 } 392 393 void computeSessionKeys()(auto ref SecureVector!ubyte resume_master_secret) 394 { 395 m_session_keys = TLSSessionKeys(this, resume_master_secret, true); 396 } 397 398 ref const(SecureVector!ubyte) originalHandshakeHash() const { return m_orig_hs_hash; } 399 400 ref HandshakeHash hash() { return m_handshake_hash; } 401 402 ref const(HandshakeHash) hash() const { return m_handshake_hash; } 403 404 void noteMessage(in HandshakeMessage msg) 405 { 406 if (m_msg_callback) 407 m_msg_callback(msg); 408 } 409 410 411 private: 412 413 void delegate(in HandshakeMessage) m_msg_callback; 414 415 Unique!HandshakeIO m_handshake_io; 416 417 uint m_hand_expecting_mask = 0; 418 uint m_hand_received_mask = 0; 419 TLSProtocolVersion m_version; 420 TLSCiphersuite m_ciphersuite; 421 TLSSessionKeys m_session_keys; 422 HandshakeHash m_handshake_hash; 423 // Used to save the original handshake hash in the session for ChannelID Resumption 424 SecureVector!ubyte m_orig_hs_hash; 425 426 Unique!ClientHello m_client_hello; 427 Unique!ServerHello m_server_hello; 428 Unique!Certificate m_server_certs; 429 Unique!ServerKeyExchange m_server_kex; 430 Unique!CertificateReq m_cert_req; 431 Unique!ServerHelloDone m_server_hello_done; 432 Unique!Certificate m_client_certs; 433 Unique!ClientKeyExchange m_client_kex; 434 Unique!CertificateVerify m_client_verify; 435 Unique!ChannelID m_channel_id; 436 Unique!NewSessionTicket m_new_session_ticket; 437 Unique!Finished m_server_finished; 438 Unique!Finished m_client_finished; 439 } 440 441 442 private: 443 444 uint bitmaskForHandshakeType(HandshakeType type) 445 { 446 switch(type) 447 { 448 case HELLO_VERIFY_REQUEST: 449 return (1 << 0); 450 451 case HELLO_REQUEST: 452 return (1 << 1); 453 454 /* 455 * Same code point for both client hello styles 456 */ 457 case CLIENT_HELLO: 458 return (1 << 2); 459 460 case SERVER_HELLO: 461 return (1 << 3); 462 463 case CERTIFICATE: 464 return (1 << 4); 465 466 case CERTIFICATE_URL: 467 return (1 << 5); 468 469 case CERTIFICATE_STATUS: 470 return (1 << 6); 471 472 case SERVER_KEX: 473 return (1 << 7); 474 475 case CERTIFICATE_REQUEST: 476 return (1 << 8); 477 478 case SERVER_HELLO_DONE: 479 return (1 << 9); 480 481 case CERTIFICATE_VERIFY: 482 return (1 << 10); 483 484 case CLIENT_KEX: 485 return (1 << 11); 486 487 case NEW_SESSION_TICKET: 488 return (1 << 12); 489 490 case HANDSHAKE_CCS: 491 return (1 << 13); 492 493 case FINISHED: 494 return (1 << 14); 495 496 // allow explicitly disabling new handshakes 497 case HANDSHAKE_NONE: 498 return 0; 499 500 default: 501 throw new InternalError("Unknown handshake type " ~ to!string(type)); 502 } 503 } 504 505 506 507 string chooseHash(in string sig_algo, 508 TLSProtocolVersion negotiated_version, 509 in TLSPolicy policy, 510 bool for_client_auth, 511 in ClientHello client_hello, 512 in CertificateReq cert_req) 513 { 514 if (!negotiated_version.supportsNegotiableSignatureAlgorithms()) 515 { 516 if (sig_algo == "RSA") 517 return "Parallel(MD5,SHA-160)"; 518 519 if (sig_algo == "DSA") 520 return "SHA-1"; 521 522 if (sig_algo == "ECDSA") 523 return "SHA-1"; 524 525 throw new InternalError("Unknown TLS signature algo " ~ sig_algo); 526 } 527 528 Vector!(Pair!(string, string)) supported_algos = for_client_auth ? cert_req.supportedAlgos() : client_hello.supportedAlgos(); 529 530 if (!supported_algos.empty()) 531 { 532 const Vector!string hashes = policy.allowedSignatureHashes(); 533 534 /* 535 * Choose our most preferred hash that the counterparty supports 536 * in pairing with the signature algorithm we want to use. 537 */ 538 foreach (hash; hashes[]) 539 { 540 foreach (algo; supported_algos[]) 541 { 542 if (algo.first == hash && algo.second == sig_algo) 543 return hash; 544 } 545 } 546 } 547 548 // TLS v1.2 default hash if the counterparty sent nothing 549 return "SHA-1"; 550 }