1 /**
2 * TLS Data Reader
3 * 
4 * Copyright:
5 * (C) 2010-2011,2014 Jack Lloyd
6 * (C) 2014-2015 Etienne Cimon
7 *
8 * License:
9 * Botan is released under the Simplified BSD License (see LICENSE.md)
10 */
11 module botan.tls.reader;
12 
13 import botan.constants;
14 static if (BOTAN_HAS_TLS):
15 
16 import botan.utils.exceptn;
17 import memutils.vector;
18 import botan.utils.loadstor;
19 import botan.utils.types;
20 import botan.utils.get_byte;
21 import std.exception;
22 import std.conv : to;
23 /**
24 * Helper class for decoding TLS protocol messages
25 */
26 struct TLSDataReader
27 {
28 public:
29     this(string type, const ref Vector!ubyte buf_input) 
30     {
31         m_typename = type;
32         m_buf = &buf_input; 
33         m_offset = 0;
34     }
35 
36     void assertDone() const
37     {
38         if (hasRemaining())
39             throw decodeError("Extra bytes at end of message");
40     }
41 
42     size_t remainingBytes() const
43     {
44         return m_buf.length - m_offset;
45     }
46 
47     bool hasRemaining() const
48     {
49         return (remainingBytes() > 0);
50     }
51 
52     void discardNext(size_t bytes)
53     {
54         assertAtLeast(bytes);
55         m_offset += bytes;
56     }
57 
58     ushort get_uint()
59     {
60         assertAtLeast(4);
61         ushort result = cast(ushort) make_uint((*m_buf)[m_offset  ], (*m_buf)[m_offset+1],
62                                                (*m_buf)[m_offset+2], (*m_buf)[m_offset+3]);
63         m_offset += 4;
64         return result;
65     }
66 
67     ushort get_ushort()
68     {
69         assertAtLeast(2);
70         ushort result = make_ushort((*m_buf)[m_offset], (*m_buf)[m_offset+1]);
71         m_offset += 2;
72         return result;
73     }
74 
75     ubyte get_byte()
76     {
77         assertAtLeast(1);
78         ubyte result = (*m_buf)[m_offset];
79         m_offset += 1;
80         return result;
81     }
82 
83     
84     Container getElem(T, Container)(size_t num_elems)
85     {
86         assertAtLeast(num_elems * T.sizeof);
87 
88         Container result = Container(num_elems);
89 
90         static if (T.sizeof > 1) foreach (size_t i; 0 .. num_elems)
91             result[i] = loadBigEndian!T(&(*m_buf)[m_offset], i);
92 		else result[] = (*m_buf)[m_offset .. m_offset + num_elems];
93         m_offset += num_elems * T.sizeof;
94         return result;
95     }
96 
97     Vector!T getRange(T)(size_t len_bytes,
98                          size_t min_elems,
99                          size_t max_elems)
100     {
101         const size_t num_elems = getNumElems(len_bytes, T.sizeof, min_elems, max_elems);
102 
103         return getElem!(T, Vector!T)(num_elems);
104     }
105 
106     Vector!T getRangeVector(T)(size_t len_bytes,
107                                size_t min_elems,
108                                size_t max_elems)
109     {
110         const size_t num_elems = getNumElems(len_bytes, T.sizeof, min_elems, max_elems);
111 
112         return getElem!(T, Vector!T)(num_elems);
113     }
114 
115     string getString(size_t len_bytes,
116                      size_t min_bytes,
117                      size_t max_bytes)
118     {
119         Vector!ubyte v = getRangeVector!ubyte(len_bytes, min_bytes, max_bytes);
120 
121         return (cast(immutable(char)*) v.ptr)[0 .. v.length].idup;
122     }
123 
124     Vector!T getFixed(T)(size_t size)
125     {
126         return getElem!(T, Vector!T)(size);
127     }
128 
129 private:
130     size_t getLengthField(size_t len_bytes)
131     {
132         assertAtLeast(len_bytes);
133 
134         if (len_bytes == 1)
135             return get_byte();
136         else if (len_bytes == 2)
137             return get_ushort();
138 
139         throw decodeError("Bad length size");
140     }
141 
142     size_t getNumElems(size_t len_bytes,
143                        size_t T_size,
144                        size_t min_elems,
145                        size_t max_elems)
146     {
147         const size_t byte_length = getLengthField(len_bytes);
148 
149         if (byte_length % T_size != 0)
150             throw decodeError("Size isn't multiple of T");
151 
152         const size_t num_elems = byte_length / T_size;
153         if (num_elems < min_elems || num_elems > max_elems)
154             throw decodeError("Length field outside parameters");
155 
156         return num_elems;
157     }
158 
159     void assertAtLeast(size_t n) const
160     {
161         if (m_buf.length - m_offset < n)
162             throw decodeError("Expected " ~ to!string(n) ~ " bytes remaining, only " ~
163                               to!string(m_buf.length-m_offset) ~ " left");
164     }
165 
166     DecodingError decodeError(in string why) const
167     {
168         return new DecodingError("Invalid " ~ m_typename ~ ": " ~ why);
169     }
170 
171     string m_typename;
172     const Vector!ubyte* m_buf;
173     size_t m_offset;
174 }
175 
176 /**
177 * Helper function for encoding length-tagged vectors
178 */
179 void appendTlsLengthValue(T, Alloc)(ref Vector!( ubyte, Alloc ) buf, in T* vals, 
180                                     size_t vals_size, size_t tag_size)
181 {
182     const size_t T_size = T.sizeof;
183     const size_t val_bytes = T_size * vals_size;
184 
185     if (tag_size != 1 && tag_size != 2)
186         throw new InvalidArgument("appendTlsLengthValue: invalid tag size");
187 
188     if ((tag_size == 1 && val_bytes > 255) ||
189         (tag_size == 2 && val_bytes > 65535))
190         throw new InvalidArgument("appendTlsLengthValue: value too large");
191 
192     foreach (size_t i; 0 .. tag_size)
193         buf.pushBack(get_byte((val_bytes).sizeof-tag_size+i, val_bytes));
194 
195     foreach (size_t i; 0 .. vals_size)
196         foreach (size_t j; 0 .. T_size)
197             buf.pushBack(get_byte(j, vals[i]));
198 }
199 
200 void appendTlsLengthValue(T, Alloc, Alloc2)(ref Vector!( ubyte, Alloc ) buf, 
201                                             auto const ref Vector!( T, Alloc2 ) vals, 
202                                             size_t tag_size)
203 {
204     appendTlsLengthValue(buf, vals.ptr, vals.length, tag_size);
205 }
206 
207 void appendTlsLengthValue(Alloc)(ref Vector!( ubyte, Alloc ) buf, 
208                                  in string str, size_t tag_size)
209 {
210     appendTlsLengthValue(buf, cast(const(ubyte)*)(str.ptr), str.length, tag_size);
211 }