1 /**
2 * TLS Unit tests
3 * 
4 * Copyright:
5 * (C) 2014-2015 Jack Lloyd
6 * (C) 2014-2015 Etienne Cimon
7 *
8 * License:
9 * Botan is released under the Simplified BSD License (see LICENSE.md)
10 */
11 module botan.tls.test;
12 import botan.constants;
13 static if (BOTAN_TEST && BOTAN_HAS_TLS):
14 
15 import botan.test;
16 import botan.rng.auto_rng;
17 import botan.tls.server;
18 import botan.tls.client;
19 import botan.cert.x509.pkcs10;
20 import botan.cert.x509.x509self;
21 import botan.cert.x509.x509_ca;
22 import botan.pubkey.algo.rsa;
23 import botan.codec.hex;
24 import botan.utils.types;
25 import std.stdio;
26 import std.datetime;
27 
28 class TLSCredentialsManagerTest : TLSCredentialsManager
29 {
30 public:
31     this(X509Certificate server_cert, X509Certificate ca_cert, PrivateKey server_key) 
32     {
33         m_server_cert = server_cert;
34         m_ca_cert = ca_cert;
35         m_key = server_key;
36         auto store = new CertificateStoreInMemory;
37         store.addCertificate(m_ca_cert);
38         m_stores.pushBack(store);
39     }
40     
41     override Vector!CertificateStore trustedCertificateAuthorities(in string, in string)
42     {
43         return m_stores.dup;
44     }
45 
46     override Vector!X509Certificate certChain(const ref Vector!string cert_key_types, in string type, in string) 
47     {
48         Vector!X509Certificate chain;
49         
50         if (type == "tls-server")
51         {
52             bool have_match = false;
53             foreach (cert_key_type; cert_key_types[])
54                 if (cert_key_type == m_key.algoName)
55                     have_match = true;
56             
57             if (have_match)
58             {
59                 chain.pushBack(m_server_cert);
60                 chain.pushBack(m_ca_cert);
61             }
62         }
63         
64         return chain.move();
65     }
66     
67     override void verifyCertificateChain(in string type, in string purported_hostname,
68                                          const ref Vector!X509Certificate cert_chain)
69     {
70         try
71         {
72             super.verifyCertificateChain(type, purported_hostname, cert_chain);
73         }
74         catch(Exception e)
75         {
76             logError("Certificate verification failed - " ~ e.msg ~ " - but will ignore");
77         }
78     }
79     
80     override PrivateKey privateKeyFor(in X509Certificate, in string, in string)
81     {
82         return *m_key;
83     }
84 
85     // Interface fallthrough
86 
87     override Vector!X509Certificate certChainSingleType(in string cert_key_type,
88                                                         in string type,
89                                                         in string context)
90     { return super.certChainSingleType(cert_key_type, type, context); }
91 
92     override bool attemptSrp(in string type, in string context)
93     { return super.attemptSrp(type, context); }
94 
95     override string srpIdentifier(in string type, in string context)
96     { return super.srpIdentifier(type, context); }
97 
98     override string srpPassword(in string type, in string context, in string identifier)
99     { return super.srpPassword(type, context, identifier); }
100 
101     override bool srpVerifier(in string type,
102                               in string context,
103                               in string identifier,
104                               ref string group_name,
105                               ref BigInt verifier,
106                               ref Vector!ubyte salt,
107                               bool generate_fake_on_unknown)
108     { return super.srpVerifier(type, context, identifier, group_name, verifier, salt, generate_fake_on_unknown); }
109 
110     override string pskIdentityHint(in string type, in string context)
111     { return super.pskIdentityHint(type, context); }
112 
113     override string pskIdentity(in string type, in string context, in string identity_hint)
114     { return super.pskIdentity(type, context, identity_hint); }
115 
116     override SymmetricKey psk(in string type, in string context, in string identity)
117     { return super.psk(type, context, identity); }
118 
119 public:
120     X509Certificate m_server_cert, m_ca_cert;
121     Unique!PrivateKey m_key;
122     Vector!CertificateStore m_stores;
123 }
124 
125 TLSCredentialsManager createCreds(RandomNumberGenerator rng)
126 {
127     auto ca_key = RSAPrivateKey(rng, 1024);
128     
129     X509CertOptions ca_opts;
130     ca_opts.common_name = "Test CA";
131     ca_opts.country = "US";
132     ca_opts.CAKey(1);
133     
134     X509Certificate ca_cert = x509self.createSelfSignedCert(ca_opts, *ca_key, "SHA-256", rng);
135     
136     auto server_key = RSAPrivateKey(rng, 1024).release();
137     
138     X509CertOptions server_opts;
139     server_opts.common_name = "localhost";
140     server_opts.country = "US";
141     
142     PKCS10Request req = x509self.createCertReq(server_opts, server_key, "SHA-256", rng);
143     
144     X509CA ca = X509CA(ca_cert, *ca_key, "SHA-256");
145     
146     auto now = Clock.currTime(UTC());
147     X509Time start_time = X509Time(now);
148     X509Time end_time = X509Time(now + 365.days);
149     
150     X509Certificate server_cert = ca.signRequest(req, rng, start_time, end_time);
151     
152 	return new TLSCredentialsManagerTest(server_cert, ca_cert, server_key);
153 }
154 
155 size_t basicTestHandshake(RandomNumberGenerator rng,
156                             TLSProtocolVersion offer_version,
157                             TLSCredentialsManager creds,
158                             TLSPolicy policy)
159 {
160     Unique!TLSSessionManagerInMemory server_sessions = new TLSSessionManagerInMemory(rng);
161 	Unique!TLSSessionManagerInMemory client_sessions = new TLSSessionManagerInMemory(rng);
162     
163     Vector!ubyte c2s_q, s2c_q, c2s_data, s2c_data;
164     
165     auto handshake_complete = delegate(in TLSSession session) {
166         if (session.Version() != offer_version)
167             logError("Wrong version negotiated");
168         return true;
169     };
170     
171     auto print_alert = delegate(in TLSAlert alert, in ubyte[])
172     {
173         if (alert.isValid())
174             logError("TLSServer recvd alert " ~ alert.typeString());
175     };
176     
177     auto save_server_data = delegate(in ubyte[] buf) {
178         c2s_data ~= cast(ubyte[])buf;
179     };
180     
181     auto save_client_data = delegate(in ubyte[] buf) {
182         s2c_data ~= cast(ubyte[])buf;
183     };
184 
185     auto next_protocol_chooser = delegate(in Vector!string protos) {
186         if (protos.length != 2)
187             logError("Bad protocol size");
188         if (protos[0] != "test/1" || protos[1] != "test/2")
189             logError("Bad protocol values: ", protos[]);
190         return "test/3";
191     };
192 
193     Vector!string protocols_offered = ["test/1", "test/2"];
194 
195     Unique!TLSServer server = new TLSServer((in ubyte[] buf) { s2c_q ~= cast(ubyte[]) buf; },
196                                 save_server_data,
197                                 print_alert,
198                                 handshake_complete,
199                                 *server_sessions,
200                                 creds,
201                                 policy,
202                                 rng,
203                                 next_protocol_chooser);
204     
205 
206     Unique!TLSClient client = new TLSClient((in ubyte[] buf) { c2s_q ~= cast(ubyte[]) buf; },
207                                 save_client_data,
208                                 print_alert,
209                                 handshake_complete,
210                                 *client_sessions,
211                                 creds,
212                                 policy,
213                                 rng,
214                                 TLSServerInformation(),
215                                 offer_version,
216                                 protocols_offered.move);
217                             
218     while(true)
219     {
220         if (client.isActive())
221             client.send("1");
222         if (server.isActive())
223         {
224             if (server.applicationProtocol() != "test/3")
225                 logError("Wrong protocol " ~ server.applicationProtocol());
226             server.send("2");
227         }
228         
229         /*
230         * Use this as a temp value to hold the queues as otherwise they
231         * might end up appending more in response to messages during the
232         * handshake.
233         */
234         Vector!ubyte input;
235         input[] = c2s_q;
236 		c2s_q.clear();
237         try
238         {
239             server.receivedData(input.ptr, input.length);
240         }
241         catch(Exception e)
242         {
243             logError("TLSServer error - " ~ e.toString());
244             break;
245         }
246         
247         input.clear();
248 		input[] = s2c_q;
249 		s2c_q.clear();
250         
251         try
252         {
253             client.receivedData(input.ptr, input.length);
254         }
255         catch(Exception e)
256         {
257 			logError("TLSClient error - " ~ e.toString());
258             break;
259         }
260         
261         if (c2s_data.length)
262         {
263             if (c2s_data[0] != '1')
264             {
265                 logError("Error");
266                 return 1;
267             }
268         }
269         
270         if (s2c_data.length)
271         {
272             if (s2c_data[0] != '2')
273             {
274                 logError("Error");
275                 return 1;
276             }
277         }
278         
279         if (s2c_data.length && c2s_data.length)
280             break;
281     }
282     
283     return 0;
284 }
285 
286 class TestPolicy : TLSPolicy
287 {
288 public:
289     override bool acceptableProtocolVersion(TLSProtocolVersion) const { return true; }
290 	override bool sendFallbackSCSV(in TLSProtocolVersion) const { return false; }
291 }
292 
293 static if (BOTAN_HAS_TESTS && !SKIP_TLS_TEST) unittest
294 {
295 	import core.memory : GC;
296 	GC.collect();
297 	import botan.libstate.global_state;
298 	auto state = globalState(); // ensure initialized
299     logDebug("Testing tls/test.d ...");
300     size_t errors = 0;
301     
302     Unique!TestPolicy default_policy = new TestPolicy;
303 	Unique!AutoSeededRNG rng = new AutoSeededRNG;
304     Unique!TLSCredentialsManager basic_creds = createCreds(*rng);
305     
306     errors += basicTestHandshake(*rng, TLSProtocolVersion(TLSProtocolVersion.SSL_V3), *basic_creds, *default_policy);
307     errors += basicTestHandshake(*rng, TLSProtocolVersion(TLSProtocolVersion.TLS_V10), *basic_creds, *default_policy);
308     errors += basicTestHandshake(*rng, TLSProtocolVersion(TLSProtocolVersion.TLS_V11), *basic_creds, *default_policy);
309     errors += basicTestHandshake(*rng, TLSProtocolVersion(TLSProtocolVersion.TLS_V12), *basic_creds, *default_policy);
310     
311     testReport("TLS", 4, errors);
312 
313 }