1 /**
2 * TLS Unit tests
3 * 
4 * Copyright:
5 * (C) 2014-2015 Jack Lloyd
6 * (C) 2014-2015 Etienne Cimon
7 *
8 * License:
9 * Botan is released under the Simplified BSD License (see LICENSE.md)
10 */
11 module botan.tls.test;
12 import botan.constants;
13 static if (BOTAN_TEST && BOTAN_HAS_TLS):
14 
15 import botan.math.bigint.bigint;
16 import botan.test;
17 import botan.rng.auto_rng;
18 import botan.tls.server;
19 import botan.tls.client;
20 import botan.cert.x509.pkcs10;
21 import botan.cert.x509.x509self;
22 import botan.cert.x509.x509_ca;
23 import botan.pubkey.algo.rsa;
24 import botan.codec.hex;
25 import botan.utils.types;
26 import std.stdio;
27 import std.datetime;
28 
29 class TLSCredentialsManagerTest : TLSCredentialsManager
30 {
31 public:
32     this(X509Certificate server_cert, X509Certificate ca_cert, PrivateKey server_key) 
33     {
34         m_server_cert = server_cert;
35         m_ca_cert = ca_cert;
36         m_key = server_key;
37         auto store = new CertificateStoreInMemory;
38         store.addCertificate(m_ca_cert);
39         m_stores.pushBack(store);
40     }
41     
42     override Vector!CertificateStore trustedCertificateAuthorities(in string, in string)
43     {
44         return m_stores.dup;
45     }
46 
47     override Vector!X509Certificate certChain(const ref Vector!string cert_key_types, in string type, in string) 
48     {
49         Vector!X509Certificate chain;
50         
51         if (type == "tls-server")
52         {
53             bool have_match = false;
54             foreach (cert_key_type; cert_key_types[])
55                 if (cert_key_type == m_key.algoName)
56                     have_match = true;
57             
58             if (have_match)
59             {
60                 chain.pushBack(m_server_cert);
61                 chain.pushBack(m_ca_cert);
62             }
63         }
64         
65         return chain.move();
66     }
67     
68     override void verifyCertificateChain(in string type, in string purported_hostname,
69                                          const ref Vector!X509Certificate cert_chain)
70     {
71         try
72         {
73             super.verifyCertificateChain(type, purported_hostname, cert_chain);
74         }
75         catch(Exception e)
76         {
77             logError("Certificate verification failed - " ~ e.msg ~ " - but will ignore");
78         }
79     }
80     
81     override PrivateKey privateKeyFor(in X509Certificate, in string, in string)
82     {
83         return *m_key;
84     }
85 
86     // Interface fallthrough
87 
88     override Vector!X509Certificate certChainSingleType(in string cert_key_type,
89                                                         in string type,
90                                                         in string context)
91     { return super.certChainSingleType(cert_key_type, type, context); }
92 
93     override bool attemptSrp(in string type, in string context)
94     { return super.attemptSrp(type, context); }
95 
96     override string srpIdentifier(in string type, in string context)
97     { return super.srpIdentifier(type, context); }
98 
99     override string srpPassword(in string type, in string context, in string identifier)
100     { return super.srpPassword(type, context, identifier); }
101 
102     override bool srpVerifier(in string type,
103                               in string context,
104                               in string identifier,
105                               ref string group_name,
106                               ref BigInt verifier,
107                               ref Vector!ubyte salt,
108                               bool generate_fake_on_unknown)
109     { return super.srpVerifier(type, context, identifier, group_name, verifier, salt, generate_fake_on_unknown); }
110 
111     override string pskIdentityHint(in string type, in string context)
112     { return super.pskIdentityHint(type, context); }
113 
114     override string pskIdentity(in string type, in string context, in string identity_hint)
115     { return super.pskIdentity(type, context, identity_hint); }
116 
117     override SymmetricKey psk(in string type, in string context, in string identity)
118     { return super.psk(type, context, identity); }
119 
120 	override bool hasPsk() { return false; }
121 
122 public:
123     X509Certificate m_server_cert, m_ca_cert;
124     Unique!PrivateKey m_key;
125     Vector!CertificateStore m_stores;
126 }
127 
128 TLSCredentialsManager createCreds(RandomNumberGenerator rng)
129 {
130     auto ca_key = RSAPrivateKey(rng, 1024);
131     
132     X509CertOptions ca_opts;
133     ca_opts.common_name = "Test CA";
134     ca_opts.country = "US";
135     ca_opts.CAKey(1);
136     
137     X509Certificate ca_cert = x509self.createSelfSignedCert(ca_opts, *ca_key, "SHA-256", rng);
138     
139     auto server_key = RSAPrivateKey(rng, 1024).release();
140     
141     X509CertOptions server_opts;
142     server_opts.common_name = "localhost";
143     server_opts.country = "US";
144     
145     PKCS10Request req = x509self.createCertReq(server_opts, server_key, "SHA-256", rng);
146     
147     X509CA ca = X509CA(ca_cert, *ca_key, "SHA-256");
148     
149     auto now = Clock.currTime(UTC());
150     X509Time start_time = X509Time(now);
151     X509Time end_time = X509Time(now + 365.days);
152     
153     X509Certificate server_cert = ca.signRequest(req, rng, start_time, end_time);
154     
155 	return new TLSCredentialsManagerTest(server_cert, ca_cert, server_key);
156 }
157 
158 size_t basicTestHandshake(RandomNumberGenerator rng,
159                             TLSProtocolVersion offer_version,
160                             TLSCredentialsManager creds,
161                             TLSPolicy policy)
162 {
163     Unique!TLSSessionManagerInMemory server_sessions = new TLSSessionManagerInMemory(rng);
164 	Unique!TLSSessionManagerInMemory client_sessions = new TLSSessionManagerInMemory(rng);
165     
166     Vector!ubyte c2s_q, s2c_q, c2s_data, s2c_data;
167     
168     auto handshake_complete = delegate(in TLSSession session) {
169         if (session.Version() != offer_version)
170             logError("Wrong version negotiated");
171         return true;
172     };
173     
174     auto print_alert = delegate(in TLSAlert alert, in ubyte[])
175     {
176         if (alert.isValid())
177             logError("TLSServer recvd alert " ~ alert.typeString());
178     };
179     
180     auto save_server_data = delegate(in ubyte[] buf) {
181         c2s_data ~= cast(ubyte[])buf;
182     };
183     
184     auto save_client_data = delegate(in ubyte[] buf) {
185         s2c_data ~= cast(ubyte[])buf;
186     };
187 
188     auto next_protocol_chooser = delegate(in Vector!string protos) {
189         if (protos.length != 2)
190             logError("Bad protocol size");
191         if (protos[0] != "test/1" || protos[1] != "test/2")
192             logError("Bad protocol values: ", protos[]);
193         return "test/3";
194     };
195 
196     Vector!string protocols_offered = ["test/1", "test/2"];
197 
198     Unique!TLSServer server = new TLSServer((in ubyte[] buf) { s2c_q ~= cast(ubyte[]) buf; },
199                                 save_server_data,
200                                 print_alert,
201                                 handshake_complete,
202                                 *server_sessions,
203                                 creds,
204                                 policy,
205                                 rng,
206                                 next_protocol_chooser);
207     
208 
209     Unique!TLSClient client = new TLSClient((in ubyte[] buf) { c2s_q ~= cast(ubyte[]) buf; },
210                                 save_client_data,
211                                 print_alert,
212                                 handshake_complete,
213                                 *client_sessions,
214                                 creds,
215                                 policy,
216                                 rng,
217                                 TLSServerInformation(),
218                                 offer_version,
219                                 protocols_offered.move);
220                             
221     while(true)
222     {
223         if (client.isActive())
224             client.send("1");
225         if (server.isActive())
226         {
227             if (server.applicationProtocol() != "test/3")
228                 logError("Wrong protocol " ~ server.applicationProtocol());
229             server.send("2");
230         }
231         
232         /*
233         * Use this as a temp value to hold the queues as otherwise they
234         * might end up appending more in response to messages during the
235         * handshake.
236         */
237         Vector!ubyte input;
238         input[] = c2s_q;
239 		c2s_q.clear();
240         try
241         {
242             server.receivedData(input.ptr, input.length);
243         }
244         catch(Exception e)
245         {
246             logError("TLSServer error - " ~ e.toString());
247             break;
248         }
249         
250         input.clear();
251 		input[] = s2c_q;
252 		s2c_q.clear();
253         
254         try
255         {
256             client.receivedData(input.ptr, input.length);
257         }
258         catch(Exception e)
259         {
260 			logError("TLSClient error - " ~ e.toString());
261             break;
262         }
263         
264         if (c2s_data.length)
265         {
266             if (c2s_data[0] != '1')
267             {
268                 logError("Error");
269                 return 1;
270             }
271         }
272         
273         if (s2c_data.length)
274         {
275             if (s2c_data[0] != '2')
276             {
277                 logError("Error");
278                 return 1;
279             }
280         }
281         
282         if (s2c_data.length && c2s_data.length)
283             break;
284     }
285     
286     return 0;
287 }
288 
289 class TestPolicy : TLSPolicy
290 {
291 public:
292     override bool acceptableProtocolVersion(TLSProtocolVersion) const { return true; }
293 	override bool sendFallbackSCSV(in TLSProtocolVersion) const { return false; }
294 }
295 
296 static if (BOTAN_HAS_TESTS && !SKIP_TLS_TEST) unittest
297 {
298 	import core.memory : GC;
299 	GC.collect();
300 	import botan.libstate.global_state;
301 	auto state = globalState(); // ensure initialized
302     logDebug("Testing tls/test.d ...");
303     size_t errors = 0;
304     
305     Unique!TestPolicy default_policy = new TestPolicy;
306 	Unique!AutoSeededRNG rng = new AutoSeededRNG;
307     Unique!TLSCredentialsManager basic_creds = createCreds(*rng);
308     
309     errors += basicTestHandshake(*rng, TLSProtocolVersion(TLSProtocolVersion.TLS_V10), *basic_creds, *default_policy);
310     errors += basicTestHandshake(*rng, TLSProtocolVersion(TLSProtocolVersion.TLS_V11), *basic_creds, *default_policy);
311     errors += basicTestHandshake(*rng, TLSProtocolVersion(TLSProtocolVersion.TLS_V12), *basic_creds, *default_policy);
312     
313     testReport("TLS", 4, errors);
314 
315 }