1 /**
2 * TLS Unit tests
3 * 
4 * Copyright:
5 * (C) 2014-2015 Jack Lloyd
6 * (C) 2014-2015 Etienne Cimon
7 *
8 * License:
9 * Botan is released under the Simplified BSD License (see LICENSE.md)
10 */
11 module botan.tls.test;
12 import botan.constants;
13 static if (BOTAN_TEST && BOTAN_HAS_TLS):
14 
15 import botan.test;
16 import botan.rng.auto_rng;
17 import botan.tls.server;
18 import botan.tls.client;
19 import botan.cert.x509.pkcs10;
20 import botan.cert.x509.x509self;
21 import botan.cert.x509.x509_ca;
22 import botan.pubkey.algo.rsa;
23 import botan.codec.hex;
24 import botan.utils.types;
25 import std.stdio;
26 import std.datetime;
27 
28 class TLSCredentialsManagerTest : TLSCredentialsManager
29 {
30 public:
31     this(X509Certificate server_cert, X509Certificate ca_cert, PrivateKey server_key) 
32     {
33         m_server_cert = server_cert;
34         m_ca_cert = ca_cert;
35         m_key = server_key;
36         auto store = new CertificateStoreInMemory;
37         store.addCertificate(m_ca_cert);
38         m_stores.pushBack(store);
39     }
40     
41     override Vector!CertificateStore trustedCertificateAuthorities(in string, in string)
42     {
43         return m_stores.dup;
44     }
45 
46     override Vector!X509Certificate certChain(const ref Vector!string cert_key_types, in string type, in string) 
47     {
48         Vector!X509Certificate chain;
49         
50         if (type == "tls-server")
51         {
52             bool have_match = false;
53             foreach (cert_key_type; cert_key_types[])
54                 if (cert_key_type == m_key.algoName)
55                     have_match = true;
56             
57             if (have_match)
58             {
59                 chain.pushBack(m_server_cert);
60                 chain.pushBack(m_ca_cert);
61             }
62         }
63         
64         return chain.move();
65     }
66     
67     override void verifyCertificateChain(in string type, in string purported_hostname,
68                                          const ref Vector!X509Certificate cert_chain)
69     {
70         try
71         {
72             super.verifyCertificateChain(type, purported_hostname, cert_chain);
73         }
74         catch(Exception e)
75         {
76             logError("Certificate verification failed - " ~ e.msg ~ " - but will ignore");
77         }
78     }
79     
80     override PrivateKey privateKeyFor(in X509Certificate, in string, in string)
81     {
82         return *m_key;
83     }
84 
85     // Interface fallthrough
86 
87     override Vector!X509Certificate certChainSingleType(in string cert_key_type,
88                                                         in string type,
89                                                         in string context)
90     { return super.certChainSingleType(cert_key_type, type, context); }
91 
92     override bool attemptSrp(in string type, in string context)
93     { return super.attemptSrp(type, context); }
94 
95     override string srpIdentifier(in string type, in string context)
96     { return super.srpIdentifier(type, context); }
97 
98     override string srpPassword(in string type, in string context, in string identifier)
99     { return super.srpPassword(type, context, identifier); }
100 
101     override bool srpVerifier(in string type,
102                               in string context,
103                               in string identifier,
104                               ref string group_name,
105                               ref BigInt verifier,
106                               ref Vector!ubyte salt,
107                               bool generate_fake_on_unknown)
108     { return super.srpVerifier(type, context, identifier, group_name, verifier, salt, generate_fake_on_unknown); }
109 
110     override string pskIdentityHint(in string type, in string context)
111     { return super.pskIdentityHint(type, context); }
112 
113     override string pskIdentity(in string type, in string context, in string identity_hint)
114     { return super.pskIdentity(type, context, identity_hint); }
115 
116     override SymmetricKey psk(in string type, in string context, in string identity)
117     { return super.psk(type, context, identity); }
118 
119 	override bool hasPsk() { return false; }
120 
121 public:
122     X509Certificate m_server_cert, m_ca_cert;
123     Unique!PrivateKey m_key;
124     Vector!CertificateStore m_stores;
125 }
126 
127 TLSCredentialsManager createCreds(RandomNumberGenerator rng)
128 {
129     auto ca_key = RSAPrivateKey(rng, 1024);
130     
131     X509CertOptions ca_opts;
132     ca_opts.common_name = "Test CA";
133     ca_opts.country = "US";
134     ca_opts.CAKey(1);
135     
136     X509Certificate ca_cert = x509self.createSelfSignedCert(ca_opts, *ca_key, "SHA-256", rng);
137     
138     auto server_key = RSAPrivateKey(rng, 1024).release();
139     
140     X509CertOptions server_opts;
141     server_opts.common_name = "localhost";
142     server_opts.country = "US";
143     
144     PKCS10Request req = x509self.createCertReq(server_opts, server_key, "SHA-256", rng);
145     
146     X509CA ca = X509CA(ca_cert, *ca_key, "SHA-256");
147     
148     auto now = Clock.currTime(UTC());
149     X509Time start_time = X509Time(now);
150     X509Time end_time = X509Time(now + 365.days);
151     
152     X509Certificate server_cert = ca.signRequest(req, rng, start_time, end_time);
153     
154 	return new TLSCredentialsManagerTest(server_cert, ca_cert, server_key);
155 }
156 
157 size_t basicTestHandshake(RandomNumberGenerator rng,
158                             TLSProtocolVersion offer_version,
159                             TLSCredentialsManager creds,
160                             TLSPolicy policy)
161 {
162     Unique!TLSSessionManagerInMemory server_sessions = new TLSSessionManagerInMemory(rng);
163 	Unique!TLSSessionManagerInMemory client_sessions = new TLSSessionManagerInMemory(rng);
164     
165     Vector!ubyte c2s_q, s2c_q, c2s_data, s2c_data;
166     
167     auto handshake_complete = delegate(in TLSSession session) {
168         if (session.Version() != offer_version)
169             logError("Wrong version negotiated");
170         return true;
171     };
172     
173     auto print_alert = delegate(in TLSAlert alert, in ubyte[])
174     {
175         if (alert.isValid())
176             logError("TLSServer recvd alert " ~ alert.typeString());
177     };
178     
179     auto save_server_data = delegate(in ubyte[] buf) {
180         c2s_data ~= cast(ubyte[])buf;
181     };
182     
183     auto save_client_data = delegate(in ubyte[] buf) {
184         s2c_data ~= cast(ubyte[])buf;
185     };
186 
187     auto next_protocol_chooser = delegate(in Vector!string protos) {
188         if (protos.length != 2)
189             logError("Bad protocol size");
190         if (protos[0] != "test/1" || protos[1] != "test/2")
191             logError("Bad protocol values: ", protos[]);
192         return "test/3";
193     };
194 
195     Vector!string protocols_offered = ["test/1", "test/2"];
196 
197     Unique!TLSServer server = new TLSServer((in ubyte[] buf) { s2c_q ~= cast(ubyte[]) buf; },
198                                 save_server_data,
199                                 print_alert,
200                                 handshake_complete,
201                                 *server_sessions,
202                                 creds,
203                                 policy,
204                                 rng,
205                                 next_protocol_chooser);
206     
207 
208     Unique!TLSClient client = new TLSClient((in ubyte[] buf) { c2s_q ~= cast(ubyte[]) buf; },
209                                 save_client_data,
210                                 print_alert,
211                                 handshake_complete,
212                                 *client_sessions,
213                                 creds,
214                                 policy,
215                                 rng,
216                                 TLSServerInformation(),
217                                 offer_version,
218                                 protocols_offered.move);
219                             
220     while(true)
221     {
222         if (client.isActive())
223             client.send("1");
224         if (server.isActive())
225         {
226             if (server.applicationProtocol() != "test/3")
227                 logError("Wrong protocol " ~ server.applicationProtocol());
228             server.send("2");
229         }
230         
231         /*
232         * Use this as a temp value to hold the queues as otherwise they
233         * might end up appending more in response to messages during the
234         * handshake.
235         */
236         Vector!ubyte input;
237         input[] = c2s_q;
238 		c2s_q.clear();
239         try
240         {
241             server.receivedData(input.ptr, input.length);
242         }
243         catch(Exception e)
244         {
245             logError("TLSServer error - " ~ e.toString());
246             break;
247         }
248         
249         input.clear();
250 		input[] = s2c_q;
251 		s2c_q.clear();
252         
253         try
254         {
255             client.receivedData(input.ptr, input.length);
256         }
257         catch(Exception e)
258         {
259 			logError("TLSClient error - " ~ e.toString());
260             break;
261         }
262         
263         if (c2s_data.length)
264         {
265             if (c2s_data[0] != '1')
266             {
267                 logError("Error");
268                 return 1;
269             }
270         }
271         
272         if (s2c_data.length)
273         {
274             if (s2c_data[0] != '2')
275             {
276                 logError("Error");
277                 return 1;
278             }
279         }
280         
281         if (s2c_data.length && c2s_data.length)
282             break;
283     }
284     
285     return 0;
286 }
287 
288 class TestPolicy : TLSPolicy
289 {
290 public:
291     override bool acceptableProtocolVersion(TLSProtocolVersion) const { return true; }
292 	override bool sendFallbackSCSV(in TLSProtocolVersion) const { return false; }
293 }
294 
295 static if (BOTAN_HAS_TESTS && !SKIP_TLS_TEST) unittest
296 {
297 	import core.memory : GC;
298 	GC.collect();
299 	import botan.libstate.global_state;
300 	auto state = globalState(); // ensure initialized
301     logDebug("Testing tls/test.d ...");
302     size_t errors = 0;
303     
304     Unique!TestPolicy default_policy = new TestPolicy;
305 	Unique!AutoSeededRNG rng = new AutoSeededRNG;
306     Unique!TLSCredentialsManager basic_creds = createCreds(*rng);
307     
308     errors += basicTestHandshake(*rng, TLSProtocolVersion(TLSProtocolVersion.TLS_V10), *basic_creds, *default_policy);
309     errors += basicTestHandshake(*rng, TLSProtocolVersion(TLSProtocolVersion.TLS_V11), *basic_creds, *default_policy);
310     errors += basicTestHandshake(*rng, TLSProtocolVersion(TLSProtocolVersion.TLS_V12), *basic_creds, *default_policy);
311     
312     testReport("TLS", 4, errors);
313 
314 }