1 /** 2 * TLS Handshake State 3 * 4 * Copyright: 5 * (C) 2004-2006,2011,2012,2015 Jack Lloyd 6 * (C) 2014-2015 Etienne Cimon 7 * 8 * License: 9 * Botan is released under the Simplified BSD License (see LICENSE.md) 10 */ 11 module botan.tls.handshake_state; 12 13 import botan.constants; 14 static if (BOTAN_HAS_TLS): 15 package: 16 17 import botan.tls.handshake_hash; 18 import botan.tls.handshake_io; 19 import botan.tls.session_key; 20 import botan.tls.ciphersuite; 21 import botan.tls.exceptn; 22 import botan.tls.messages; 23 import botan.pubkey.pk_keys; 24 import botan.pubkey.pubkey; 25 import botan.kdf.kdf; 26 import botan.tls.messages; 27 import botan.tls.record; 28 29 package: 30 /** 31 * TLS Handshake State 32 */ 33 class HandshakeState 34 { 35 public: 36 /* 37 * Initialize the TLS Handshake State 38 */ 39 this(HandshakeIO io, void delegate(in HandshakeMessage) msg_callback = null) 40 { 41 m_ciphersuite = TLSCiphersuite.init; 42 m_session_keys = TLSSessionKeys.init; 43 m_msg_callback = msg_callback; 44 m_handshake_io = io; 45 m_version = m_handshake_io.initialRecordVersion(); 46 m_handshake_hash.reset(); 47 } 48 49 HandshakeIO handshakeIo() { return *m_handshake_io; } 50 51 /** 52 * Return true iff we have received a particular message already 53 * Params: 54 * handshake_msg = the message type 55 */ 56 bool receivedHandshakeMsg(HandshakeType handshake_msg) const 57 { 58 const uint mask = bitmaskForHandshakeType(handshake_msg); 59 60 return cast(bool)(m_hand_received_mask & mask); 61 } 62 63 /** 64 * Confirm that we were expecting this message type 65 * Params: 66 * handshake_msg = the message type 67 */ 68 void confirmTransitionTo(HandshakeType handshake_msg) 69 { 70 const uint mask = bitmaskForHandshakeType(handshake_msg); 71 72 m_hand_received_mask |= mask; 73 74 const bool ok = cast(bool)(m_hand_expecting_mask & mask); // overlap? 75 76 if (!ok) 77 throw new TLSUnexpectedMessage("Unexpected state transition in handshake, got " ~ 78 to!string(handshake_msg) ~ 79 " expected " ~ to!string(m_hand_expecting_mask) ~ 80 " received " ~ to!string(m_hand_received_mask)); 81 82 /* We don't know what to expect next, so force a call to 83 set_expected_next; if it doesn't happen, the next transition 84 check will always fail which is what we want. 85 */ 86 m_hand_expecting_mask = 0; 87 } 88 89 /** 90 * Record that we are expecting a particular message type next 91 * Params: 92 * handshake_msg = the message type 93 */ 94 void setExpectedNext(HandshakeType handshake_msg) 95 { 96 m_hand_expecting_mask |= bitmaskForHandshakeType(handshake_msg); 97 } 98 99 NextRecord getNextHandshakeMsg() 100 { 101 const bool expecting_ccs = cast(bool)(bitmaskForHandshakeType(HANDSHAKE_CCS) & m_hand_expecting_mask); 102 103 return m_handshake_io.getNextRecord(expecting_ccs); 104 } 105 106 Vector!ubyte sessionTicket() const 107 { 108 if (newSessionTicket() && !newSessionTicket().ticket().empty()) 109 return newSessionTicket().ticket().dup; 110 111 return clientHello().sessionTicket(); 112 } 113 114 const(Pair!(string, SignatureFormat)) 115 understandSigFormat(in PublicKey key, string hash_algo, string sig_algo, bool for_client_auth) const 116 { 117 const string algo_name = key.algoName; 118 119 /* 120 FIXME: This should check what was sent against the client hello 121 preferences, or the certificate request, to ensure it was allowed 122 by those restrictions. 123 124 Or not? 125 */ 126 127 if (this.Version().supportsNegotiableSignatureAlgorithms()) 128 { 129 if (hash_algo == "") 130 throw new DecodingError("Counterparty did not send hash/sig IDS"); 131 132 if (sig_algo != algo_name) 133 throw new DecodingError("Counterparty sent inconsistent key and sig types"); 134 } 135 else 136 { 137 if (hash_algo != "" || sig_algo != "") 138 throw new DecodingError("Counterparty sent hash/sig IDs with old version"); 139 } 140 141 if (algo_name == "RSA") 142 { 143 if (!this.Version().supportsNegotiableSignatureAlgorithms()) 144 { 145 hash_algo = "Parallel(MD5,SHA-160)"; 146 } 147 148 const string padding = "EMSA3(" ~ hash_algo ~ ")"; 149 return makePair(padding, IEEE_1363); 150 } 151 else if (algo_name == "DSA" || algo_name == "ECDSA") 152 { 153 if (!this.Version().supportsNegotiableSignatureAlgorithms()) 154 { 155 hash_algo = "SHA-1"; 156 } 157 158 const string padding = "EMSA1(" ~ hash_algo ~ ")"; 159 160 return makePair(padding, DER_SEQUENCE); 161 } 162 163 throw new InvalidArgument(algo_name ~ " is invalid/unknown for TLS signatures"); 164 } 165 166 const(Pair!(string, SignatureFormat)) 167 chooseSigFormat(in PrivateKey key, 168 ref string hash_algo_out, 169 ref string sig_algo_out, 170 bool for_client_auth, 171 in TLSPolicy policy) const 172 { 173 const string sig_algo = key.algoName; 174 175 const string hash_algo = chooseHash(sig_algo, 176 this.Version(), 177 policy, 178 for_client_auth, 179 clientHello(), 180 certReq()); 181 182 if (this.Version().supportsNegotiableSignatureAlgorithms()) 183 { 184 hash_algo_out = hash_algo; 185 sig_algo_out = sig_algo; 186 } 187 188 if (sig_algo == "RSA") 189 { 190 const string padding = "EMSA3(" ~ hash_algo ~ ")"; 191 192 return makePair(padding, IEEE_1363); 193 } 194 else if (sig_algo == "DSA" || sig_algo == "ECDSA") 195 { 196 const string padding = "EMSA1(" ~ hash_algo ~ ")"; 197 198 return makePair(padding, DER_SEQUENCE); 199 } 200 201 throw new InvalidArgument(sig_algo ~ " is invalid/unknown for TLS signatures"); 202 } 203 204 const(string) srpIdentifier() const 205 { 206 if (ciphersuite().valid() && ciphersuite().kexAlgo() == "SRP_SHA") 207 return clientHello().srpIdentifier(); 208 209 return ""; 210 } 211 212 KDF protocolSpecificPrf() const 213 { 214 if (Version().supportsCiphersuiteSpecificPrf()) 215 { 216 const string prf_algo = ciphersuite().prfAlgo(); 217 218 if (prf_algo == "MD5" || prf_algo == "SHA-1") 219 return getKdf("TLS-12-PRF(SHA-256)"); 220 221 return getKdf("TLS-12-PRF(" ~ prf_algo ~ ")"); 222 } 223 else 224 { 225 // TLS v1.0, v1.1 and DTLS v1.0 226 return getKdf("TLS-PRF"); 227 } 228 229 // throw new InternalError("Unknown version code " ~ Version().toString()); 230 } 231 232 const(TLSProtocolVersion) Version() const { return m_version; } 233 234 void setOriginalHandshakeHash(SecureVector!ubyte orig_hs_hash) 235 { 236 m_orig_hs_hash = orig_hs_hash.move(); 237 } 238 239 void setVersion(in TLSProtocolVersion _version) 240 { 241 m_version = _version; 242 } 243 244 void helloVerifyRequest(in HelloVerifyRequest hello_verify) 245 { 246 noteMessage(hello_verify); 247 248 m_client_hello.updateHelloCookie(hello_verify); 249 hash().reset(); 250 hash().update(handshakeIo().send(*m_client_hello)); 251 noteMessage(*m_client_hello); 252 } 253 254 255 void clientHello(ClientHello clientHello) 256 { 257 m_client_hello = clientHello; 258 noteMessage(*m_client_hello); 259 } 260 261 void serverHello(ServerHello server_hello) 262 { 263 m_server_hello = server_hello; 264 m_ciphersuite = TLSCiphersuite.byId(m_server_hello.ciphersuite()); 265 noteMessage(*m_server_hello); 266 } 267 268 void serverCerts(Certificate server_certs) 269 { 270 m_server_certs = server_certs; 271 noteMessage(*m_server_certs); 272 } 273 274 void serverKex(ServerKeyExchange server_kex) 275 { 276 m_server_kex = server_kex; 277 noteMessage(*m_server_kex); 278 } 279 280 void certReq(CertificateReq cert_req) 281 { 282 m_cert_req = cert_req; 283 noteMessage(*m_cert_req); 284 } 285 286 void serverHelloDone(ServerHelloDone server_hello_done) 287 { 288 m_server_hello_done = server_hello_done; 289 noteMessage(*m_server_hello_done); 290 } 291 292 void clientCerts(Certificate client_certs) 293 { 294 m_client_certs = client_certs; 295 noteMessage(*m_client_certs); 296 } 297 298 void clientKex(ClientKeyExchange client_kex) 299 { 300 m_client_kex = client_kex; 301 noteMessage(*m_client_kex); 302 } 303 304 void clientVerify(CertificateVerify client_verify) 305 { 306 m_client_verify = client_verify; 307 noteMessage(*m_client_verify); 308 } 309 310 void channelID(ChannelID channel_id) 311 { 312 m_channel_id = channel_id; 313 noteMessage(*m_channel_id); 314 } 315 316 void newSessionTicket(NewSessionTicket new_session_ticket) 317 { 318 m_new_session_ticket = new_session_ticket; 319 noteMessage(*m_new_session_ticket); 320 } 321 322 void serverFinished(Finished server_finished) 323 { 324 m_server_finished = server_finished; 325 noteMessage(*m_server_finished); 326 } 327 328 void clientFinished(Finished client_finished) 329 { 330 m_client_finished = client_finished; 331 noteMessage(*m_client_finished); 332 } 333 334 const(ClientHello) clientHello() const 335 { return *m_client_hello; } 336 337 const(ServerHello) serverHello() const 338 { return *m_server_hello; } 339 340 const(Certificate) serverCerts() const 341 { return *m_server_certs; } 342 343 const(ServerKeyExchange) serverKex() const 344 { return *m_server_kex; } 345 346 const(CertificateReq) certReq() const 347 { return *m_cert_req; } 348 349 const(ServerHelloDone) serverHelloDone() const 350 { return *m_server_hello_done; } 351 352 const(Certificate) clientCerts() const 353 { return *m_client_certs; } 354 355 const(ClientKeyExchange) clientKex() const 356 { return *m_client_kex; } 357 358 const(CertificateVerify) clientVerify() const 359 { return *m_client_verify; } 360 361 const(ChannelID) channelID() const 362 { return *m_channel_id; } 363 364 const(NewSessionTicket) newSessionTicket() const 365 { return *m_new_session_ticket; } 366 367 const(Finished) serverFinished() const 368 { return *m_server_finished; } 369 370 const(Finished) clientFinished() const 371 { return *m_client_finished; } 372 373 ref const(TLSCiphersuite) ciphersuite() const { return m_ciphersuite; } 374 375 ref const(TLSSessionKeys) sessionKeys() const { return m_session_keys; } 376 377 void computeSessionKeys() 378 { 379 m_session_keys = TLSSessionKeys(this, clientKex().preMasterSecret().dup, false); 380 } 381 382 void computeSessionKeys()(auto ref SecureVector!ubyte resume_master_secret) 383 { 384 m_session_keys = TLSSessionKeys(this, resume_master_secret, true); 385 } 386 387 ref const(SecureVector!ubyte) originalHandshakeHash() const { return m_orig_hs_hash; } 388 389 ref HandshakeHash hash() { return m_handshake_hash; } 390 391 ref const(HandshakeHash) hash() const { return m_handshake_hash; } 392 393 void noteMessage(in HandshakeMessage msg) 394 { 395 if (m_msg_callback) 396 m_msg_callback(msg); 397 } 398 399 400 private: 401 402 void delegate(in HandshakeMessage) m_msg_callback; 403 404 Unique!HandshakeIO m_handshake_io; 405 406 uint m_hand_expecting_mask = 0; 407 uint m_hand_received_mask = 0; 408 TLSProtocolVersion m_version; 409 TLSCiphersuite m_ciphersuite; 410 TLSSessionKeys m_session_keys; 411 HandshakeHash m_handshake_hash; 412 // Used to save the original handshake hash in the session for ChannelID Resumption 413 SecureVector!ubyte m_orig_hs_hash; 414 415 Unique!ClientHello m_client_hello; 416 Unique!ServerHello m_server_hello; 417 Unique!Certificate m_server_certs; 418 Unique!ServerKeyExchange m_server_kex; 419 Unique!CertificateReq m_cert_req; 420 Unique!ServerHelloDone m_server_hello_done; 421 Unique!Certificate m_client_certs; 422 Unique!ClientKeyExchange m_client_kex; 423 Unique!CertificateVerify m_client_verify; 424 Unique!ChannelID m_channel_id; 425 Unique!NewSessionTicket m_new_session_ticket; 426 Unique!Finished m_server_finished; 427 Unique!Finished m_client_finished; 428 } 429 430 431 private: 432 433 uint bitmaskForHandshakeType(HandshakeType type) 434 { 435 switch(type) 436 { 437 case HELLO_VERIFY_REQUEST: 438 return (1 << 0); 439 440 case HELLO_REQUEST: 441 return (1 << 1); 442 443 /* 444 * Same code point for both client hello styles 445 */ 446 case CLIENT_HELLO: 447 return (1 << 2); 448 449 case SERVER_HELLO: 450 return (1 << 3); 451 452 case CERTIFICATE: 453 return (1 << 4); 454 455 case CERTIFICATE_URL: 456 return (1 << 5); 457 458 case CERTIFICATE_STATUS: 459 return (1 << 6); 460 461 case SERVER_KEX: 462 return (1 << 7); 463 464 case CERTIFICATE_REQUEST: 465 return (1 << 8); 466 467 case SERVER_HELLO_DONE: 468 return (1 << 9); 469 470 case CERTIFICATE_VERIFY: 471 return (1 << 10); 472 473 case CLIENT_KEX: 474 return (1 << 11); 475 476 case NEW_SESSION_TICKET: 477 return (1 << 12); 478 479 case HANDSHAKE_CCS: 480 return (1 << 13); 481 482 case FINISHED: 483 return (1 << 14); 484 485 // allow explicitly disabling new handshakes 486 case HANDSHAKE_NONE: 487 return 0; 488 489 default: 490 throw new InternalError("Unknown handshake type " ~ to!string(type)); 491 } 492 } 493 494 495 496 string chooseHash(in string sig_algo, 497 TLSProtocolVersion negotiated_version, 498 in TLSPolicy policy, 499 bool for_client_auth, 500 in ClientHello client_hello, 501 in CertificateReq cert_req) 502 { 503 if (!negotiated_version.supportsNegotiableSignatureAlgorithms()) 504 { 505 if (sig_algo == "RSA") 506 return "Parallel(MD5,SHA-160)"; 507 508 if (sig_algo == "DSA") 509 return "SHA-1"; 510 511 if (sig_algo == "ECDSA") 512 return "SHA-1"; 513 514 throw new InternalError("Unknown TLS signature algo " ~ sig_algo); 515 } 516 517 Vector!(Pair!(string, string)) supported_algos = for_client_auth ? cert_req.supportedAlgos() : client_hello.supportedAlgos(); 518 519 if (!supported_algos.empty()) 520 { 521 const Vector!string hashes = policy.allowedSignatureHashes(); 522 523 /* 524 * Choose our most preferred hash that the counterparty supports 525 * in pairing with the signature algorithm we want to use. 526 */ 527 foreach (hash; hashes[]) 528 { 529 foreach (algo; supported_algos[]) 530 { 531 if (algo.first == hash && algo.second == sig_algo) 532 return hash; 533 } 534 } 535 } 536 537 // TLS v1.2 default hash if the counterparty sent nothing 538 return "SHA-1"; 539 }