1 /**
2 * TLS Extensions
3 * 
4 * Copyright:
5 * (C) 2011-2012,2015 Jack Lloyd
6 * (C) 2014-2015 Etienne Cimon
7 *
8 * License:
9 * Botan is released under the Simplified BSD License (see LICENSE.md)
10 */
11 module botan.tls.extensions;
12 
13 import std.exception : enforce;
14 import botan.constants;
15 static if (BOTAN_HAS_TLS):
16 public alias ushort HandshakeExtensionType;
17 public enum : HandshakeExtensionType {
18     TLSEXT_SERVER_NAME_INDICATION    = 0,
19     TLSEXT_MAX_FRAGMENT_LENGTH       = 1,
20     TLSEXT_CLIENT_CERT_URL           = 2,
21     TLSEXT_TRUSTED_CA_KEYS           = 3,
22     TLSEXT_TRUNCATED_HMAC            = 4,
23 	TLSEXT_STATUS_REQUEST            = 5,
24 
25     TLSEXT_CERTIFICATE_TYPES         = 9,
26     TLSEXT_USABLE_ELLIPTIC_CURVES    = 10,
27     TLSEXT_EC_POINT_FORMATS          = 11,
28     TLSEXT_SRP_IDENTIFIER            = 12,
29     TLSEXT_SIGNATURE_ALGORITHMS      = 13,
30     TLSEXT_HEARTBEAT_SUPPORT         = 15,
31     TLSEXT_ALPN                      = 16,
32 	TLSEXT_SIGNED_CERT_TIMESTAMP     = 18,
33 	TLSEXT_PADDING                   = 21,
34 	TLSEXT_EXTENDED_MASTER_SECRET    = 23,
35 
36     TLSEXT_SESSION_TICKET            = 35,
37 
38 	TLSEXT_NPN                       = 13172,
39 
40 	TLSEXT_CHANNEL_ID                = 30032,
41 
42     TLSEXT_SAFE_RENEGOTIATION        = 65281,
43 }
44 
45 package:
46 
47 import memutils.vector;
48 import botan.tls.magic;
49 import botan.utils.types;
50 import memutils.hashmap;
51 import botan.tls.reader;
52 import botan.tls.exceptn;
53 import botan.tls.alert;
54 import botan.utils.types : Unique;
55 import botan.utils.get_byte;
56 import std.conv : to;
57 import std.array : Appender;
58 
59 /**
60 * Base class representing a TLS extension of some kind
61 */
62 interface Extension
63 {
64 public:
65     /**
66     * Returns: code number of the extension
67     */
68     abstract HandshakeExtensionType type() const;
69 
70     /**
71     * Returns: serialized binary for the extension
72     */
73     abstract Vector!ubyte serialize() const;
74 
75     /**
76     * Returns: if we should encode this extension or not
77     */
78     abstract @property bool empty() const;
79 }
80 
81 class NPN : Extension
82 {
83 	static HandshakeExtensionType staticType() { return TLSEXT_NPN; }
84 	
85 	override HandshakeExtensionType type() const { return staticType(); }
86 	
87 	override Vector!ubyte serialize() const
88 	{
89 		return Vector!ubyte();
90 	}
91 	
92 	override @property bool empty() const { return false; }
93 }
94 
95 /**
96  * OCSP Stapling
97 */
98 class StatusRequest : Extension
99 {
100 	static HandshakeExtensionType staticType() { return TLSEXT_STATUS_REQUEST; }
101 	
102 	override HandshakeExtensionType type() const { return staticType(); }
103 	
104 	override Vector!ubyte serialize() const
105 	{
106 		Vector!ubyte buf;
107 		buf.reserve(5);
108 		buf.pushBack(0x01); // OCSP
109 
110 		// Responders
111 		buf.pushBack(0x00);
112 		buf.pushBack(0x00);
113 
114 		// Request Extensions
115 		buf.pushBack(0x00);
116 		buf.pushBack(0x00);
117 		
118 		return buf.move();
119 	}
120 	
121 	override @property bool empty() const { return false; }
122 }
123 
124 /**
125  * Extended master secret
126  */
127 class ExtendedMasterSecret : Extension
128 {
129 	static HandshakeExtensionType staticType() { return TLSEXT_EXTENDED_MASTER_SECRET; }
130 	
131 	override HandshakeExtensionType type() const { return staticType(); }
132 	
133 	override Vector!ubyte serialize() const
134 	{
135 		return Vector!ubyte();
136 	}
137 	
138 	override @property bool empty() const { return false; }
139 
140 	this(){}
141 
142 	this(ref TLSDataReader reader, ushort extension_size)
143 	{
144 		if (extension_size != 0)
145 			throw new DecodingError("Invalid extended_master_secret extension");
146 	}
147 }
148 
149 /**
150 * Signed Certificate Timestamp
151 */
152 class SignedCertificateTimestamp : Extension
153 {
154 public:
155 	static HandshakeExtensionType staticType() { return TLSEXT_SIGNED_CERT_TIMESTAMP; }
156 	
157 	override HandshakeExtensionType type() const { return staticType(); }
158 	
159 	override Vector!ubyte serialize() const
160 	{
161 		return Vector!ubyte();
162 	}
163 	
164 	override @property bool empty() const { return false; }
165 }
166 
167 /**
168 * Channel ID
169 */
170 class ChannelIDSupport : Extension
171 {
172 public:
173     static HandshakeExtensionType staticType() { return TLSEXT_CHANNEL_ID; }
174     
175     override HandshakeExtensionType type() const { return staticType(); }
176     
177     override Vector!ubyte serialize() const
178     {
179         return Vector!ubyte();
180     }
181 
182     this() {}
183 
184     this(ref TLSDataReader reader, ushort extension_size)
185     {
186         /*
187         * This is used by the server to confirm that it supports Channel ID
188         */
189     }
190         
191     override @property bool empty() const { return false; }
192     
193 }
194 /**
195 * Channel ID
196 */
197 class EncryptedChannelID : Extension
198 {
199     import botan.pubkey.pk_keys;
200     import botan.pubkey.algo.ecdsa;
201     import botan.math.bigint.bigint;
202 public:
203     static HandshakeExtensionType staticType() { return TLSEXT_CHANNEL_ID; }
204     
205     override HandshakeExtensionType type() const { return staticType(); }
206   
207 
208     this(PrivateKey pkey, SecureVector!ubyte hs_hash, SecureVector!ubyte orig_hs_hash) {
209         m_priv = pkey;
210         m_hs_hash = hs_hash.move();
211         m_orig_hs_hash = orig_hs_hash.move();
212     }
213     
214     this(ref TLSDataReader reader, ushort extension_size)
215     {
216         /*
217         * The (x,y) pubkey verifies the info and its hash will be saved and used as a machine identifier
218         */
219     }
220 
221     override Vector!ubyte serialize() const
222     {
223         Vector!ubyte buf;
224         static string magic = "TLS Channel ID signature\x00";
225         static string resume_magic = "Resumption\x00";
226         buf.reserve(32*4);
227         SecureVector!ubyte concat = cast(ubyte[])magic;
228 
229         if (m_orig_hs_hash.length > 0) {
230             concat.reserve(128);
231             concat ~= resume_magic;
232             concat ~= m_orig_hs_hash[];
233         }
234 
235         concat ~= m_hs_hash[];
236         import botan.libstate.lookup;
237         Unique!HashFunction sha256 = retrieveHash("SHA-256").clone();
238         sha256.update(concat[]);
239         SecureVector!ubyte channel_id_hash = sha256.finished();
240         ECDSAPrivateKey ecdsa_priv = ECDSAPrivateKey(m_priv);
241         const BigInt x = ecdsa_priv.publicPoint().getAffineX();
242         const BigInt y = ecdsa_priv.publicPoint().getAffineY();
243         import std.algorithm : max;
244         size_t part_size = max(x.bytes(), y.bytes());
245         enforce(part_size <= 32);
246         Vector!ubyte bits = Vector!ubyte(64);
247         
248         x.binaryEncode(bits.ptr);
249         y.binaryEncode(bits.ptr + 32);
250 
251         buf ~= bits[];
252         auto signer = scoped!ECDSASignatureOperation(ecdsa_priv);
253         import botan.rng.auto_rng : AutoSeededRNG;
254         auto rng = scoped!AutoSeededRNG();
255         auto sig = signer.sign(channel_id_hash.ptr, channel_id_hash.length, rng).unlock();
256         buf ~= sig;
257         return buf.move();
258     }
259 
260     override @property bool empty() const { return false; }
261 
262 private:
263     PrivateKey m_priv;
264     SecureVector!ubyte m_orig_hs_hash;
265     SecureVector!ubyte m_hs_hash;
266 }
267 
268 /**
269 * EC Point formats (RFC 4492) only uncompressed supported.
270 */
271 class SupportedPointFormats : Extension
272 {
273 public:
274 	static HandshakeExtensionType staticType() { return TLSEXT_EC_POINT_FORMATS; }
275 	
276 	override HandshakeExtensionType type() const { return staticType(); }
277 
278 	this(Vector!ubyte formats = Vector!ubyte([cast(ubyte)0x00])) {
279 		m_formats = formats.clone();
280 	}
281 
282 	override Vector!ubyte serialize() const
283 	{
284 		Vector!ubyte buf;
285 		buf.reserve(4);
286 		buf.pushBack(cast(ubyte)m_formats.length); // 1 point format
287 
288 		foreach (fmt; m_formats[]) {
289 			buf.pushBack(fmt);
290 		}
291 		
292 		return buf.move();
293 	}
294 	
295 	override @property bool empty() const { return false; }
296 
297 	private Vector!ubyte m_formats;
298 }
299 
300 /**
301 * TLS Server Name Indicator extension (RFC 3546)
302 */
303 class ServerNameIndicator : Extension
304 {
305 public:
306     static HandshakeExtensionType staticType() { return TLSEXT_SERVER_NAME_INDICATION; }
307 
308     override HandshakeExtensionType type() const { return staticType(); }
309 
310     this(in string host_name) 
311     {
312 		//logDebug("SNI loaded with host name: ", host_name);
313         m_sni_host_name = host_name;
314     }
315 
316     this(ref TLSDataReader reader, ushort extension_size)
317     {
318         /*
319         * This is used by the server to confirm that it knew the name
320         */
321         if (extension_size == 0)
322             return;
323         
324         ushort name_bytes = reader.get_ushort();
325         
326         if (name_bytes + 2 != extension_size)
327             throw new DecodingError("Bad encoding of SNI extension");
328         
329         while (name_bytes)
330         {
331             ubyte name_type = reader.get_byte();
332             name_bytes--;
333             
334             if (name_type == 0) // DNS
335             {
336                 m_sni_host_name = reader.getString(2, 1, 65535);
337                 name_bytes -= (2 + m_sni_host_name.length);
338             }
339             else // some other unknown name type
340             {
341                 reader.discardNext(name_bytes);
342                 name_bytes = 0;
343             }
344         }
345     }
346 
347 	string hostName() const { return m_sni_host_name; }
348 
349     override Vector!ubyte serialize() const
350     {
351         Vector!ubyte buf;
352         
353         size_t name_len = m_sni_host_name.length;
354         
355         buf.pushBack(get_byte(0, cast(ushort) (name_len+3)));
356         buf.pushBack(get_byte(1, cast(ushort) (name_len+3)));
357         buf.pushBack(0); // DNS
358         
359         buf.pushBack(get_byte(0, cast(ushort) name_len));
360         buf.pushBack(get_byte(1, cast(ushort) name_len));
361         
362         buf ~= (cast(const(ubyte)*)m_sni_host_name.ptr)[0 .. m_sni_host_name.length];
363         
364         return buf.move();
365     }
366 
367     override @property bool empty() const { return m_sni_host_name == ""; }
368 private:
369     string m_sni_host_name;
370 }
371 
372 /**
373 * SRP identifier extension (RFC 5054)
374 */
375 class SRPIdentifier : Extension
376 {
377 public:
378     static HandshakeExtensionType staticType() { return TLSEXT_SRP_IDENTIFIER; }
379 
380     override HandshakeExtensionType type() const { return staticType(); }
381 
382     this(in string identifier) 
383     {
384         m_srp_identifier = identifier;
385     }
386 
387     this(ref TLSDataReader reader, ushort extension_size)
388     {
389         m_srp_identifier = reader.getString(1, 1, 255);
390         
391         if (m_srp_identifier.length + 1 != extension_size)
392             throw new DecodingError("Bad encoding for SRP identifier extension");
393     }
394 
395     this(ref TLSDataReader reader, ushort extension_size);
396 
397     string identifier() const { return m_srp_identifier; }
398 
399 
400     override Vector!ubyte serialize() const
401     {
402         Vector!ubyte buf;
403 
404         const(ubyte)* srp_bytes = cast(const(ubyte)*) m_srp_identifier.ptr;
405         
406         appendTlsLengthValue(buf, srp_bytes, m_srp_identifier.length, 1);
407         
408         return buf.move();
409     }
410 
411     override @property bool empty() const { return m_srp_identifier == ""; }
412 private:
413     string m_srp_identifier;
414 }
415 
416 /**
417 * Renegotiation Indication Extension (RFC 5746)
418 */
419 class RenegotiationExtension : Extension
420 {
421 public:
422     static HandshakeExtensionType staticType() { return TLSEXT_SAFE_RENEGOTIATION; }
423 
424     override HandshakeExtensionType type() const { return staticType(); }
425 
426     this() {}
427 
428     this(Vector!ubyte bits)
429     {
430         m_reneg_data = bits.move();
431     }
432 
433     this(ref TLSDataReader reader, ushort extension_size)
434     {
435         m_reneg_data = reader.getRange!ubyte(1, 0, 255);
436         
437         if (m_reneg_data.length + 1 != extension_size)
438             throw new DecodingError("Bad encoding for secure renegotiation extn");
439     }
440 
441     ref const(Vector!ubyte) renegotiationInfo() const { return m_reneg_data; }
442 
443     override Vector!ubyte serialize() const
444     {
445         Vector!ubyte buf;
446         appendTlsLengthValue(buf, m_reneg_data, 1);
447         return buf.move();
448     }
449 
450     override @property bool empty() const { return m_reneg_data.empty; }
451 
452 private:
453     Vector!ubyte m_reneg_data;
454 }
455 
456 /**
457 * Maximum Fragment Length Negotiation Extension (RFC 4366 sec 3.2)
458 */
459 class MaximumFragmentLength : Extension
460 {
461 public:
462     static HandshakeExtensionType staticType() { return TLSEXT_MAX_FRAGMENT_LENGTH; }
463 
464     override HandshakeExtensionType type() const { return staticType(); }
465 
466     override @property bool empty() const { return false; }
467 
468     size_t fragmentSize() const { return m_max_fragment; }
469 
470     override Vector!ubyte serialize() const
471     {
472         static ubyte[size_t] fragment_to_code;
473         if (fragment_to_code.length == 0)
474             fragment_to_code = [ 512: 1, 1024: 2, 2048: 3, 4096: 4 ];
475         
476         auto i = fragment_to_code.get(m_max_fragment, 0);
477         
478         if (i == 0)
479             throw new InvalidArgument("Bad setting " ~ to!string(m_max_fragment) ~ " for maximum fragment size");
480         
481         return Vector!ubyte([i]);
482     }
483 
484     /**
485     * Params:
486     *  max_fragment = specifies what maximum fragment size to
487     *          advertise. Currently must be one of 512, 1024, 2048, or
488     *          4096.
489     */
490     this(size_t max_fragment) 
491     {
492         m_max_fragment = max_fragment;
493     }
494 
495     this(ref TLSDataReader reader, ushort extension_size)
496     {
497         __gshared immutable size_t[] code_to_fragment = [ 0, 512, 1024, 2048, 4096 ];
498         if (extension_size != 1)
499             throw new DecodingError("Bad size for maximum fragment extension");
500         ubyte val = reader.get_byte();
501 
502         if (val < code_to_fragment.length) {
503 
504             auto i = code_to_fragment[val];
505             
506             m_max_fragment = i;
507         }
508         else
509             throw new TLSException(TLSAlert.ILLEGAL_PARAMETER, "Bad value in maximum fragment extension");
510 
511     }
512 
513 private:
514     size_t m_max_fragment;
515 }
516 
517 /**
518 * ALPN (RFC 7301)
519 */
520 class ApplicationLayerProtocolNotification : Extension
521 {
522 public:
523     static HandshakeExtensionType staticType() { return TLSEXT_ALPN; }
524 
525     override HandshakeExtensionType type() const { return staticType(); }
526 
527     ref const(Vector!string) protocols() const { return m_protocols; }
528 
529     /**
530     * Single protocol, used by server
531     */
532     this() {}
533 
534     /**
535     * List of protocols, used by client
536     */
537     this(Vector!string protocols) 
538     {
539         m_protocols = protocols.move(); 
540     }
541 
542     this(string protocol) {
543         m_protocols.length = 1;
544         m_protocols[0] = protocol;
545     }
546 
547     this(ref TLSDataReader reader, ushort extension_size)
548     {
549         if (extension_size == 0)
550             return; // empty extension
551         
552         const ushort name_bytes = reader.get_ushort();
553         
554         size_t bytes_remaining = extension_size - 2;
555 
556         if (name_bytes != bytes_remaining)
557             throw new DecodingError("Bad encoding of ALPN extension, bad length field");
558 
559         while (bytes_remaining)
560         {
561             const string p = reader.getString(1, 0, 255);
562             
563             if (bytes_remaining < p.length + 1)
564                 throw new DecodingError("Bad encoding of ALPN, length field too long");
565             
566             bytes_remaining -= (p.length + 1);
567 			//logDebug("Got protocol: ", p); 
568             m_protocols.pushBack(p);
569         }
570     }
571 
572     ref string singleProtocol() const
573     {
574         if (m_protocols.length != 1)
575             throw new TLSException(TLSAlert.HANDSHAKE_FAILURE, "Server sent " ~ m_protocols.length.to!string ~ " protocols in ALPN extension response");
576         
577         return m_protocols[0];
578     }
579 
580     override Vector!ubyte serialize() const
581     {
582         Vector!ubyte buf = Vector!ubyte(2);
583 
584         foreach (ref p; m_protocols)
585         {
586             if (p.length >= 256)
587                 throw new TLSException(TLSAlert.INTERNAL_ERROR, "ALPN name too long");
588             if (p != "")
589                 appendTlsLengthValue(buf, cast(const(ubyte)*) p.ptr, p.length, 1);
590         }
591         ushort len = cast(ushort)( buf.length - 2 );
592         buf[0] = get_byte!ushort(0, len);
593         buf[1] = get_byte!ushort(1, len);
594 
595         return buf.move();
596     }
597 
598     override @property bool empty() const { return m_protocols.empty; }
599 private:
600     Vector!string m_protocols;
601 }
602 
603 /**
604 * TLSSession Ticket Extension (RFC 5077)
605 */
606 class SessionTicket : Extension
607 {
608 public:
609     static HandshakeExtensionType staticType() { return TLSEXT_SESSION_TICKET; }
610 
611     override HandshakeExtensionType type() const { return staticType(); }
612 
613     /**
614     * Returns: contents of the session ticket
615     */
616     ref const(Vector!ubyte) contents() const { return m_ticket; }
617 
618     /**
619     * Create empty extension, used by both client and server
620     */
621     this() {}
622 
623     /**
624     * Extension with ticket, used by client
625     */
626     this(Vector!ubyte session_ticket)
627     {
628         m_ticket = session_ticket.move();
629     }
630 
631     /**
632     * Deserialize a session ticket
633     */
634     this(ref TLSDataReader reader, ushort extension_size)
635     {
636         m_ticket = reader.getElem!(ubyte, Vector!ubyte)(extension_size);
637     }
638 
639     override Vector!ubyte serialize() const { return m_ticket.clone; }
640 
641     override @property bool empty() const { return false; }
642 private:
643     Vector!ubyte m_ticket;
644 }
645 
646 /**
647 * Supported Elliptic Curves Extension (RFC 4492)
648 */
649 class SupportedEllipticCurves : Extension
650 {
651 public:
652     static HandshakeExtensionType staticType() { return TLSEXT_USABLE_ELLIPTIC_CURVES; }
653 
654     override HandshakeExtensionType type() const { return staticType(); }
655 
656     static string curveIdToName(ushort id)
657     {
658         switch(id)
659         {
660 			/*
661 			 * unsupported 
662 			 */
663 			case 1:
664 				return "sect163k1";
665 			case 2:
666 				return "sect163r1";
667 			case 3:
668 				return "sect163r2";
669 			case 4:
670 				return "sect193r1";
671 			case 5:
672 				return "sect193r2";
673 			case 6:
674 				return "sect233k1";
675 			case 7:
676 				return "sect233r1";
677 			case 8:
678 				return "sect239k1";
679 			case 9:
680 				return "sect283k1";
681 			case 10:
682 				return "sect283r1";
683 			case 11:
684 				return "sect409k1";
685 			case 12:
686 				return "sect409r1";
687 			case 13:
688 				return "sect571k1";
689 			case 14:
690 				return "sect571r1";
691 			/*
692 			 * supported
693 			 */
694             case 15:
695                 return "secp160k1";
696             case 16:
697                 return "secp160r1";
698             case 17:
699                 return "secp160r2";
700             case 18:
701                 return "secp192k1";
702             case 19:
703                 return "secp192r1";
704             case 20:
705                 return "secp224k1";
706             case 21:
707                 return "secp224r1";
708             case 22:
709                 return "secp256k1";
710             case 23:
711                 return "secp256r1";
712             case 24:
713                 return "secp384r1";
714             case 25:
715                 return "secp521r1";
716             case 26:
717                 return "brainpool256r1";
718             case 27:
719                 return "brainpool384r1";
720             case 28:
721                 return "brainpool512r1";
722 			case 29:
723 				return "x25519";
724 			
725             default:
726                 return ""; // something we don't know or support
727         }
728     }
729 
730     static ushort nameToCurveId(in string name)
731     {
732 		/*
733 		 * unsupported 
734 		 */
735 		if (name == "sect163k1")
736 			return 1;
737 		if (name == "sect163r1")
738 			return 2;
739 		if (name == "sect163r2")
740 			return 3;
741 		if (name == "sect193r1")
742 			return 4;
743 		if (name == "sect193r2")
744 			return 5;
745 		if (name == "sect233k1")
746 			return 6;
747 		if (name == "sect233r1")
748 			return 7;
749 		if (name == "sect239k1")
750 			return 8;
751 		if (name == "sect283k1")
752 			return 9;
753 		if (name == "sect283r1")
754 			return 10;
755 		if (name == "sect409k1")
756 			return 11;
757 		if (name == "sect409r1")
758 			return 12;
759 		if (name == "sect571k1")
760 			return 13;
761 		if (name == "sect571r1")
762 			return 14;
763 
764 		/*
765 		 * supported
766 		 */
767         if (name == "secp160k1")
768             return 15;
769         if (name == "secp160r1")
770             return 16;
771         if (name == "secp160r2")
772             return 17;
773         if (name == "secp192k1")
774             return 18;
775         if (name == "secp192r1")
776             return 19;
777         if (name == "secp224k1")
778             return 20;
779         if (name == "secp224r1")
780             return 21;
781         if (name == "secp256k1")
782             return 22;
783         if (name == "secp256r1")
784             return 23;
785         if (name == "secp384r1")
786             return 24;
787         if (name == "secp521r1")
788             return 25;
789         if (name == "brainpool256r1")
790             return 26;
791         if (name == "brainpool384r1")
792             return 27;
793         if (name == "brainpool512r1")
794             return 28;
795 		if (name == "x25519")
796 			return 29;
797         
798         throw new InvalidArgument("name_to_curve_id unknown name " ~ name);
799     }
800 
801     ref const(Vector!string) curves() const { return m_curves; }
802 
803     override Vector!ubyte serialize() const
804     {
805 		Vector!ubyte buf;
806 		buf.reserve(m_curves.length * 2 + 2);
807 		buf.length = 2;
808         
809 		if (m_grease > 0) {
810 			buf.pushBack(get_byte(0, m_grease));
811 			buf.pushBack(get_byte(1, m_grease));
812 		}
813 		
814 		for (size_t i = 0; i != m_curves.length; ++i)
815         {
816             const ushort id = nameToCurveId(m_curves[i]);
817 			if (id > 0) {
818 	            buf.pushBack(get_byte(0, id));
819 	            buf.pushBack(get_byte(1, id));
820 			}
821         }
822         
823         buf[0] = get_byte(0, cast(ushort) (buf.length-2));
824         buf[1] = get_byte(1, cast(ushort) (buf.length-2));
825         
826         return buf.move();
827     }
828 
829     this(Vector!string curves, const ushort grease = 0) 
830     {
831         m_curves = curves.move();
832 		m_grease = grease;
833     }
834 
835     this(ref TLSDataReader reader, ushort extension_size)
836     {
837         ushort len = reader.get_ushort();
838 		m_curves.reserve(cast(size_t)len);
839 		//logDebug("Got elliptic curves len: ", len, " ext size: ", extension_size);
840         if (len + 2 != extension_size)
841             throw new DecodingError("Inconsistent length field in elliptic curve list");
842         
843         if (len % 2 == 1)
844             throw new DecodingError("Elliptic curve list of strange size");
845         
846         len /= 2;
847         
848         foreach (size_t i; 0 .. len)
849         {
850             const ushort id = reader.get_ushort();
851             const string name = curveIdToName(id);
852 			//logDebug("Got curve name: ", name);
853             
854             if (name != "")
855                 m_curves.pushBack(name);
856         }
857     }
858 
859     override @property bool empty() const { return m_curves.empty; }
860 private:
861     Vector!string m_curves;
862 	ushort m_grease;
863 }
864 
865 /**
866 * Signature Algorithms Extension for TLS 1.2 (RFC 5246)
867 */
868 class SignatureAlgorithms : Extension
869 {
870 public:
871     static HandshakeExtensionType staticType() { return TLSEXT_SIGNATURE_ALGORITHMS; }
872 
873     override HandshakeExtensionType type() const { return staticType(); }
874 
875     static string hashAlgoName(ubyte code)
876     {
877         switch(code)
878         {
879             case 1:
880                 return "MD5";
881                 // code 1 is MD5 - ignore it
882                 
883             case 2:
884                 return "SHA-1";
885             case 3:
886                 return "SHA-224";
887             case 4:
888                 return "SHA-256";
889             case 5:
890                 return "SHA-384";
891             case 6:
892                 return "SHA-512";
893             default:
894                 return "";
895         }
896     }
897 
898     static ubyte hashAlgoCode(in string name)
899     {
900         if (name == "MD5")
901             return 1;
902         
903         if (name == "SHA-1")
904             return 2;
905         
906         if (name == "SHA-224")
907             return 3;
908         
909         if (name == "SHA-256")
910             return 4;
911         
912         if (name == "SHA-384")
913             return 5;
914         
915         if (name == "SHA-512")
916             return 6;
917         
918         throw new InternalError("Unknown hash ID " ~ name ~ " for signature_algorithms");
919     }
920 
921     static string sigAlgoName(ubyte code)
922     {
923         switch(code)
924         {
925             case 1:
926                 return "RSA";
927             case 2:
928                 return "DSA";
929             case 3:
930                 return "ECDSA";
931 			case 8:
932 				return "RSA-PSS";
933             default:
934                 return "";
935         }
936     }
937 
938     static ubyte sigAlgoCode(in string name)
939     {
940         if (name == "RSA")
941             return 1;
942         
943         if (name == "DSA")
944             return 2;
945         
946         if (name == "ECDSA")
947             return 3;
948         
949 		if (name == "RSA-PSS")
950 			return 8;
951 
952         throw new InternalError("Unknown sig ID " ~ name ~ " for signature_algorithms");
953     }
954 
955     ref const(Vector!( Pair!(string, string)  )) supportedSignatureAlgorthms() const
956     {
957         return m_supported_algos;
958     }
959 
960     override Vector!ubyte serialize() const
961     {
962         Vector!ubyte buf = Vector!ubyte(2);
963         if (m_signature_algos_override.length > 0) {
964             buf ~= cast(ubyte[])m_signature_algos_override[];
965         }
966         else for (size_t i = 0; i != m_supported_algos.length; ++i)
967         {
968             try
969             {
970 				const ubyte hash_code = hashAlgoCode(m_supported_algos[i].first);
971 				const ubyte sig_code = sigAlgoCode(m_supported_algos[i].second);
972 				if (sig_code == 8) { // RSA-PSS
973 					if (hash_code < 4) continue;
974 					buf.pushBack(sig_code);
975 					buf.pushBack(hash_code);
976 				} else {
977 					buf.pushBack(hash_code);
978 					buf.pushBack(sig_code);
979 				}
980 			}
981 			catch (Exception)
982             {}
983         }
984         
985         buf[0] = get_byte(0, cast(ushort) (buf.length-2));
986         buf[1] = get_byte(1, cast(ushort) (buf.length-2));
987         
988         return buf.move();
989     }
990 
991     override @property bool empty() const { return false; }
992 
993     this()(auto const ref Vector!string hashes, auto const ref Vector!string sigs, auto const ref Vector!ubyte sig_algos)
994     {
995         m_signature_algos_override[] = cast(ubyte[])sig_algos[];
996 
997 		if (hashes[0] != "SHA-512" && hashes[0] != "SHA-256") {
998 			for (size_t j = 0; j != sigs.length; ++j)
999 	        	for (size_t i = 0; i != hashes.length; ++i) 
1000 					if (hashes[i] != "SHA-1")
1001 		                m_supported_algos.pushBack(makePair(hashes[i], sigs[j]));
1002 
1003 			for (size_t j = 0; j != sigs.length; ++j) 
1004 				for (size_t i = 0; i != hashes.length; ++i) 
1005 					if (hashes[i] == "SHA-1" && sigs[j] != "RSA-PSS")
1006 						m_supported_algos.pushBack(makePair(hashes[i], sigs[j]));
1007 		} else {
1008 			for (size_t i = 0; i != hashes.length; ++i) 
1009 				for (size_t j = 0; j != sigs.length; ++j)
1010 					m_supported_algos.pushBack(makePair(hashes[i], sigs[j]));
1011 		}
1012 		
1013     }
1014     
1015     this(ref TLSDataReader reader,
1016          ushort extension_size)
1017     {
1018         ushort len = reader.get_ushort();
1019         
1020         if (len + 2 != extension_size)
1021             throw new DecodingError("Bad encoding on signature algorithms extension");
1022         
1023         while (len)
1024         {
1025 			ubyte hash_byte = reader.get_byte();
1026 			ubyte sig_byte = reader.get_byte();
1027 			if (hash_byte == 0x08) { // RSA-PSS
1028 				import std.algorithm : swap;
1029 				swap(hash_byte, sig_byte);
1030 			}
1031             const string hash_code = hashAlgoName(hash_byte);
1032             const string sig_code = sigAlgoName(sig_byte);
1033             
1034             len -= 2;
1035             
1036             // If not something we know, ignore it completely
1037             if (hash_code == "" || sig_code == "")
1038                 continue;
1039 			//logDebug("Got signature: ", hash_code, " => ",sig_code);
1040             m_supported_algos.pushBack(makePair(hash_code, sig_code));
1041         }
1042     }
1043 
1044     this(Vector!( Pair!(string, string)  ) algos) 
1045     {
1046         m_supported_algos = algos.move();
1047     }
1048 
1049 private:
1050     Vector!( Pair!(string, string) ) m_supported_algos;
1051     Vector!ubyte m_signature_algos_override;
1052 }
1053 
1054 /**
1055 * Heartbeat Extension (RFC 6520)
1056 */
1057 class HeartbeatSupportIndicator : Extension
1058 {
1059 public:
1060     static HandshakeExtensionType staticType() { return TLSEXT_HEARTBEAT_SUPPORT; }
1061 
1062     override HandshakeExtensionType type() const { return staticType(); }
1063 
1064     bool peerAllowedToSend() const { return m_peer_allowed_to_send; }
1065 
1066     override Vector!ubyte serialize() const
1067     {
1068         Vector!ubyte heartbeat = Vector!ubyte(1);
1069         heartbeat[0] = (m_peer_allowed_to_send ? 1 : 2);
1070         return heartbeat.move();
1071     }
1072 
1073     override @property bool empty() const { return false; }
1074 
1075     this(bool peer_allowed_to_send) 
1076     {
1077         m_peer_allowed_to_send = peer_allowed_to_send; 
1078     }
1079 
1080     this(ref TLSDataReader reader, ushort extension_size)
1081     {
1082         if (extension_size != 1)
1083             throw new DecodingError("Strange size for heartbeat extension");
1084         
1085         const ubyte code = reader.get_byte();
1086         
1087         if (code != 1 && code != 2)
1088             throw new TLSException(TLSAlert.ILLEGAL_PARAMETER, "Unknown heartbeat code " ~ to!string(code));
1089         
1090         m_peer_allowed_to_send = (code == 1);
1091     }
1092 
1093 private:
1094     bool m_peer_allowed_to_send;
1095 }
1096 
1097 /**
1098 * Represents a block of extensions in a hello message
1099 */
1100 struct TLSExtensions
1101 {
1102 public:
1103     Vector!HandshakeExtensionType extensionTypes() const
1104     {
1105 		return m_extensions.types.clone;
1106     }
1107 
1108 
1109     T get(T)() const
1110     {
1111         HandshakeExtensionType type = T.staticType();
1112 
1113         return cast(T)m_extensions.get(type, T.init);
1114     }
1115 
1116     void add(Extension extn)
1117     {
1118         assert(extn);
1119 
1120         auto val = m_extensions.get(extn.type(), null);
1121         if (val) {
1122 			m_extensions.remove(extn.type());
1123 		}
1124         m_extensions.add(extn.type(), extn);
1125     }
1126 
1127     Vector!ubyte serialize() const
1128     {
1129         Vector!ubyte buf = Vector!ubyte(2); // 2 bytes for length field
1130         
1131 		// grease first
1132 		if (m_grease_first > 0) {
1133 			buf.pushBack(get_byte(0, m_grease_first));
1134 			buf.pushBack(get_byte(1, m_grease_first));
1135 			buf.pushBack(cast(ubyte)0);
1136 			buf.pushBack(cast(ubyte)0);
1137 		}
1138 
1139         foreach (const ref Extension extn; m_extensions.extensions[])
1140         {
1141             if (extn.empty)
1142                 continue;
1143             
1144             const ushort extn_code = extn.type();
1145             const Vector!ubyte extn_val = extn.serialize();
1146             
1147             buf.pushBack(get_byte(0, extn_code));
1148             buf.pushBack(get_byte(1, extn_code));
1149             
1150             buf.pushBack(get_byte(0, cast(ushort) extn_val.length));
1151             buf.pushBack(get_byte(1, cast(ushort) extn_val.length));
1152             
1153             buf ~= extn_val[];
1154         }
1155         
1156 		// grease last
1157 		if (m_grease_last > 0) {
1158 			ushort grease_last = m_grease_last;
1159 			if (m_grease_first == m_grease_last)
1160 				grease_last ^= 0x1010;
1161 			buf.pushBack(get_byte(0, grease_last));
1162 			buf.pushBack(get_byte(1, grease_last));
1163 			buf.pushBack(cast(ubyte) 0);
1164 			buf.pushBack(cast(ubyte) 1);
1165 			buf.pushBack(cast(ubyte) 0);
1166 		}
1167 		
1168 		const ushort extn_size = cast(ushort) (buf.length - 2);
1169         
1170         buf[0] = get_byte(0, extn_size);
1171         buf[1] = get_byte(1, extn_size);
1172         
1173         // avoid sending a completely empty extensions block
1174         if (buf.length == 2)
1175             return Vector!ubyte();
1176         
1177         return buf.move();
1178     }
1179 
1180     void deserialize(ref TLSDataReader reader)
1181     {
1182         if (reader.hasRemaining())
1183         {
1184             const ushort all_extn_size = reader.get_ushort();
1185             
1186             if (reader.remainingBytes() != all_extn_size)
1187                 throw new DecodingError("Bad extension size");
1188             
1189             while (reader.hasRemaining())
1190             {
1191                 const ushort extension_code = reader.get_ushort();
1192                 const ushort extension_size = reader.get_ushort();
1193 				//logDebug("Got extension: ", extension_code); 
1194                 Extension extn = makeExtension(reader, extension_code, extension_size);
1195                 
1196                 if (extn)
1197                     this.add(extn);
1198                 else // unknown/unhandled extension
1199                     reader.discardNext(extension_size);
1200             }
1201         }
1202     }
1203 
1204 	void grease(const ushort first, const ushort last) {
1205 		m_grease_first = first;
1206 		m_grease_last = last;
1207 	}
1208 
1209 	void reserve(size_t n) { m_extensions.extensions.reserve(n); }
1210 
1211     this(ref TLSDataReader reader) { deserialize(reader); }
1212 
1213 private:
1214 	HandshakeExtensions m_extensions;
1215 	ushort m_grease_first;
1216 	ushort m_grease_last;
1217 }
1218 
1219 private struct HandshakeExtensions {
1220 private:
1221 	Vector!HandshakeExtensionType types;
1222 	Vector!Extension extensions;
1223 
1224 	Extension get(HandshakeExtensionType type, Extension dflt) const {
1225 		size_t i;
1226 		foreach (HandshakeExtensionType t; types[]) {
1227 			if (t == type)
1228 				return cast() extensions[i];
1229 			i++;
1230 		}
1231 		return dflt;
1232 	}
1233 
1234 	void add(HandshakeExtensionType type, Extension ext)
1235 	{
1236 		types ~= type;
1237 		extensions ~= ext;
1238 	}
1239 
1240 	void remove(HandshakeExtensionType type) {
1241 		size_t i;
1242 		foreach (HandshakeExtensionType t; types[]) {
1243 			if (t == type) {
1244 				Vector!HandshakeExtensionType tmp_types;
1245 				tmp_types.reserve(types.length - 1);
1246 				tmp_types ~= types[0 .. i];
1247 				Vector!Extension tmp_extensions;
1248 				tmp_extensions.reserve(extensions.length - 1);
1249 				tmp_extensions ~= extensions[0 .. i];
1250 				if (i != types.length - 1) {
1251 					tmp_types ~= types[i+1 .. types.length];
1252 					tmp_extensions ~= extensions[i+1 .. extensions.length];
1253 				}
1254 				types[] = tmp_types[];
1255 				extensions[] = tmp_extensions[];
1256 				return;
1257 			}
1258 			i++;
1259 		}
1260 		logError("Could not find a TLS extension we wanted to remove...");
1261 	}
1262 
1263 }
1264 
1265 private:
1266 
1267 Extension makeExtension(ref TLSDataReader reader, ushort code, ushort size)
1268 {
1269     switch(code)
1270     {
1271         case TLSEXT_SERVER_NAME_INDICATION:
1272             return new ServerNameIndicator(reader, size);
1273             
1274 		case TLSEXT_EXTENDED_MASTER_SECRET:
1275 			return new ExtendedMasterSecret(reader, size);
1276 
1277         case TLSEXT_MAX_FRAGMENT_LENGTH:
1278             return new MaximumFragmentLength(reader, size);
1279             
1280         case TLSEXT_SRP_IDENTIFIER:
1281             return new SRPIdentifier(reader, size);
1282             
1283         case TLSEXT_USABLE_ELLIPTIC_CURVES:
1284             return new SupportedEllipticCurves(reader, size);
1285             
1286         case TLSEXT_SAFE_RENEGOTIATION:
1287             return new RenegotiationExtension(reader, size);
1288             
1289         case TLSEXT_SIGNATURE_ALGORITHMS:
1290             return new SignatureAlgorithms(reader, size);
1291             
1292         case TLSEXT_ALPN:
1293             return new ApplicationLayerProtocolNotification(reader, size);
1294             
1295         case TLSEXT_HEARTBEAT_SUPPORT:
1296             return new HeartbeatSupportIndicator(reader, size);
1297             
1298         case TLSEXT_SESSION_TICKET:
1299             return new SessionTicket(reader, size);
1300            
1301 		case TLSEXT_CHANNEL_ID:
1302 			return new ChannelIDSupport(reader, size);
1303 
1304         default:
1305             return null; // not known
1306     }
1307 }