1 /** 2 * TLS Handshake State 3 * 4 * Copyright: 5 * (C) 2004-2006,2011,2012,2015 Jack Lloyd 6 * (C) 2014-2015 Etienne Cimon 7 * 8 * License: 9 * Botan is released under the Simplified BSD License (see LICENSE.md) 10 */ 11 module botan.tls.handshake_state; 12 13 import botan.constants; 14 static if (BOTAN_HAS_TLS): 15 package: 16 17 import botan.tls.handshake_hash; 18 import botan.tls.handshake_io; 19 import botan.tls.session_key; 20 import botan.tls.ciphersuite; 21 import botan.tls.exceptn; 22 import botan.tls.messages; 23 import botan.pubkey.pk_keys; 24 import botan.pubkey.pubkey; 25 import botan.kdf.kdf; 26 import botan.tls.messages; 27 import botan.tls.record; 28 29 package: 30 /** 31 * SSL/TLS Handshake State 32 */ 33 class HandshakeState 34 { 35 public: 36 /* 37 * Initialize the SSL/TLS Handshake State 38 */ 39 this(HandshakeIO io, void delegate(in HandshakeMessage) msg_callback = null) 40 { 41 m_ciphersuite = TLSCiphersuite.init; 42 m_session_keys = TLSSessionKeys.init; 43 m_msg_callback = msg_callback; 44 m_handshake_io = io; 45 m_version = m_handshake_io.initialRecordVersion(); 46 } 47 48 ~this() { m_handshake_io.drop(); } 49 50 HandshakeIO handshakeIo() { return *m_handshake_io; } 51 52 /** 53 * Return true iff we have received a particular message already 54 * Params: 55 * handshake_msg = the message type 56 */ 57 bool receivedHandshakeMsg(HandshakeType handshake_msg) const 58 { 59 const uint mask = bitmaskForHandshakeType(handshake_msg); 60 61 return cast(bool)(m_hand_received_mask & mask); 62 } 63 64 /** 65 * Confirm that we were expecting this message type 66 * Params: 67 * handshake_msg = the message type 68 */ 69 void confirmTransitionTo(HandshakeType handshake_msg) 70 { 71 const uint mask = bitmaskForHandshakeType(handshake_msg); 72 73 m_hand_received_mask |= mask; 74 75 const bool ok = cast(bool)(m_hand_expecting_mask & mask); // overlap? 76 77 if (!ok) 78 throw new TLSUnexpectedMessage("Unexpected state transition in handshake, got " ~ 79 to!string(handshake_msg) ~ 80 " expected " ~ to!string(m_hand_expecting_mask) ~ 81 " received " ~ to!string(m_hand_received_mask)); 82 83 /* We don't know what to expect next, so force a call to 84 set_expected_next; if it doesn't happen, the next transition 85 check will always fail which is what we want. 86 */ 87 m_hand_expecting_mask = 0; 88 } 89 90 /** 91 * Record that we are expecting a particular message type next 92 * Params: 93 * handshake_msg = the message type 94 */ 95 void setExpectedNext(HandshakeType handshake_msg) 96 { 97 m_hand_expecting_mask |= bitmaskForHandshakeType(handshake_msg); 98 } 99 100 NextRecord getNextHandshakeMsg() 101 { 102 const bool expecting_ccs = cast(bool)(bitmaskForHandshakeType(HANDSHAKE_CCS) & m_hand_expecting_mask); 103 104 return m_handshake_io.getNextRecord(expecting_ccs); 105 } 106 107 Vector!ubyte sessionTicket() const 108 { 109 if (newSessionTicket() && !newSessionTicket().ticket().empty()) 110 return newSessionTicket().ticket().dup; 111 112 return clientHello().sessionTicket(); 113 } 114 115 const(Pair!(string, SignatureFormat)) 116 understandSigFormat(in PublicKey key, string hash_algo, string sig_algo, bool for_client_auth) const 117 { 118 const string algo_name = key.algoName; 119 120 /* 121 FIXME: This should check what was sent against the client hello 122 preferences, or the certificate request, to ensure it was allowed 123 by those restrictions. 124 125 Or not? 126 */ 127 128 if (this.Version().supportsNegotiableSignatureAlgorithms()) 129 { 130 if (hash_algo == "") 131 throw new DecodingError("Counterparty did not send hash/sig IDS"); 132 133 if (sig_algo != algo_name) 134 throw new DecodingError("Counterparty sent inconsistent key and sig types"); 135 } 136 else 137 { 138 if (hash_algo != "" || sig_algo != "") 139 throw new DecodingError("Counterparty sent hash/sig IDs with old version"); 140 } 141 142 if (algo_name == "RSA") 143 { 144 if (for_client_auth && this.Version() == TLSProtocolVersion.SSL_V3) 145 { 146 hash_algo = "Raw"; 147 } 148 else if (!this.Version().supportsNegotiableSignatureAlgorithms()) 149 { 150 hash_algo = "Parallel(MD5,SHA-160)"; 151 } 152 153 const string padding = "EMSA3(" ~ hash_algo ~ ")"; 154 return makePair(padding, IEEE_1363); 155 } 156 else if (algo_name == "DSA" || algo_name == "ECDSA") 157 { 158 if (algo_name == "DSA" && for_client_auth && this.Version() == TLSProtocolVersion.SSL_V3) 159 { 160 hash_algo = "Raw"; 161 } 162 else if (!this.Version().supportsNegotiableSignatureAlgorithms()) 163 { 164 hash_algo = "SHA-1"; 165 } 166 167 const string padding = "EMSA1(" ~ hash_algo ~ ")"; 168 169 return makePair(padding, DER_SEQUENCE); 170 } 171 172 throw new InvalidArgument(algo_name ~ " is invalid/unknown for TLS signatures"); 173 } 174 175 const(Pair!(string, SignatureFormat)) 176 chooseSigFormat(in PrivateKey key, 177 ref string hash_algo_out, 178 ref string sig_algo_out, 179 bool for_client_auth, 180 in TLSPolicy policy) const 181 { 182 const string sig_algo = key.algoName; 183 184 const string hash_algo = chooseHash(sig_algo, 185 this.Version(), 186 policy, 187 for_client_auth, 188 clientHello(), 189 certReq()); 190 191 if (this.Version().supportsNegotiableSignatureAlgorithms()) 192 { 193 hash_algo_out = hash_algo; 194 sig_algo_out = sig_algo; 195 } 196 197 if (sig_algo == "RSA") 198 { 199 const string padding = "EMSA3(" ~ hash_algo ~ ")"; 200 201 return makePair(padding, IEEE_1363); 202 } 203 else if (sig_algo == "DSA" || sig_algo == "ECDSA") 204 { 205 const string padding = "EMSA1(" ~ hash_algo ~ ")"; 206 207 return makePair(padding, DER_SEQUENCE); 208 } 209 210 throw new InvalidArgument(sig_algo ~ " is invalid/unknown for TLS signatures"); 211 } 212 213 const(string) srpIdentifier() const 214 { 215 if (ciphersuite().valid() && ciphersuite().kexAlgo() == "SRP_SHA") 216 return clientHello().srpIdentifier(); 217 218 return ""; 219 } 220 221 KDF protocolSpecificPrf() const 222 { 223 if (Version() == TLSProtocolVersion.SSL_V3) 224 { 225 return getKdf("SSL3-PRF"); 226 } 227 else if (Version().supportsCiphersuiteSpecificPrf()) 228 { 229 const string prf_algo = ciphersuite().prfAlgo(); 230 231 if (prf_algo == "MD5" || prf_algo == "SHA-1") 232 return getKdf("TLS-12-PRF(SHA-256)"); 233 234 return getKdf("TLS-12-PRF(" ~ prf_algo ~ ")"); 235 } 236 else 237 { 238 // TLS v1.0, v1.1 and DTLS v1.0 239 return getKdf("TLS-PRF"); 240 } 241 242 // throw new InternalError("Unknown version code " ~ Version().toString()); 243 } 244 245 const(TLSProtocolVersion) Version() const { return m_version; } 246 247 void setVersion(in TLSProtocolVersion _version) 248 { 249 m_version = _version; 250 } 251 252 void helloVerifyRequest(in HelloVerifyRequest hello_verify) 253 { 254 noteMessage(hello_verify); 255 256 m_client_hello.updateHelloCookie(hello_verify); 257 hash().reset(); 258 hash().update(handshakeIo().send(*m_client_hello)); 259 noteMessage(*m_client_hello); 260 } 261 262 263 void clientHello(ClientHello clientHello) 264 { 265 m_client_hello = clientHello; 266 noteMessage(*m_client_hello); 267 } 268 269 void serverHello(ServerHello server_hello) 270 { 271 m_server_hello = server_hello; 272 m_ciphersuite = TLSCiphersuite.byId(m_server_hello.ciphersuite()); 273 noteMessage(*m_server_hello); 274 } 275 276 void serverCerts(Certificate server_certs) 277 { 278 m_server_certs = server_certs; 279 noteMessage(*m_server_certs); 280 } 281 282 void serverKex(ServerKeyExchange server_kex) 283 { 284 m_server_kex = server_kex; 285 noteMessage(*m_server_kex); 286 } 287 288 void certReq(CertificateReq cert_req) 289 { 290 m_cert_req = cert_req; 291 noteMessage(*m_cert_req); 292 } 293 294 void serverHelloDone(ServerHelloDone server_hello_done) 295 { 296 m_server_hello_done = server_hello_done; 297 noteMessage(*m_server_hello_done); 298 } 299 300 void clientCerts(Certificate client_certs) 301 { 302 m_client_certs = client_certs; 303 noteMessage(*m_client_certs); 304 } 305 306 void clientKex(ClientKeyExchange client_kex) 307 { 308 m_client_kex = client_kex; 309 noteMessage(*m_client_kex); 310 } 311 312 void clientVerify(CertificateVerify client_verify) 313 { 314 m_client_verify = client_verify; 315 noteMessage(*m_client_verify); 316 } 317 318 void newSessionTicket(NewSessionTicket new_session_ticket) 319 { 320 m_new_session_ticket = new_session_ticket; 321 noteMessage(*m_new_session_ticket); 322 } 323 324 void serverFinished(Finished server_finished) 325 { 326 m_server_finished = server_finished; 327 noteMessage(*m_server_finished); 328 } 329 330 void clientFinished(Finished client_finished) 331 { 332 m_client_finished = client_finished; 333 noteMessage(*m_client_finished); 334 } 335 336 const(ClientHello) clientHello() const 337 { return *m_client_hello; } 338 339 const(ServerHello) serverHello() const 340 { return *m_server_hello; } 341 342 const(Certificate) serverCerts() const 343 { return *m_server_certs; } 344 345 const(ServerKeyExchange) serverKex() const 346 { return *m_server_kex; } 347 348 const(CertificateReq) certReq() const 349 { return *m_cert_req; } 350 351 const(ServerHelloDone) serverHelloDone() const 352 { return *m_server_hello_done; } 353 354 const(Certificate) clientCerts() const 355 { return *m_client_certs; } 356 357 const(ClientKeyExchange) clientKex() const 358 { return *m_client_kex; } 359 360 const(CertificateVerify) clientVerify() const 361 { return *m_client_verify; } 362 363 const(NewSessionTicket) newSessionTicket() const 364 { return *m_new_session_ticket; } 365 366 const(Finished) serverFinished() const 367 { return *m_server_finished; } 368 369 const(Finished) clientFinished() const 370 { return *m_client_finished; } 371 372 ref const(TLSCiphersuite) ciphersuite() const { return m_ciphersuite; } 373 374 ref const(TLSSessionKeys) sessionKeys() const { return m_session_keys; } 375 376 void computeSessionKeys() 377 { 378 m_session_keys = TLSSessionKeys(this, clientKex().preMasterSecret().dup, false); 379 } 380 381 void computeSessionKeys()(auto ref SecureVector!ubyte resume_master_secret) 382 { 383 m_session_keys = TLSSessionKeys(this, resume_master_secret, true); 384 } 385 386 ref HandshakeHash hash() { return m_handshake_hash; } 387 388 ref const(HandshakeHash) hash() const { return m_handshake_hash; } 389 390 void noteMessage(in HandshakeMessage msg) 391 { 392 if (m_msg_callback) 393 m_msg_callback(msg); 394 } 395 396 397 private: 398 399 void delegate(in HandshakeMessage) m_msg_callback; 400 401 Unique!HandshakeIO m_handshake_io; 402 403 uint m_hand_expecting_mask = 0; 404 uint m_hand_received_mask = 0; 405 TLSProtocolVersion m_version; 406 TLSCiphersuite m_ciphersuite; 407 TLSSessionKeys m_session_keys; 408 HandshakeHash m_handshake_hash; 409 410 Unique!ClientHello m_client_hello; 411 Unique!ServerHello m_server_hello; 412 Unique!Certificate m_server_certs; 413 Unique!ServerKeyExchange m_server_kex; 414 Unique!CertificateReq m_cert_req; 415 Unique!ServerHelloDone m_server_hello_done; 416 Unique!Certificate m_client_certs; 417 Unique!ClientKeyExchange m_client_kex; 418 Unique!CertificateVerify m_client_verify; 419 Unique!NewSessionTicket m_new_session_ticket; 420 Unique!Finished m_server_finished; 421 Unique!Finished m_client_finished; 422 } 423 424 425 private: 426 427 uint bitmaskForHandshakeType(HandshakeType type) 428 { 429 switch(type) 430 { 431 case HELLO_VERIFY_REQUEST: 432 return (1 << 0); 433 434 case HELLO_REQUEST: 435 return (1 << 1); 436 437 /* 438 * Same code point for both client hello styles 439 */ 440 case CLIENT_HELLO: 441 case CLIENT_HELLO_SSLV2: 442 return (1 << 2); 443 444 case SERVER_HELLO: 445 return (1 << 3); 446 447 case CERTIFICATE: 448 return (1 << 4); 449 450 case CERTIFICATE_URL: 451 return (1 << 5); 452 453 case CERTIFICATE_STATUS: 454 return (1 << 6); 455 456 case SERVER_KEX: 457 return (1 << 7); 458 459 case CERTIFICATE_REQUEST: 460 return (1 << 8); 461 462 case SERVER_HELLO_DONE: 463 return (1 << 9); 464 465 case CERTIFICATE_VERIFY: 466 return (1 << 10); 467 468 case CLIENT_KEX: 469 return (1 << 11); 470 471 case NEW_SESSION_TICKET: 472 return (1 << 12); 473 474 case HANDSHAKE_CCS: 475 return (1 << 13); 476 477 case FINISHED: 478 return (1 << 14); 479 480 // allow explicitly disabling new handshakes 481 case HANDSHAKE_NONE: 482 return 0; 483 484 default: 485 throw new InternalError("Unknown handshake type " ~ to!string(type)); 486 } 487 } 488 489 490 491 string chooseHash(in string sig_algo, 492 TLSProtocolVersion negotiated_version, 493 in TLSPolicy policy, 494 bool for_client_auth, 495 in ClientHello client_hello, 496 in CertificateReq cert_req) 497 { 498 if (!negotiated_version.supportsNegotiableSignatureAlgorithms()) 499 { 500 if (for_client_auth && negotiated_version == TLSProtocolVersion.SSL_V3) 501 return "Raw"; 502 503 if (sig_algo == "RSA") 504 return "Parallel(MD5,SHA-160)"; 505 506 if (sig_algo == "DSA") 507 return "SHA-1"; 508 509 if (sig_algo == "ECDSA") 510 return "SHA-1"; 511 512 throw new InternalError("Unknown TLS signature algo " ~ sig_algo); 513 } 514 515 Vector!(Pair!(string, string)) supported_algos = for_client_auth ? cert_req.supportedAlgos() : client_hello.supportedAlgos(); 516 517 if (!supported_algos.empty()) 518 { 519 const Vector!string hashes = policy.allowedSignatureHashes(); 520 521 /* 522 * Choose our most preferred hash that the counterparty supports 523 * in pairing with the signature algorithm we want to use. 524 */ 525 foreach (hash; hashes[]) 526 { 527 foreach (algo; supported_algos[]) 528 { 529 if (algo.first == hash && algo.second == sig_algo) 530 return hash; 531 } 532 } 533 } 534 535 // TLS v1.2 default hash if the counterparty sent nothing 536 return "SHA-1"; 537 }