1 /**
2 * TLS Handshake State
3 * 
4 * Copyright:
5 * (C) 2004-2006,2011,2012,2015 Jack Lloyd
6 * (C) 2014-2015 Etienne Cimon
7 *
8 * License:
9 * Botan is released under the Simplified BSD License (see LICENSE.md)
10 */
11 module botan.tls.handshake_state;
12 
13 import botan.constants;
14 static if (BOTAN_HAS_TLS):
15 package:
16 
17 import botan.tls.handshake_hash;
18 import botan.tls.handshake_io;
19 import botan.tls.session_key;
20 import botan.tls.ciphersuite;
21 import botan.tls.exceptn;
22 import botan.tls.messages;
23 import botan.pubkey.pk_keys;
24 import botan.pubkey.pubkey;
25 import botan.kdf.kdf;
26 import botan.tls.messages;
27 import botan.tls.record;
28 
29 package:
30 /**
31 * TLS Handshake State
32 */
33 class HandshakeState
34 {
35 public:
36     /*
37     * Initialize the TLS Handshake State
38     */
39     this(HandshakeIO io, void delegate(in HandshakeMessage) msg_callback = null) 
40     {
41         m_ciphersuite = TLSCiphersuite.init;
42         m_session_keys = TLSSessionKeys.init;
43         m_msg_callback = msg_callback;
44         m_handshake_io = io;
45         m_version = m_handshake_io.initialRecordVersion();
46 		m_handshake_hash.reset();
47     }
48 
49     HandshakeIO handshakeIo() { return *m_handshake_io; }
50 
51     /**
52     * Return true iff we have received a particular message already
53     * Params:
54     *  handshake_msg = the message type
55     */
56     bool receivedHandshakeMsg(HandshakeType handshake_msg) const
57     {
58         const uint mask = bitmaskForHandshakeType(handshake_msg);
59         
60         return cast(bool)(m_hand_received_mask & mask);
61     }
62 
63     /**
64     * Confirm that we were expecting this message type
65     * Params:
66     *  handshake_msg = the message type
67     */
68     void confirmTransitionTo(HandshakeType handshake_msg)
69     {
70         const uint mask = bitmaskForHandshakeType(handshake_msg);
71         
72         m_hand_received_mask |= mask;
73         
74         const bool ok = cast(bool)(m_hand_expecting_mask & mask); // overlap?
75         
76         if (!ok)
77             throw new TLSUnexpectedMessage("Unexpected state transition in handshake, got " ~
78                                          to!string(handshake_msg) ~
79                                          " expected " ~ to!string(m_hand_expecting_mask) ~
80                                          " received " ~ to!string(m_hand_received_mask));
81         
82         /* We don't know what to expect next, so force a call to
83             set_expected_next; if it doesn't happen, the next transition
84             check will always fail which is what we want.
85         */
86         m_hand_expecting_mask = 0;
87     }
88 
89     /**
90     * Record that we are expecting a particular message type next
91     * Params:
92     *  handshake_msg = the message type
93     */
94     void setExpectedNext(HandshakeType handshake_msg)
95     {
96         m_hand_expecting_mask |= bitmaskForHandshakeType(handshake_msg);
97     }
98 
99     NextRecord getNextHandshakeMsg()
100     {
101         const bool expecting_ccs = cast(bool)(bitmaskForHandshakeType(HANDSHAKE_CCS) & m_hand_expecting_mask);
102         
103         return m_handshake_io.getNextRecord(expecting_ccs);
104     }
105 
106     Vector!ubyte sessionTicket() const
107     {
108         if (newSessionTicket() && !newSessionTicket().ticket().empty())
109             return newSessionTicket().ticket().dup;
110         
111         return clientHello().sessionTicket();
112     }
113 
114     const(Pair!(string, SignatureFormat))
115         understandSigFormat(in PublicKey key, string hash_algo, string sig_algo, bool for_client_auth) const
116     {
117         string algo_name = key.algoName;
118         
119         /*
120         FIXME: This should check what was sent against the client hello
121         preferences, or the certificate request, to ensure it was allowed
122         by those restrictions.
123 
124         Or not?
125         */
126         
127         if (this.Version().supportsNegotiableSignatureAlgorithms())
128         {
129             if (hash_algo == "")
130                 throw new DecodingError("Counterparty did not send hash/sig IDS");
131             
132             if (sig_algo != algo_name)
133 				algo_name = sig_algo;
134                 //throw new DecodingError("Counterparty sent inconsistent key and sig types");
135         }
136         else
137         {
138             if (hash_algo != "" || sig_algo != "")
139                 throw new DecodingError("Counterparty sent hash/sig IDs with old version");
140         }
141         
142         if (algo_name == "RSA")
143         {
144 			if (!this.Version().supportsNegotiableSignatureAlgorithms())
145 			{
146             	hash_algo = "Parallel(MD5,SHA-160)";
147 			}
148 
149             const string padding = "EMSA3(" ~ hash_algo ~ ")";
150             return makePair(padding, IEEE_1363);
151         }
152 		else if (algo_name == "RSA-PSS") {
153 			const string padding = "PSSR(" ~ hash_algo ~ ")";
154 			return makePair(padding, IEEE_1363);
155 		}
156         else if (algo_name == "DSA" || algo_name == "ECDSA")
157         {
158             if (!this.Version().supportsNegotiableSignatureAlgorithms())
159             {
160                 hash_algo = "SHA-1";
161             }
162             
163             const string padding = "EMSA1(" ~ hash_algo ~ ")";
164             
165             return makePair(padding, DER_SEQUENCE);
166         }
167         
168         throw new InvalidArgument(algo_name ~ " is invalid/unknown for TLS signatures");
169     }
170 
171     const(Pair!(string, SignatureFormat))
172         chooseSigFormat(in PrivateKey key,
173                           ref string hash_algo_out,
174                           ref string sig_algo_out,
175                           bool for_client_auth,
176                           in TLSPolicy policy) const
177     {
178         const string sig_algo = key.algoName;
179         
180         const string hash_algo = chooseHash(sig_algo,
181                                              this.Version(),
182                                              policy,
183                                              for_client_auth,
184                                              clientHello(),
185                                              certReq());
186         
187         if (this.Version().supportsNegotiableSignatureAlgorithms())
188         {
189             hash_algo_out = hash_algo;
190             sig_algo_out = sig_algo;
191         }
192         
193         if (sig_algo == "RSA")
194 		{
195 			const string padding = "EMSA3(" ~ hash_algo ~ ")";
196 			
197 			return makePair(padding, IEEE_1363);
198 		}
199 		else if (sig_algo == "RSA-PSS")
200 		{
201 			const string padding = "PSSR(" ~ hash_algo ~ ")";
202 			
203 			return makePair(padding, IEEE_1363);
204 		}
205         else if (sig_algo == "DSA" || sig_algo == "ECDSA")
206         {
207             const string padding = "EMSA1(" ~ hash_algo ~ ")";
208             
209             return makePair(padding, DER_SEQUENCE);
210         }
211         
212         throw new InvalidArgument(sig_algo ~ " is invalid/unknown for TLS signatures");
213     }
214 
215     const(string) srpIdentifier() const
216     {
217         if (ciphersuite().valid() && ciphersuite().kexAlgo() == "SRP_SHA")
218             return clientHello().srpIdentifier();
219         
220         return "";
221     }
222 
223     KDF protocolSpecificPrf() const
224     {
225         if (Version().supportsCiphersuiteSpecificPrf())
226         {
227             const string prf_algo = ciphersuite().prfAlgo();
228             
229             if (prf_algo == "MD5" || prf_algo == "SHA-1")
230                 return getKdf("TLS-12-PRF(SHA-256)");
231             
232             return getKdf("TLS-12-PRF(" ~ prf_algo ~ ")");
233         }
234         else
235         {
236             // TLS v1.0, v1.1 and DTLS v1.0
237             return getKdf("TLS-PRF");
238         }
239         
240         // throw new InternalError("Unknown version code " ~ Version().toString());
241     }
242 
243     const(TLSProtocolVersion) Version() const { return m_version; }
244 
245     void setOriginalHandshakeHash(SecureVector!ubyte orig_hs_hash)
246     {
247         m_orig_hs_hash = orig_hs_hash.move();
248     }
249 
250     void setVersion(in TLSProtocolVersion _version)
251     {
252         m_version = _version;
253     }
254 
255     void helloVerifyRequest(in HelloVerifyRequest hello_verify)
256     {
257         noteMessage(hello_verify);
258         
259         m_client_hello.updateHelloCookie(hello_verify);
260         hash().reset();
261         hash().update(handshakeIo().send(*m_client_hello));
262         noteMessage(*m_client_hello);
263     }
264 
265 
266     void clientHello(ClientHello clientHello)
267     {
268         m_client_hello = clientHello;
269         noteMessage(*m_client_hello);
270     }
271     
272     void serverHello(ServerHello server_hello)
273     {
274         m_server_hello = server_hello;
275         m_ciphersuite = TLSCiphersuite.byId(m_server_hello.ciphersuite());
276         noteMessage(*m_server_hello);
277     }
278     
279     void serverCerts(Certificate server_certs)
280     {
281         m_server_certs = server_certs;
282         noteMessage(*m_server_certs);
283     }
284     
285     void serverKex(ServerKeyExchange server_kex)
286     {
287         m_server_kex = server_kex;
288         noteMessage(*m_server_kex);
289     }
290     
291     void certReq(CertificateReq cert_req)
292     {
293         m_cert_req = cert_req;
294         noteMessage(*m_cert_req);
295     }
296 
297     void serverHelloDone(ServerHelloDone server_hello_done)
298     {
299         m_server_hello_done = server_hello_done;
300         noteMessage(*m_server_hello_done);
301     }
302     
303     void clientCerts(Certificate client_certs)
304     {
305         m_client_certs = client_certs;
306         noteMessage(*m_client_certs);
307     }
308     
309     void clientKex(ClientKeyExchange client_kex)
310     {
311         m_client_kex = client_kex;
312         noteMessage(*m_client_kex);
313     }
314     
315     void clientVerify(CertificateVerify client_verify)
316     {
317         m_client_verify = client_verify;
318         noteMessage(*m_client_verify);
319     }
320     
321     void channelID(ChannelID channel_id)
322     {
323         m_channel_id = channel_id;
324         noteMessage(*m_channel_id);
325     }
326     
327     void newSessionTicket(NewSessionTicket new_session_ticket)
328     {
329         m_new_session_ticket = new_session_ticket;
330         noteMessage(*m_new_session_ticket);
331     }
332     
333     void serverFinished(Finished server_finished)
334     {
335         m_server_finished = server_finished;
336         noteMessage(*m_server_finished);
337     }
338     
339     void clientFinished(Finished client_finished)
340     {
341         m_client_finished = client_finished;
342         noteMessage(*m_client_finished);
343     }
344 
345     const(ClientHello) clientHello() const
346     { return *m_client_hello; }
347 
348     const(ServerHello) serverHello() const
349     { return *m_server_hello; }
350 
351     const(Certificate) serverCerts() const
352     { return *m_server_certs; }
353 
354     const(ServerKeyExchange) serverKex() const
355     { return *m_server_kex; }
356 
357     const(CertificateReq) certReq() const
358     { return *m_cert_req; }
359 
360     const(ServerHelloDone) serverHelloDone() const
361     { return *m_server_hello_done; }
362 
363     const(Certificate) clientCerts() const
364     { return *m_client_certs; }
365 
366     const(ClientKeyExchange) clientKex() const
367     { return *m_client_kex; }
368 
369     const(CertificateVerify) clientVerify() const
370     { return *m_client_verify; }
371 
372     const(ChannelID) channelID() const
373     { return *m_channel_id; }
374 
375     const(NewSessionTicket) newSessionTicket() const
376     { return *m_new_session_ticket; }
377 
378     const(Finished) serverFinished() const
379     { return *m_server_finished; }
380 
381     const(Finished) clientFinished() const
382     { return *m_client_finished; }
383 
384     ref const(TLSCiphersuite) ciphersuite() const { return m_ciphersuite; }
385 
386     ref const(TLSSessionKeys) sessionKeys() const { return m_session_keys; }
387 
388     void computeSessionKeys()
389     {
390         m_session_keys = TLSSessionKeys(this, clientKex().preMasterSecret().dup, false);
391     }
392 
393     void computeSessionKeys()(auto ref SecureVector!ubyte resume_master_secret)
394     {
395         m_session_keys = TLSSessionKeys(this, resume_master_secret, true);
396     }
397 
398     ref const(SecureVector!ubyte) originalHandshakeHash() const { return m_orig_hs_hash; }
399 
400     ref HandshakeHash hash() { return m_handshake_hash; }
401 
402     ref const(HandshakeHash) hash() const { return m_handshake_hash; }
403 
404     void noteMessage(in HandshakeMessage msg)
405     {
406         if (m_msg_callback)
407             m_msg_callback(msg);
408     }
409 
410 
411 private:
412 
413     void delegate(in HandshakeMessage) m_msg_callback;
414 
415     Unique!HandshakeIO m_handshake_io;
416 
417     uint m_hand_expecting_mask = 0;
418     uint m_hand_received_mask = 0;
419     TLSProtocolVersion m_version;
420     TLSCiphersuite m_ciphersuite;
421     TLSSessionKeys m_session_keys;
422     HandshakeHash m_handshake_hash;
423     // Used to save the original handshake hash in the session for ChannelID Resumption
424     SecureVector!ubyte m_orig_hs_hash;
425 
426     Unique!ClientHello m_client_hello;
427     Unique!ServerHello m_server_hello;
428     Unique!Certificate m_server_certs;
429     Unique!ServerKeyExchange m_server_kex;
430     Unique!CertificateReq m_cert_req;
431     Unique!ServerHelloDone m_server_hello_done;
432     Unique!Certificate m_client_certs;
433     Unique!ClientKeyExchange m_client_kex;
434     Unique!CertificateVerify m_client_verify;
435     Unique!ChannelID m_channel_id;
436     Unique!NewSessionTicket m_new_session_ticket;
437     Unique!Finished m_server_finished;
438     Unique!Finished m_client_finished;
439 }
440 
441 
442 private:
443 
444 uint bitmaskForHandshakeType(HandshakeType type)
445 {
446     switch(type)
447     {
448         case HELLO_VERIFY_REQUEST:
449             return (1 << 0);
450             
451         case HELLO_REQUEST:
452             return (1 << 1);
453             
454             /*
455         * Same code point for both client hello styles
456         */
457         case CLIENT_HELLO:
458             return (1 << 2);
459             
460         case SERVER_HELLO:
461             return (1 << 3);
462             
463         case CERTIFICATE:
464             return (1 << 4);
465             
466         case CERTIFICATE_URL:
467             return (1 << 5);
468             
469         case CERTIFICATE_STATUS:
470             return (1 << 6);
471             
472         case SERVER_KEX:
473             return (1 << 7);
474             
475         case CERTIFICATE_REQUEST:
476             return (1 << 8);
477             
478         case SERVER_HELLO_DONE:
479             return (1 << 9);
480             
481         case CERTIFICATE_VERIFY:
482             return (1 << 10);
483             
484         case CLIENT_KEX:
485             return (1 << 11);
486             
487         case NEW_SESSION_TICKET:
488             return (1 << 12);
489             
490         case HANDSHAKE_CCS:
491             return (1 << 13);
492             
493         case FINISHED:
494             return (1 << 14);
495             
496             // allow explicitly disabling new handshakes
497         case HANDSHAKE_NONE:
498             return 0;
499 
500         default:
501             throw new InternalError("Unknown handshake type " ~ to!string(type));
502     }
503 }
504 
505 
506 
507 string chooseHash(in string sig_algo,
508                    TLSProtocolVersion negotiated_version,
509                    in TLSPolicy policy,
510                    bool for_client_auth,
511                    in ClientHello client_hello,
512                    in CertificateReq cert_req)
513 {
514     if (!negotiated_version.supportsNegotiableSignatureAlgorithms())
515     {
516         if (sig_algo == "RSA")
517             return "Parallel(MD5,SHA-160)";
518         
519         if (sig_algo == "DSA")
520             return "SHA-1";
521         
522         if (sig_algo == "ECDSA")
523             return "SHA-1";
524         
525         throw new InternalError("Unknown TLS signature algo " ~ sig_algo);
526     }
527     
528     Vector!(Pair!(string, string)) supported_algos = for_client_auth ? cert_req.supportedAlgos() : client_hello.supportedAlgos();
529     
530     if (!supported_algos.empty())
531     {
532         const Vector!string hashes = policy.allowedSignatureHashes();
533         
534         /*
535         * Choose our most preferred hash that the counterparty supports
536         * in pairing with the signature algorithm we want to use.
537         */
538         foreach (hash; hashes[])
539         {
540             foreach (algo; supported_algos[])
541             {
542                 if (algo.first == hash && algo.second == sig_algo)
543                     return hash;
544             }
545         }
546     }
547     
548     // TLS v1.2 default hash if the counterparty sent nothing
549     return "SHA-1";
550 }