1 /**
2 Copyright: Copyright (c) 2017, Oleg Butko. All rights reserved.
3 Copyright: Copyright (c) 2018-2019, Joakim Brännström. All rights reserved.
4 License: MIT
5 Author: Joakim Brännström (joakim.brannstrom@gmx.com)
6 Author: Oleg Butko (deviator)
7 */
8 module miniorm.api;
9 
10 import logger = std.experimental.logger;
11 
12 import std.array : Appender;
13 import std.datetime : SysTime, Duration;
14 import std.range;
15 
16 import miniorm.exception;
17 import miniorm.queries;
18 
19 import d2sqlite3;
20 
21 version (unittest) {
22     import std.algorithm : map;
23     import unit_threaded.assertions;
24 }
25 
26 ///
27 struct Miniorm {
28     private Statement[string] cachedStmt;
29     /// True means that all queries are logged.
30     private bool log_;
31 
32     ///
33     Database db;
34     alias getUnderlyingDb this;
35 
36     ref Database getUnderlyingDb() {
37         return db;
38     }
39 
40     alias getUnderlyingDb this;
41 
42     ///
43     this(Database db) {
44         this.db = db;
45     }
46 
47     ///
48     this(string path, int flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE) {
49         this(Database(path, flags));
50     }
51 
52     ~this() {
53         cleanupCache;
54     }
55 
56     /// Toggle logging.
57     void log(bool v) {
58         this.log_ = v;
59     }
60 
61     /// Returns: True if logging is activated
62     private bool isLog() {
63         return log_;
64     }
65 
66     private void cleanupCache() {
67         foreach (ref s; cachedStmt.byValue)
68             s.finalize;
69         cachedStmt = null;
70     }
71 
72     void opAssign(ref typeof(this) rhs) {
73         cleanupCache;
74         db = rhs.db;
75     }
76 
77     void run(string sql, bool delegate(ResultRange) dg = null) {
78         if (isLog)
79             logger.trace(sql);
80         db.run(sql, dg);
81     }
82 
83     void close() {
84         cleanupCache;
85         db.close();
86     }
87 
88     size_t run(T)(Count!T v) {
89         const sql = v.toSql.toString;
90         if (isLog)
91             logger.trace(sql);
92         return db.executeCheck(sql).front.front.as!size_t;
93     }
94 
95     auto run(T)(Select!T v) {
96         import std.algorithm : map;
97         import std.format : format;
98         import std.range : inputRangeObject;
99 
100         const sql = v.toSql.toString;
101         if (isLog)
102             logger.trace(sql);
103 
104         auto result = db.executeCheck(sql);
105 
106         static T qconv(typeof(result.front) e) {
107             import miniorm.schema : fieldToCol;
108 
109             T ret;
110             static string rr() {
111                 string[] res;
112                 res ~= "import std.traits : isStaticArray, OriginalType;";
113                 res ~= "import miniorm.api : fromSqLiteDateTime;";
114                 foreach (i, a; fieldToCol!("", T)()) {
115                     res ~= `{`;
116                     if (a.columnType == "DATETIME") {
117                         res ~= `{ ret.%1$s = fromSqLiteDateTime(e.peek!string(%2$d)); }`.format(a.identifier,
118                                 i);
119                     } else {
120                         res ~= q{alias ET = typeof(ret.%s);}.format(a.identifier);
121                         res ~= q{static if (isStaticArray!ET)};
122                         res ~= `
123                             {
124                                 import std.algorithm : min;
125                                 auto ubval = e[%2$d].as!(ubyte[]);
126                                 auto etval = cast(typeof(ET.init[]))ubval;
127                                 auto ln = min(ret.%1$s.length, etval.length);
128                                 ret.%1$s[0..ln] = etval[0..ln];
129                             }
130                             `.format(a.identifier, i);
131                         res ~= q{else static if (is(ET == enum))};
132                         res ~= format(q{ret.%1$s = cast(ET) e.peek!ET(%2$d);}, a.identifier, i);
133                         res ~= q{else};
134                         res ~= format(q{ret.%1$s = e.peek!ET(%2$d);}, a.identifier, i);
135                     }
136                     res ~= `}`;
137                 }
138                 return res.join("\n");
139             }
140 
141             mixin(rr());
142             return ret;
143         }
144 
145         return result.map!qconv;
146     }
147 
148     void run(T)(Delete!T v) {
149         const sql = v.toSql.toString;
150         if (isLog)
151             logger.trace(sql);
152         db.run(sql);
153     }
154 
155     void run(AggregateInsert all = AggregateInsert.no, T0, T1)(Insert!T0 v, T1[] arr...)
156             if (!isInputRange!T1) {
157         procInsert!all(v, arr);
158     }
159 
160     void run(AggregateInsert all = AggregateInsert.no, T, R)(Insert!T v, R rng)
161             if (isInputRange!R) {
162         procInsert!all(v, rng);
163     }
164 
165     private void procInsert(AggregateInsert all = AggregateInsert.no, T, R)(Insert!T q, R rng)
166             if ((all && hasLength!R) || !all) {
167         import std.algorithm : among;
168 
169         // generate code for binding values in a struct to a prepared
170         // statement.
171         // Expects an external variable "n" to exist that keeps track of the
172         // index. This is requied when the binding is for multiple values.
173         // Expects the statement to be named "stmt".
174         // Expects the variable to read values from to be named "v".
175         // Indexing start from 1 according to the sqlite manual.
176         static string genBinding(T)(bool replace) {
177             import miniorm.schema : fieldToCol;
178 
179             string s;
180             foreach (i, v; fieldToCol!("", T)) {
181                 if (!replace && v.isPrimaryKey)
182                     continue;
183                 if (v.columnType == "DATETIME")
184                     s ~= "stmt.bind(n+1, v." ~ v.identifier ~ ".toUTC.toSqliteDateTime);";
185                 else
186                     s ~= "stmt.bind(n+1, v." ~ v.identifier ~ ");";
187                 s ~= "++n;";
188             }
189             return s;
190         }
191 
192         alias T = ElementType!R;
193 
194         const replace = q.query.opt == InsertOpt.InsertOrReplace;
195 
196         static if (all == AggregateInsert.yes)
197             q = q.values(rng.length);
198         else
199             q = q.values(1);
200 
201         const sql = q.toSql.toString;
202 
203         if (sql !in cachedStmt)
204             cachedStmt[sql] = db.prepare(sql);
205         auto stmt = cachedStmt[sql];
206 
207         static if (all == AggregateInsert.yes) {
208             int n;
209             foreach (v; rng) {
210                 if (replace) {
211                     mixin(genBinding!T(true));
212                 } else {
213                     mixin(genBinding!T(false));
214                 }
215             }
216             if (isLog)
217                 logger.trace(sql, " -> ", rng);
218             stmt.execute();
219             stmt.reset();
220         } else {
221             foreach (v; rng) {
222                 int n;
223                 if (replace) {
224                     mixin(genBinding!T(true));
225                 } else {
226                     mixin(genBinding!T(false));
227                 }
228                 if (isLog)
229                     logger.trace(sql, " -> ", v);
230                 stmt.execute();
231                 stmt.reset();
232             }
233         }
234     }
235 }
236 
237 /** Wheter one aggregated insert or multiple should be generated.
238  *
239  * no:
240  * ---
241  * INSERT INTO foo ('v0') VALUES (?)
242  * INSERT INTO foo ('v0') VALUES (?)
243  * INSERT INTO foo ('v0') VALUES (?)
244  * ---
245  *
246  * yes:
247  * ---
248  * INSERT INTO foo ('v0') VALUES (?) (?) (?)
249  * ---
250  */
251 enum AggregateInsert {
252     no,
253     yes
254 }
255 
256 version (unittest) {
257     import miniorm.schema;
258 
259     import std.conv : text, to;
260     import std.range;
261     import std.algorithm;
262     import std.datetime;
263     import std.array;
264     import std.stdio;
265 
266     import unit_threaded.assertions;
267 }
268 
269 @("shall operate on a database allocted in std.experimental.allocators without any errors")
270 unittest {
271     struct One {
272         ulong id;
273         string text;
274     }
275 
276     import std.experimental.allocator;
277     import std.experimental.allocator.mallocator;
278     import std.experimental.allocator.building_blocks.scoped_allocator;
279 
280     // TODO: fix this
281     //Microrm* db;
282     //ScopedAllocator!Mallocator scalloc;
283     //db = scalloc.make!Microrm(":memory:");
284     //scope (exit) {
285     //    db.close;
286     //    scalloc.dispose(db);
287     //}
288 
289     // TODO: replace the one below with the above code.
290     auto db = Miniorm(":memory:");
291     db.run(buildSchema!One);
292     db.run(insert!One.insert, iota(0, 10).map!(i => One(i * 100, "hello" ~ text(i))));
293     db.run(count!One).shouldEqual(10);
294 
295     auto ones = db.run(select!One).array;
296     ones.length.shouldEqual(10);
297     assert(ones.all!(a => a.id < 100));
298     db.getUnderlyingDb.lastInsertRowid.shouldEqual(ones[$ - 1].id);
299 
300     db.run(delete_!One);
301     db.run(count!One).shouldEqual(0);
302     db.run(insertOrReplace!One, iota(0, 499).map!(i => One((i + 1) * 100, "hello" ~ text(i))));
303     ones = db.run(select!One).array;
304     ones.length.shouldEqual(499);
305     assert(ones.all!(a => a.id >= 100));
306     db.lastInsertRowid.shouldEqual(ones[$ - 1].id);
307 }
308 
309 @("shall insert and extract datetime from the table")
310 unittest {
311     import std.datetime : Clock;
312     import core.thread : Thread;
313     import core.time : dur;
314 
315     struct One {
316         ulong id;
317         SysTime time;
318     }
319 
320     auto db = Miniorm(":memory:");
321     db.run(buildSchema!One);
322 
323     const time = Clock.currTime;
324     Thread.sleep(1.dur!"msecs");
325 
326     db.run(insert!One.insert, One(0, Clock.currTime));
327 
328     auto ones = db.run(select!One).array;
329     ones.length.shouldEqual(1);
330     ones[0].time.shouldBeGreaterThan(time);
331 }
332 
333 unittest {
334     struct One {
335         ulong id;
336         string text;
337     }
338 
339     auto db = Miniorm(":memory:");
340     db.run(buildSchema!One);
341 
342     db.run(count!One).shouldEqual(0);
343     db.run!(AggregateInsert.yes)(insert!One.insert, iota(0, 10)
344             .map!(i => One(i * 100, "hello" ~ text(i))));
345     db.run(count!One).shouldEqual(10);
346 
347     auto ones = db.run(select!One).array;
348     assert(ones.length == 10);
349     assert(ones.all!(a => a.id < 100));
350     assert(db.lastInsertRowid == ones[$ - 1].id);
351 
352     db.run(delete_!One);
353     db.run(count!One).shouldEqual(0);
354 
355     import std.datetime;
356     import std.conv : to;
357 
358     db.run!(AggregateInsert.yes)(insertOrReplace!One, iota(0, 499)
359             .map!(i => One((i + 1) * 100, "hello" ~ text(i))));
360     ones = db.run(select!One).array;
361     assert(ones.length == 499);
362     assert(ones.all!(a => a.id >= 100));
363     assert(db.lastInsertRowid == ones[$ - 1].id);
364 }
365 
366 @("shall convert the database type to the enum when retrieving via select")
367 unittest {
368     static struct Foo {
369         enum MyEnum : string {
370             foo = "batman",
371             bar = "robin",
372         }
373 
374         ulong id;
375         MyEnum enum_;
376     }
377 
378     auto db = Miniorm(":memory:");
379     db.run(buildSchema!Foo);
380 
381     db.run(insert!Foo.insert, Foo(0, Foo.MyEnum.bar));
382     auto res = db.run(select!Foo).array;
383 
384     res.length.shouldEqual(1);
385     res[0].enum_.shouldEqual(Foo.MyEnum.bar);
386 }
387 
388 unittest {
389     struct Limit {
390         int min, max;
391     }
392 
393     struct Limits {
394         Limit volt, curr;
395     }
396 
397     struct Settings {
398         ulong id;
399         Limits limits;
400     }
401 
402     auto db = Miniorm(":memory:");
403     db.run(buildSchema!Settings);
404     assert(db.run(count!Settings) == 0);
405     db.run(insertOrReplace!Settings, Settings(10, Limits(Limit(0, 12), Limit(-10, 10))));
406     assert(db.run(count!Settings) == 1);
407 
408     db.run(insertOrReplace!Settings, Settings(10, Limits(Limit(0, 2), Limit(-3, 3))));
409     db.run(insertOrReplace!Settings, Settings(11, Limits(Limit(0, 11), Limit(-11, 11))));
410     db.run(insertOrReplace!Settings, Settings(12, Limits(Limit(0, 12), Limit(-12, 12))));
411 
412     assert(db.run(count!Settings) == 3);
413     assert(db.run(count!Settings.where(`"limits.volt.max" = 2`)) == 1);
414     assert(db.run(count!Settings.where(`"limits.volt.max" > 10`)) == 2);
415     db.run(delete_!Settings.where(`"limits.volt.max" < 10`));
416     assert(db.run(count!Settings) == 2);
417 }
418 
419 unittest {
420     struct Settings {
421         ulong id;
422         int[5] data;
423     }
424 
425     auto db = Miniorm(":memory:");
426     db.run(buildSchema!Settings);
427 
428     db.run(insert!Settings.insert, Settings(0, [1, 2, 3, 4, 5]));
429 
430     assert(db.run(count!Settings) == 1);
431     auto s = db.run(select!Settings).front;
432     assert(s.data == [1, 2, 3, 4, 5]);
433 }
434 
435 SysTime fromSqLiteDateTime(string raw_dt) {
436     import core.time : dur;
437     import std.datetime : DateTime, UTC;
438     import std.format : formattedRead;
439 
440     int year, month, day, hour, minute, second, msecs;
441     formattedRead(raw_dt, "%s-%s-%s %s:%s:%s.%s", year, month, day, hour, minute, second, msecs);
442     auto dt = DateTime(year, month, day, hour, minute, second);
443 
444     return SysTime(dt, msecs.dur!"msecs", UTC());
445 }
446 
447 string toSqliteDateTime(SysTime ts) {
448     import std.format;
449 
450     return format("%04s-%02s-%02s %02s:%02s:%02s.%s", ts.year,
451             cast(ushort) ts.month, ts.day, ts.hour, ts.minute, ts.second,
452             ts.fracSecs.total!"msecs");
453 }
454 
455 class SpinSqlTimeout : Exception {
456     this() {
457         super(null);
458     }
459 }
460 
461 /** Execute an SQL query until it succeeds.
462  *
463  * Note: If there are any errors in the query it will go into an infinite loop.
464  */
465 auto spinSql(alias query, alias logFn = logger.warning)(Duration timeout = Duration.max) {
466     import core.thread : Thread;
467     import core.time : dur;
468     import std.datetime.stopwatch : StopWatch, AutoStart;
469     import std.exception : collectException;
470     import std.random : uniform;
471 
472     const sw = StopWatch(AutoStart.yes);
473 
474     while (sw.peek < timeout) {
475         try {
476             return query();
477         } catch (Exception e) {
478             logFn(e.msg).collectException;
479             // even though the database have a builtin sleep it still result in too much spam.
480             Thread.sleep(uniform(50, 150).dur!"msecs");
481         }
482     }
483 
484     throw new SpinSqlTimeout();
485 }