1 /**
2 * TLS Handshake State
3 * 
4 * Copyright:
5 * (C) 2004-2006,2011,2012,2015 Jack Lloyd
6 * (C) 2014-2015 Etienne Cimon
7 *
8 * License:
9 * Botan is released under the Simplified BSD License (see LICENSE.md)
10 */
11 module botan.tls.handshake_state;
12 
13 import botan.constants;
14 static if (BOTAN_HAS_TLS):
15 package:
16 
17 import botan.tls.handshake_hash;
18 import botan.tls.handshake_io;
19 import botan.tls.session_key;
20 import botan.tls.ciphersuite;
21 import botan.tls.exceptn;
22 import botan.tls.messages;
23 import botan.pubkey.pk_keys;
24 import botan.pubkey.pubkey;
25 import botan.kdf.kdf;
26 import botan.tls.messages;
27 import botan.tls.record;
28 
29 package:
30 /**
31 * TLS Handshake State
32 */
33 class HandshakeState
34 {
35 public:
36     /*
37     * Initialize the TLS Handshake State
38     */
39     this(HandshakeIO io, void delegate(in HandshakeMessage) msg_callback = null) 
40     {
41         m_ciphersuite = TLSCiphersuite.init;
42         m_session_keys = TLSSessionKeys.init;
43         m_msg_callback = msg_callback;
44         m_handshake_io = io;
45         m_version = m_handshake_io.initialRecordVersion();
46 		m_handshake_hash.reset();
47     }
48 
49     HandshakeIO handshakeIo() { return *m_handshake_io; }
50 
51     /**
52     * Return true iff we have received a particular message already
53     * Params:
54     *  handshake_msg = the message type
55     */
56     bool receivedHandshakeMsg(HandshakeType handshake_msg) const
57     {
58         const uint mask = bitmaskForHandshakeType(handshake_msg);
59         
60         return cast(bool)(m_hand_received_mask & mask);
61     }
62 
63     /**
64     * Confirm that we were expecting this message type
65     * Params:
66     *  handshake_msg = the message type
67     */
68     void confirmTransitionTo(HandshakeType handshake_msg)
69     {
70         const uint mask = bitmaskForHandshakeType(handshake_msg);
71         
72         m_hand_received_mask |= mask;
73         
74         const bool ok = cast(bool)(m_hand_expecting_mask & mask); // overlap?
75         
76         if (!ok)
77             throw new TLSUnexpectedMessage("Unexpected state transition in handshake, got " ~
78                                          to!string(handshake_msg) ~
79                                          " expected " ~ to!string(m_hand_expecting_mask) ~
80                                          " received " ~ to!string(m_hand_received_mask));
81         
82         /* We don't know what to expect next, so force a call to
83             set_expected_next; if it doesn't happen, the next transition
84             check will always fail which is what we want.
85         */
86         m_hand_expecting_mask = 0;
87     }
88 
89     /**
90     * Record that we are expecting a particular message type next
91     * Params:
92     *  handshake_msg = the message type
93     */
94     void setExpectedNext(HandshakeType handshake_msg)
95     {
96         m_hand_expecting_mask |= bitmaskForHandshakeType(handshake_msg);
97     }
98 
99     NextRecord getNextHandshakeMsg()
100     {
101         const bool expecting_ccs = cast(bool)(bitmaskForHandshakeType(HANDSHAKE_CCS) & m_hand_expecting_mask);
102         
103         return m_handshake_io.getNextRecord(expecting_ccs);
104     }
105 
106     Vector!ubyte sessionTicket() const
107     {
108         if (newSessionTicket() && !newSessionTicket().ticket().empty())
109             return newSessionTicket().ticket().dup;
110         
111         return clientHello().sessionTicket();
112     }
113 
114     const(Pair!(string, SignatureFormat))
115         understandSigFormat(in PublicKey key, string hash_algo, string sig_algo, bool for_client_auth) const
116     {
117         const string algo_name = key.algoName;
118         
119         /*
120         FIXME: This should check what was sent against the client hello
121         preferences, or the certificate request, to ensure it was allowed
122         by those restrictions.
123 
124         Or not?
125         */
126         
127         if (this.Version().supportsNegotiableSignatureAlgorithms())
128         {
129             if (hash_algo == "")
130                 throw new DecodingError("Counterparty did not send hash/sig IDS");
131             
132             if (sig_algo != algo_name)
133                 throw new DecodingError("Counterparty sent inconsistent key and sig types");
134         }
135         else
136         {
137             if (hash_algo != "" || sig_algo != "")
138                 throw new DecodingError("Counterparty sent hash/sig IDs with old version");
139         }
140         
141         if (algo_name == "RSA")
142         {
143 			if (!this.Version().supportsNegotiableSignatureAlgorithms())
144 			{
145             	hash_algo = "Parallel(MD5,SHA-160)";
146 			}
147             
148             const string padding = "EMSA3(" ~ hash_algo ~ ")";
149             return makePair(padding, IEEE_1363);
150         }
151         else if (algo_name == "DSA" || algo_name == "ECDSA")
152         {
153             if (!this.Version().supportsNegotiableSignatureAlgorithms())
154             {
155                 hash_algo = "SHA-1";
156             }
157             
158             const string padding = "EMSA1(" ~ hash_algo ~ ")";
159             
160             return makePair(padding, DER_SEQUENCE);
161         }
162         
163         throw new InvalidArgument(algo_name ~ " is invalid/unknown for TLS signatures");
164     }
165 
166     const(Pair!(string, SignatureFormat))
167         chooseSigFormat(in PrivateKey key,
168                           ref string hash_algo_out,
169                           ref string sig_algo_out,
170                           bool for_client_auth,
171                           in TLSPolicy policy) const
172     {
173         const string sig_algo = key.algoName;
174         
175         const string hash_algo = chooseHash(sig_algo,
176                                              this.Version(),
177                                              policy,
178                                              for_client_auth,
179                                              clientHello(),
180                                              certReq());
181         
182         if (this.Version().supportsNegotiableSignatureAlgorithms())
183         {
184             hash_algo_out = hash_algo;
185             sig_algo_out = sig_algo;
186         }
187         
188         if (sig_algo == "RSA")
189         {
190             const string padding = "EMSA3(" ~ hash_algo ~ ")";
191             
192             return makePair(padding, IEEE_1363);
193         }
194         else if (sig_algo == "DSA" || sig_algo == "ECDSA")
195         {
196             const string padding = "EMSA1(" ~ hash_algo ~ ")";
197             
198             return makePair(padding, DER_SEQUENCE);
199         }
200         
201         throw new InvalidArgument(sig_algo ~ " is invalid/unknown for TLS signatures");
202     }
203 
204     const(string) srpIdentifier() const
205     {
206         if (ciphersuite().valid() && ciphersuite().kexAlgo() == "SRP_SHA")
207             return clientHello().srpIdentifier();
208         
209         return "";
210     }
211 
212     KDF protocolSpecificPrf() const
213     {
214         if (Version().supportsCiphersuiteSpecificPrf())
215         {
216             const string prf_algo = ciphersuite().prfAlgo();
217             
218             if (prf_algo == "MD5" || prf_algo == "SHA-1")
219                 return getKdf("TLS-12-PRF(SHA-256)");
220             
221             return getKdf("TLS-12-PRF(" ~ prf_algo ~ ")");
222         }
223         else
224         {
225             // TLS v1.0, v1.1 and DTLS v1.0
226             return getKdf("TLS-PRF");
227         }
228         
229         // throw new InternalError("Unknown version code " ~ Version().toString());
230     }
231 
232     const(TLSProtocolVersion) Version() const { return m_version; }
233 
234     void setOriginalHandshakeHash(SecureVector!ubyte orig_hs_hash)
235     {
236         m_orig_hs_hash = orig_hs_hash.move();
237     }
238 
239     void setVersion(in TLSProtocolVersion _version)
240     {
241         m_version = _version;
242     }
243 
244     void helloVerifyRequest(in HelloVerifyRequest hello_verify)
245     {
246         noteMessage(hello_verify);
247         
248         m_client_hello.updateHelloCookie(hello_verify);
249         hash().reset();
250         hash().update(handshakeIo().send(*m_client_hello));
251         noteMessage(*m_client_hello);
252     }
253 
254 
255     void clientHello(ClientHello clientHello)
256     {
257         m_client_hello = clientHello;
258         noteMessage(*m_client_hello);
259     }
260     
261     void serverHello(ServerHello server_hello)
262     {
263         m_server_hello = server_hello;
264         m_ciphersuite = TLSCiphersuite.byId(m_server_hello.ciphersuite());
265         noteMessage(*m_server_hello);
266     }
267     
268     void serverCerts(Certificate server_certs)
269     {
270         m_server_certs = server_certs;
271         noteMessage(*m_server_certs);
272     }
273     
274     void serverKex(ServerKeyExchange server_kex)
275     {
276         m_server_kex = server_kex;
277         noteMessage(*m_server_kex);
278     }
279     
280     void certReq(CertificateReq cert_req)
281     {
282         m_cert_req = cert_req;
283         noteMessage(*m_cert_req);
284     }
285 
286     void serverHelloDone(ServerHelloDone server_hello_done)
287     {
288         m_server_hello_done = server_hello_done;
289         noteMessage(*m_server_hello_done);
290     }
291     
292     void clientCerts(Certificate client_certs)
293     {
294         m_client_certs = client_certs;
295         noteMessage(*m_client_certs);
296     }
297     
298     void clientKex(ClientKeyExchange client_kex)
299     {
300         m_client_kex = client_kex;
301         noteMessage(*m_client_kex);
302     }
303     
304     void clientVerify(CertificateVerify client_verify)
305     {
306         m_client_verify = client_verify;
307         noteMessage(*m_client_verify);
308     }
309     
310     void channelID(ChannelID channel_id)
311     {
312         m_channel_id = channel_id;
313         noteMessage(*m_channel_id);
314     }
315     
316     void newSessionTicket(NewSessionTicket new_session_ticket)
317     {
318         m_new_session_ticket = new_session_ticket;
319         noteMessage(*m_new_session_ticket);
320     }
321     
322     void serverFinished(Finished server_finished)
323     {
324         m_server_finished = server_finished;
325         noteMessage(*m_server_finished);
326     }
327     
328     void clientFinished(Finished client_finished)
329     {
330         m_client_finished = client_finished;
331         noteMessage(*m_client_finished);
332     }
333 
334     const(ClientHello) clientHello() const
335     { return *m_client_hello; }
336 
337     const(ServerHello) serverHello() const
338     { return *m_server_hello; }
339 
340     const(Certificate) serverCerts() const
341     { return *m_server_certs; }
342 
343     const(ServerKeyExchange) serverKex() const
344     { return *m_server_kex; }
345 
346     const(CertificateReq) certReq() const
347     { return *m_cert_req; }
348 
349     const(ServerHelloDone) serverHelloDone() const
350     { return *m_server_hello_done; }
351 
352     const(Certificate) clientCerts() const
353     { return *m_client_certs; }
354 
355     const(ClientKeyExchange) clientKex() const
356     { return *m_client_kex; }
357 
358     const(CertificateVerify) clientVerify() const
359     { return *m_client_verify; }
360 
361     const(ChannelID) channelID() const
362     { return *m_channel_id; }
363 
364     const(NewSessionTicket) newSessionTicket() const
365     { return *m_new_session_ticket; }
366 
367     const(Finished) serverFinished() const
368     { return *m_server_finished; }
369 
370     const(Finished) clientFinished() const
371     { return *m_client_finished; }
372 
373     ref const(TLSCiphersuite) ciphersuite() const { return m_ciphersuite; }
374 
375     ref const(TLSSessionKeys) sessionKeys() const { return m_session_keys; }
376 
377     void computeSessionKeys()
378     {
379         m_session_keys = TLSSessionKeys(this, clientKex().preMasterSecret().dup, false);
380     }
381 
382     void computeSessionKeys()(auto ref SecureVector!ubyte resume_master_secret)
383     {
384         m_session_keys = TLSSessionKeys(this, resume_master_secret, true);
385     }
386 
387     ref const(SecureVector!ubyte) originalHandshakeHash() const { return m_orig_hs_hash; }
388 
389     ref HandshakeHash hash() { return m_handshake_hash; }
390 
391     ref const(HandshakeHash) hash() const { return m_handshake_hash; }
392 
393     void noteMessage(in HandshakeMessage msg)
394     {
395         if (m_msg_callback)
396             m_msg_callback(msg);
397     }
398 
399 
400 private:
401 
402     void delegate(in HandshakeMessage) m_msg_callback;
403 
404     Unique!HandshakeIO m_handshake_io;
405 
406     uint m_hand_expecting_mask = 0;
407     uint m_hand_received_mask = 0;
408     TLSProtocolVersion m_version;
409     TLSCiphersuite m_ciphersuite;
410     TLSSessionKeys m_session_keys;
411     HandshakeHash m_handshake_hash;
412     // Used to save the original handshake hash in the session for ChannelID Resumption
413     SecureVector!ubyte m_orig_hs_hash;
414 
415     Unique!ClientHello m_client_hello;
416     Unique!ServerHello m_server_hello;
417     Unique!Certificate m_server_certs;
418     Unique!ServerKeyExchange m_server_kex;
419     Unique!CertificateReq m_cert_req;
420     Unique!ServerHelloDone m_server_hello_done;
421     Unique!Certificate m_client_certs;
422     Unique!ClientKeyExchange m_client_kex;
423     Unique!CertificateVerify m_client_verify;
424     Unique!ChannelID m_channel_id;
425     Unique!NewSessionTicket m_new_session_ticket;
426     Unique!Finished m_server_finished;
427     Unique!Finished m_client_finished;
428 }
429 
430 
431 private:
432 
433 uint bitmaskForHandshakeType(HandshakeType type)
434 {
435     switch(type)
436     {
437         case HELLO_VERIFY_REQUEST:
438             return (1 << 0);
439             
440         case HELLO_REQUEST:
441             return (1 << 1);
442             
443             /*
444         * Same code point for both client hello styles
445         */
446         case CLIENT_HELLO:
447             return (1 << 2);
448             
449         case SERVER_HELLO:
450             return (1 << 3);
451             
452         case CERTIFICATE:
453             return (1 << 4);
454             
455         case CERTIFICATE_URL:
456             return (1 << 5);
457             
458         case CERTIFICATE_STATUS:
459             return (1 << 6);
460             
461         case SERVER_KEX:
462             return (1 << 7);
463             
464         case CERTIFICATE_REQUEST:
465             return (1 << 8);
466             
467         case SERVER_HELLO_DONE:
468             return (1 << 9);
469             
470         case CERTIFICATE_VERIFY:
471             return (1 << 10);
472             
473         case CLIENT_KEX:
474             return (1 << 11);
475             
476         case NEW_SESSION_TICKET:
477             return (1 << 12);
478             
479         case HANDSHAKE_CCS:
480             return (1 << 13);
481             
482         case FINISHED:
483             return (1 << 14);
484             
485             // allow explicitly disabling new handshakes
486         case HANDSHAKE_NONE:
487             return 0;
488 
489         default:
490             throw new InternalError("Unknown handshake type " ~ to!string(type));
491     }
492 }
493 
494 
495 
496 string chooseHash(in string sig_algo,
497                    TLSProtocolVersion negotiated_version,
498                    in TLSPolicy policy,
499                    bool for_client_auth,
500                    in ClientHello client_hello,
501                    in CertificateReq cert_req)
502 {
503     if (!negotiated_version.supportsNegotiableSignatureAlgorithms())
504     {
505         if (sig_algo == "RSA")
506             return "Parallel(MD5,SHA-160)";
507         
508         if (sig_algo == "DSA")
509             return "SHA-1";
510         
511         if (sig_algo == "ECDSA")
512             return "SHA-1";
513         
514         throw new InternalError("Unknown TLS signature algo " ~ sig_algo);
515     }
516     
517     Vector!(Pair!(string, string)) supported_algos = for_client_auth ? cert_req.supportedAlgos() : client_hello.supportedAlgos();
518     
519     if (!supported_algos.empty())
520     {
521         const Vector!string hashes = policy.allowedSignatureHashes();
522         
523         /*
524         * Choose our most preferred hash that the counterparty supports
525         * in pairing with the signature algorithm we want to use.
526         */
527         foreach (hash; hashes[])
528         {
529             foreach (algo; supported_algos[])
530             {
531                 if (algo.first == hash && algo.second == sig_algo)
532                     return hash;
533             }
534         }
535     }
536     
537     // TLS v1.2 default hash if the counterparty sent nothing
538     return "SHA-1";
539 }