1 /**
2 * TLS Handshake State
3 * 
4 * Copyright:
5 * (C) 2004-2006,2011,2012,2015 Jack Lloyd
6 * (C) 2014-2015 Etienne Cimon
7 *
8 * License:
9 * Botan is released under the Simplified BSD License (see LICENSE.md)
10 */
11 module botan.tls.handshake_state;
12 
13 import botan.constants;
14 static if (BOTAN_HAS_TLS):
15 package:
16 
17 import botan.tls.handshake_hash;
18 import botan.tls.handshake_io;
19 import botan.tls.session_key;
20 import botan.tls.ciphersuite;
21 import botan.tls.exceptn;
22 import botan.tls.messages;
23 import botan.pubkey.pk_keys;
24 import botan.pubkey.pubkey;
25 import botan.kdf.kdf;
26 import botan.tls.messages;
27 import botan.tls.record;
28 
29 package:
30 /**
31 * SSL/TLS Handshake State
32 */
33 class HandshakeState
34 {
35 public:
36     /*
37     * Initialize the SSL/TLS Handshake State
38     */
39     this(HandshakeIO io, void delegate(in HandshakeMessage) msg_callback = null) 
40     {
41         m_ciphersuite = TLSCiphersuite.init;
42         m_session_keys = TLSSessionKeys.init;
43         m_msg_callback = msg_callback;
44         m_handshake_io = io;
45         m_version = m_handshake_io.initialRecordVersion();
46     }
47 
48 	~this() { m_handshake_io.drop(); }
49 
50     HandshakeIO handshakeIo() { return *m_handshake_io; }
51 
52     /**
53     * Return true iff we have received a particular message already
54     * Params:
55     *  handshake_msg = the message type
56     */
57     bool receivedHandshakeMsg(HandshakeType handshake_msg) const
58     {
59         const uint mask = bitmaskForHandshakeType(handshake_msg);
60         
61         return cast(bool)(m_hand_received_mask & mask);
62     }
63 
64     /**
65     * Confirm that we were expecting this message type
66     * Params:
67     *  handshake_msg = the message type
68     */
69     void confirmTransitionTo(HandshakeType handshake_msg)
70     {
71         const uint mask = bitmaskForHandshakeType(handshake_msg);
72         
73         m_hand_received_mask |= mask;
74         
75         const bool ok = cast(bool)(m_hand_expecting_mask & mask); // overlap?
76         
77         if (!ok)
78             throw new TLSUnexpectedMessage("Unexpected state transition in handshake, got " ~
79                                          to!string(handshake_msg) ~
80                                          " expected " ~ to!string(m_hand_expecting_mask) ~
81                                          " received " ~ to!string(m_hand_received_mask));
82         
83         /* We don't know what to expect next, so force a call to
84             set_expected_next; if it doesn't happen, the next transition
85             check will always fail which is what we want.
86         */
87         m_hand_expecting_mask = 0;
88     }
89 
90     /**
91     * Record that we are expecting a particular message type next
92     * Params:
93     *  handshake_msg = the message type
94     */
95     void setExpectedNext(HandshakeType handshake_msg)
96     {
97         m_hand_expecting_mask |= bitmaskForHandshakeType(handshake_msg);
98     }
99 
100     NextRecord getNextHandshakeMsg()
101     {
102         const bool expecting_ccs = cast(bool)(bitmaskForHandshakeType(HANDSHAKE_CCS) & m_hand_expecting_mask);
103         
104         return m_handshake_io.getNextRecord(expecting_ccs);
105     }
106 
107     Vector!ubyte sessionTicket() const
108     {
109         if (newSessionTicket() && !newSessionTicket().ticket().empty())
110             return newSessionTicket().ticket().dup;
111         
112         return clientHello().sessionTicket();
113     }
114 
115     const(Pair!(string, SignatureFormat))
116         understandSigFormat(in PublicKey key, string hash_algo, string sig_algo, bool for_client_auth) const
117     {
118         const string algo_name = key.algoName;
119         
120         /*
121         FIXME: This should check what was sent against the client hello
122         preferences, or the certificate request, to ensure it was allowed
123         by those restrictions.
124 
125         Or not?
126         */
127         
128         if (this.Version().supportsNegotiableSignatureAlgorithms())
129         {
130             if (hash_algo == "")
131                 throw new DecodingError("Counterparty did not send hash/sig IDS");
132             
133             if (sig_algo != algo_name)
134                 throw new DecodingError("Counterparty sent inconsistent key and sig types");
135         }
136         else
137         {
138             if (hash_algo != "" || sig_algo != "")
139                 throw new DecodingError("Counterparty sent hash/sig IDs with old version");
140         }
141         
142         if (algo_name == "RSA")
143         {
144             if (for_client_auth && this.Version() == TLSProtocolVersion.SSL_V3)
145             {
146                 hash_algo = "Raw";
147             }
148             else if (!this.Version().supportsNegotiableSignatureAlgorithms())
149             {
150                 hash_algo = "Parallel(MD5,SHA-160)";
151             }
152             
153             const string padding = "EMSA3(" ~ hash_algo ~ ")";
154             return makePair(padding, IEEE_1363);
155         }
156         else if (algo_name == "DSA" || algo_name == "ECDSA")
157         {
158             if (algo_name == "DSA" && for_client_auth && this.Version() == TLSProtocolVersion.SSL_V3)
159             {
160                 hash_algo = "Raw";
161             }
162             else if (!this.Version().supportsNegotiableSignatureAlgorithms())
163             {
164                 hash_algo = "SHA-1";
165             }
166             
167             const string padding = "EMSA1(" ~ hash_algo ~ ")";
168             
169             return makePair(padding, DER_SEQUENCE);
170         }
171         
172         throw new InvalidArgument(algo_name ~ " is invalid/unknown for TLS signatures");
173     }
174 
175     const(Pair!(string, SignatureFormat))
176         chooseSigFormat(in PrivateKey key,
177                           ref string hash_algo_out,
178                           ref string sig_algo_out,
179                           bool for_client_auth,
180                           in TLSPolicy policy) const
181     {
182         const string sig_algo = key.algoName;
183         
184         const string hash_algo = chooseHash(sig_algo,
185                                              this.Version(),
186                                              policy,
187                                              for_client_auth,
188                                              clientHello(),
189                                              certReq());
190         
191         if (this.Version().supportsNegotiableSignatureAlgorithms())
192         {
193             hash_algo_out = hash_algo;
194             sig_algo_out = sig_algo;
195         }
196         
197         if (sig_algo == "RSA")
198         {
199             const string padding = "EMSA3(" ~ hash_algo ~ ")";
200             
201             return makePair(padding, IEEE_1363);
202         }
203         else if (sig_algo == "DSA" || sig_algo == "ECDSA")
204         {
205             const string padding = "EMSA1(" ~ hash_algo ~ ")";
206             
207             return makePair(padding, DER_SEQUENCE);
208         }
209         
210         throw new InvalidArgument(sig_algo ~ " is invalid/unknown for TLS signatures");
211     }
212 
213     const(string) srpIdentifier() const
214     {
215         if (ciphersuite().valid() && ciphersuite().kexAlgo() == "SRP_SHA")
216             return clientHello().srpIdentifier();
217         
218         return "";
219     }
220 
221     KDF protocolSpecificPrf() const
222     {
223         if (Version() == TLSProtocolVersion.SSL_V3)
224         {
225             return getKdf("SSL3-PRF");
226         }
227         else if (Version().supportsCiphersuiteSpecificPrf())
228         {
229             const string prf_algo = ciphersuite().prfAlgo();
230             
231             if (prf_algo == "MD5" || prf_algo == "SHA-1")
232                 return getKdf("TLS-12-PRF(SHA-256)");
233             
234             return getKdf("TLS-12-PRF(" ~ prf_algo ~ ")");
235         }
236         else
237         {
238             // TLS v1.0, v1.1 and DTLS v1.0
239             return getKdf("TLS-PRF");
240         }
241         
242         // throw new InternalError("Unknown version code " ~ Version().toString());
243     }
244 
245     const(TLSProtocolVersion) Version() const { return m_version; }
246 
247     void setVersion(in TLSProtocolVersion _version)
248     {
249         m_version = _version;
250     }
251 
252     void helloVerifyRequest(in HelloVerifyRequest hello_verify)
253     {
254         noteMessage(hello_verify);
255         
256         m_client_hello.updateHelloCookie(hello_verify);
257         hash().reset();
258         hash().update(handshakeIo().send(*m_client_hello));
259         noteMessage(*m_client_hello);
260     }
261 
262 
263     void clientHello(ClientHello clientHello)
264     {
265         m_client_hello = clientHello;
266         noteMessage(*m_client_hello);
267     }
268     
269     void serverHello(ServerHello server_hello)
270     {
271         m_server_hello = server_hello;
272         m_ciphersuite = TLSCiphersuite.byId(m_server_hello.ciphersuite());
273         noteMessage(*m_server_hello);
274     }
275     
276     void serverCerts(Certificate server_certs)
277     {
278         m_server_certs = server_certs;
279         noteMessage(*m_server_certs);
280     }
281     
282     void serverKex(ServerKeyExchange server_kex)
283     {
284         m_server_kex = server_kex;
285         noteMessage(*m_server_kex);
286     }
287     
288     void certReq(CertificateReq cert_req)
289     {
290         m_cert_req = cert_req;
291         noteMessage(*m_cert_req);
292     }
293 
294     void serverHelloDone(ServerHelloDone server_hello_done)
295     {
296         m_server_hello_done = server_hello_done;
297         noteMessage(*m_server_hello_done);
298     }
299     
300     void clientCerts(Certificate client_certs)
301     {
302         m_client_certs = client_certs;
303         noteMessage(*m_client_certs);
304     }
305     
306     void clientKex(ClientKeyExchange client_kex)
307     {
308         m_client_kex = client_kex;
309         noteMessage(*m_client_kex);
310     }
311     
312     void clientVerify(CertificateVerify client_verify)
313     {
314         m_client_verify = client_verify;
315         noteMessage(*m_client_verify);
316     }
317     
318     void newSessionTicket(NewSessionTicket new_session_ticket)
319     {
320         m_new_session_ticket = new_session_ticket;
321         noteMessage(*m_new_session_ticket);
322     }
323     
324     void serverFinished(Finished server_finished)
325     {
326         m_server_finished = server_finished;
327         noteMessage(*m_server_finished);
328     }
329     
330     void clientFinished(Finished client_finished)
331     {
332         m_client_finished = client_finished;
333         noteMessage(*m_client_finished);
334     }
335 
336     const(ClientHello) clientHello() const
337     { return *m_client_hello; }
338 
339     const(ServerHello) serverHello() const
340     { return *m_server_hello; }
341 
342     const(Certificate) serverCerts() const
343     { return *m_server_certs; }
344 
345     const(ServerKeyExchange) serverKex() const
346     { return *m_server_kex; }
347 
348     const(CertificateReq) certReq() const
349     { return *m_cert_req; }
350 
351     const(ServerHelloDone) serverHelloDone() const
352     { return *m_server_hello_done; }
353 
354     const(Certificate) clientCerts() const
355     { return *m_client_certs; }
356 
357     const(ClientKeyExchange) clientKex() const
358     { return *m_client_kex; }
359 
360     const(CertificateVerify) clientVerify() const
361     { return *m_client_verify; }
362 
363     const(NewSessionTicket) newSessionTicket() const
364     { return *m_new_session_ticket; }
365 
366     const(Finished) serverFinished() const
367     { return *m_server_finished; }
368 
369     const(Finished) clientFinished() const
370     { return *m_client_finished; }
371 
372     ref const(TLSCiphersuite) ciphersuite() const { return m_ciphersuite; }
373 
374     ref const(TLSSessionKeys) sessionKeys() const { return m_session_keys; }
375 
376     void computeSessionKeys()
377     {
378         m_session_keys = TLSSessionKeys(this, clientKex().preMasterSecret().dup, false);
379     }
380 
381     void computeSessionKeys()(auto ref SecureVector!ubyte resume_master_secret)
382     {
383         m_session_keys = TLSSessionKeys(this, resume_master_secret, true);
384     }
385 
386     ref HandshakeHash hash() { return m_handshake_hash; }
387 
388     ref const(HandshakeHash) hash() const { return m_handshake_hash; }
389 
390     void noteMessage(in HandshakeMessage msg)
391     {
392         if (m_msg_callback)
393             m_msg_callback(msg);
394     }
395 
396 
397 private:
398 
399     void delegate(in HandshakeMessage) m_msg_callback;
400 
401     Unique!HandshakeIO m_handshake_io;
402 
403     uint m_hand_expecting_mask = 0;
404     uint m_hand_received_mask = 0;
405     TLSProtocolVersion m_version;
406     TLSCiphersuite m_ciphersuite;
407     TLSSessionKeys m_session_keys;
408     HandshakeHash m_handshake_hash;
409 
410     Unique!ClientHello m_client_hello;
411     Unique!ServerHello m_server_hello;
412     Unique!Certificate m_server_certs;
413     Unique!ServerKeyExchange m_server_kex;
414     Unique!CertificateReq m_cert_req;
415     Unique!ServerHelloDone m_server_hello_done;
416     Unique!Certificate m_client_certs;
417     Unique!ClientKeyExchange m_client_kex;
418     Unique!CertificateVerify m_client_verify;
419     Unique!NewSessionTicket m_new_session_ticket;
420     Unique!Finished m_server_finished;
421     Unique!Finished m_client_finished;
422 }
423 
424 
425 private:
426 
427 uint bitmaskForHandshakeType(HandshakeType type)
428 {
429     switch(type)
430     {
431         case HELLO_VERIFY_REQUEST:
432             return (1 << 0);
433             
434         case HELLO_REQUEST:
435             return (1 << 1);
436             
437             /*
438         * Same code point for both client hello styles
439         */
440         case CLIENT_HELLO:
441         case CLIENT_HELLO_SSLV2:
442             return (1 << 2);
443             
444         case SERVER_HELLO:
445             return (1 << 3);
446             
447         case CERTIFICATE:
448             return (1 << 4);
449             
450         case CERTIFICATE_URL:
451             return (1 << 5);
452             
453         case CERTIFICATE_STATUS:
454             return (1 << 6);
455             
456         case SERVER_KEX:
457             return (1 << 7);
458             
459         case CERTIFICATE_REQUEST:
460             return (1 << 8);
461             
462         case SERVER_HELLO_DONE:
463             return (1 << 9);
464             
465         case CERTIFICATE_VERIFY:
466             return (1 << 10);
467             
468         case CLIENT_KEX:
469             return (1 << 11);
470             
471         case NEW_SESSION_TICKET:
472             return (1 << 12);
473             
474         case HANDSHAKE_CCS:
475             return (1 << 13);
476             
477         case FINISHED:
478             return (1 << 14);
479             
480             // allow explicitly disabling new handshakes
481         case HANDSHAKE_NONE:
482             return 0;
483 
484         default:
485             throw new InternalError("Unknown handshake type " ~ to!string(type));
486     }
487 }
488 
489 
490 
491 string chooseHash(in string sig_algo,
492                    TLSProtocolVersion negotiated_version,
493                    in TLSPolicy policy,
494                    bool for_client_auth,
495                    in ClientHello client_hello,
496                    in CertificateReq cert_req)
497 {
498     if (!negotiated_version.supportsNegotiableSignatureAlgorithms())
499     {
500         if (for_client_auth && negotiated_version == TLSProtocolVersion.SSL_V3)
501             return "Raw";
502         
503         if (sig_algo == "RSA")
504             return "Parallel(MD5,SHA-160)";
505         
506         if (sig_algo == "DSA")
507             return "SHA-1";
508         
509         if (sig_algo == "ECDSA")
510             return "SHA-1";
511         
512         throw new InternalError("Unknown TLS signature algo " ~ sig_algo);
513     }
514     
515     Vector!(Pair!(string, string)) supported_algos = for_client_auth ? cert_req.supportedAlgos() : client_hello.supportedAlgos();
516     
517     if (!supported_algos.empty())
518     {
519         const Vector!string hashes = policy.allowedSignatureHashes();
520         
521         /*
522         * Choose our most preferred hash that the counterparty supports
523         * in pairing with the signature algorithm we want to use.
524         */
525         foreach (hash; hashes[])
526         {
527             foreach (algo; supported_algos[])
528             {
529                 if (algo.first == hash && algo.second == sig_algo)
530                     return hash;
531             }
532         }
533     }
534     
535     // TLS v1.2 default hash if the counterparty sent nothing
536     return "SHA-1";
537 }