1 /**
2 Copyright: Copyright (c) 2017, Oleg Butko. All rights reserved.
3 Copyright: Copyright (c) 2018-2019, Joakim Brännström. All rights reserved.
4 License: MIT
5 Author: Joakim Brännström (joakim.brannstrom@gmx.com)
6 Author: Oleg Butko (deviator)
7 */
8 module miniorm.api;
9 
10 import core.time : dur;
11 import logger = std.experimental.logger;
12 import std.array : Appender;
13 import std.datetime : SysTime, Duration;
14 import std.range;
15 
16 import miniorm.exception;
17 import miniorm.queries;
18 
19 import d2sqlite3;
20 
21 version (unittest) {
22     import std.algorithm : map;
23     import unit_threaded.assertions;
24 }
25 
26 ///
27 struct Miniorm {
28     private Statement[string] cachedStmt;
29     /// True means that all queries are logged.
30     private bool log_;
31 
32     ///
33     private Database db;
34     alias getUnderlyingDb this;
35 
36     ref Database getUnderlyingDb() return  {
37         return db;
38     }
39 
40     ///
41     this(Database db) {
42         this.db = db;
43     }
44 
45     ///
46     this(string path, int flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE) {
47         this(Database(path, flags));
48     }
49 
50     ~this() {
51         cleanupCache;
52     }
53 
54     /// Start a RAII handled transaction.
55     Transaction transaction() {
56         return Transaction(db);
57     }
58 
59     RefCntStatement prepare(string sql) {
60         if (auto v = sql in cachedStmt) {
61             return RefCntStatement(*v);
62         }
63         auto r = db.prepare(sql);
64         cachedStmt[sql] = r;
65         return RefCntStatement(cachedStmt[sql]);
66     }
67 
68     /// Toggle logging.
69     void log(bool v) nothrow {
70         this.log_ = v;
71     }
72 
73     /// Returns: True if logging is activated
74     private bool isLog() {
75         return log_;
76     }
77 
78     private void cleanupCache() {
79         foreach (ref s; cachedStmt.byValue)
80             s.finalize;
81         cachedStmt = null;
82     }
83 
84     void opAssign(ref typeof(this) rhs) {
85         cleanupCache;
86         db = rhs.db;
87     }
88 
89     void run(string sql, bool delegate(ResultRange) dg = null) {
90         if (isLog)
91             logger.trace(sql);
92         db.run(sql, dg);
93     }
94 
95     void close() {
96         cleanupCache;
97         db.close();
98     }
99 
100     size_t run(T)(Count!T v) {
101         const sql = v.toSql.toString;
102         if (isLog)
103             logger.trace(sql);
104         return db.executeCheck(sql).front.front.as!size_t;
105     }
106 
107     auto run(T)(Select!T v) {
108         import std.algorithm : map;
109         import std.format : format;
110         import std.range : inputRangeObject;
111 
112         const sql = v.toSql.toString;
113         if (isLog)
114             logger.trace(sql);
115 
116         auto result = db.executeCheck(sql);
117 
118         static T qconv(typeof(result.front) e) {
119             import miniorm.schema : fieldToCol;
120 
121             T ret;
122             static string rr() {
123                 string[] res;
124                 res ~= "import std.traits : isStaticArray, OriginalType;";
125                 res ~= "import miniorm.api : fromSqLiteDateTime;";
126                 foreach (i, a; fieldToCol!("", T)()) {
127                     res ~= `{`;
128                     if (a.columnType == "DATETIME") {
129                         res ~= `{ ret.%1$s = fromSqLiteDateTime(e.peek!string(%2$d)); }`.format(a.identifier,
130                                 i);
131                     } else {
132                         res ~= q{alias ET = typeof(ret.%s);}.format(a.identifier);
133                         res ~= q{static if (isStaticArray!ET)};
134                         res ~= `
135                             {
136                                 import std.algorithm : min;
137                                 auto ubval = e[%2$d].as!(ubyte[]);
138                                 auto etval = cast(typeof(ET.init[]))ubval;
139                                 auto ln = min(ret.%1$s.length, etval.length);
140                                 ret.%1$s[0..ln] = etval[0..ln];
141                             }
142                             `.format(a.identifier, i);
143                         res ~= q{else static if (is(ET == enum))};
144                         res ~= format(q{ret.%1$s = cast(ET) e.peek!ET(%2$d);}, a.identifier, i);
145                         res ~= q{else};
146                         res ~= format(q{ret.%1$s = e.peek!ET(%2$d);}, a.identifier, i);
147                     }
148                     res ~= `}`;
149                 }
150                 return res.join("\n");
151             }
152 
153             mixin(rr());
154             return ret;
155         }
156 
157         return result.map!qconv;
158     }
159 
160     void run(T)(Delete!T v) {
161         const sql = v.toSql.toString;
162         if (isLog)
163             logger.trace(sql);
164         db.run(sql);
165     }
166 
167     void run(AggregateInsert all = AggregateInsert.no, T0, T1)(Insert!T0 v, T1[] arr...)
168             if (!isInputRange!T1) {
169         procInsert!all(v, arr);
170     }
171 
172     void run(AggregateInsert all = AggregateInsert.no, T, R)(Insert!T v, R rng)
173             if (isInputRange!R) {
174         procInsert!all(v, rng);
175     }
176 
177     private void procInsert(AggregateInsert all = AggregateInsert.no, T, R)(Insert!T q, R rng)
178             if ((all && hasLength!R) || !all) {
179         import std.algorithm : among;
180 
181         // generate code for binding values in a struct to a prepared
182         // statement.
183         // Expects an external variable "n" to exist that keeps track of the
184         // index. This is requied when the binding is for multiple values.
185         // Expects the statement to be named "stmt".
186         // Expects the variable to read values from to be named "v".
187         // Indexing start from 1 according to the sqlite manual.
188         static string genBinding(T)(bool replace) {
189             import miniorm.schema : fieldToCol;
190 
191             string s;
192             foreach (i, v; fieldToCol!("", T)) {
193                 if (!replace && v.isPrimaryKey)
194                     continue;
195                 if (v.columnType == "DATETIME")
196                     s ~= "stmt.get.bind(n+1, v." ~ v.identifier ~ ".toUTC.toSqliteDateTime);";
197                 else
198                     s ~= "stmt.get.bind(n+1, v." ~ v.identifier ~ ");";
199                 s ~= "++n;";
200             }
201             return s;
202         }
203 
204         alias T = ElementType!R;
205 
206         const replace = q.query.opt == InsertOpt.InsertOrReplace;
207 
208         static if (all == AggregateInsert.yes)
209             q = q.values(rng.length);
210         else
211             q = q.values(1);
212 
213         const sql = q.toSql.toString;
214 
215         auto stmt = prepare(sql);
216 
217         static if (all == AggregateInsert.yes) {
218             int n;
219             foreach (v; rng) {
220                 if (replace) {
221                     mixin(genBinding!T(true));
222                 } else {
223                     mixin(genBinding!T(false));
224                 }
225             }
226             if (isLog)
227                 logger.trace(sql, " -> ", rng);
228             stmt.get.execute();
229             stmt.get.reset();
230         } else {
231             foreach (v; rng) {
232                 int n;
233                 if (replace) {
234                     mixin(genBinding!T(true));
235                 } else {
236                     mixin(genBinding!T(false));
237                 }
238                 if (isLog)
239                     logger.trace(sql, " -> ", v);
240                 stmt.get.execute();
241                 stmt.get.reset();
242             }
243         }
244     }
245 }
246 
247 /** Wheter one aggregated insert or multiple should be generated.
248  *
249  * no:
250  * ---
251  * INSERT INTO foo ('v0') VALUES (?)
252  * INSERT INTO foo ('v0') VALUES (?)
253  * INSERT INTO foo ('v0') VALUES (?)
254  * ---
255  *
256  * yes:
257  * ---
258  * INSERT INTO foo ('v0') VALUES (?) (?) (?)
259  * ---
260  */
261 enum AggregateInsert {
262     no,
263     yes
264 }
265 
266 version (unittest) {
267     import miniorm.schema;
268 
269     import std.conv : text, to;
270     import std.range;
271     import std.algorithm;
272     import std.datetime;
273     import std.array;
274     import std.stdio;
275 
276     import unit_threaded.assertions;
277 }
278 
279 @("shall operate on a database allocted in std.experimental.allocators without any errors")
280 unittest {
281     struct One {
282         ulong id;
283         string text;
284     }
285 
286     import std.experimental.allocator;
287     import std.experimental.allocator.mallocator;
288     import std.experimental.allocator.building_blocks.scoped_allocator;
289 
290     // TODO: fix this
291     //Microrm* db;
292     //ScopedAllocator!Mallocator scalloc;
293     //db = scalloc.make!Microrm(":memory:");
294     //scope (exit) {
295     //    db.close;
296     //    scalloc.dispose(db);
297     //}
298 
299     // TODO: replace the one below with the above code.
300     auto db = Miniorm(":memory:");
301     db.run(buildSchema!One);
302     db.run(insert!One.insert, iota(0, 10).map!(i => One(i * 100, "hello" ~ text(i))));
303     db.run(count!One).shouldEqual(10);
304 
305     auto ones = db.run(select!One).array;
306     ones.length.shouldEqual(10);
307     assert(ones.all!(a => a.id < 100));
308     db.getUnderlyingDb.lastInsertRowid.shouldEqual(ones[$ - 1].id);
309 
310     db.run(delete_!One);
311     db.run(count!One).shouldEqual(0);
312     db.run(insertOrReplace!One, iota(0, 499).map!(i => One((i + 1) * 100, "hello" ~ text(i))));
313     ones = db.run(select!One).array;
314     ones.length.shouldEqual(499);
315     assert(ones.all!(a => a.id >= 100));
316     db.lastInsertRowid.shouldEqual(ones[$ - 1].id);
317 }
318 
319 @("shall insert and extract datetime from the table")
320 unittest {
321     import std.datetime : Clock;
322     import core.thread : Thread;
323 
324     struct One {
325         ulong id;
326         SysTime time;
327     }
328 
329     auto db = Miniorm(":memory:");
330     db.run(buildSchema!One);
331 
332     const time = Clock.currTime;
333     Thread.sleep(1.dur!"msecs");
334 
335     db.run(insert!One.insert, One(0, Clock.currTime));
336 
337     auto ones = db.run(select!One).array;
338     ones.length.shouldEqual(1);
339     ones[0].time.shouldBeGreaterThan(time);
340 }
341 
342 unittest {
343     struct One {
344         ulong id;
345         string text;
346     }
347 
348     auto db = Miniorm(":memory:");
349     db.run(buildSchema!One);
350 
351     db.run(count!One).shouldEqual(0);
352     db.run!(AggregateInsert.yes)(insert!One.insert, iota(0, 10)
353             .map!(i => One(i * 100, "hello" ~ text(i))));
354     db.run(count!One).shouldEqual(10);
355 
356     auto ones = db.run(select!One).array;
357     assert(ones.length == 10);
358     assert(ones.all!(a => a.id < 100));
359     assert(db.lastInsertRowid == ones[$ - 1].id);
360 
361     db.run(delete_!One);
362     db.run(count!One).shouldEqual(0);
363 
364     import std.datetime;
365     import std.conv : to;
366 
367     db.run!(AggregateInsert.yes)(insertOrReplace!One, iota(0, 499)
368             .map!(i => One((i + 1) * 100, "hello" ~ text(i))));
369     ones = db.run(select!One).array;
370     assert(ones.length == 499);
371     assert(ones.all!(a => a.id >= 100));
372     assert(db.lastInsertRowid == ones[$ - 1].id);
373 }
374 
375 @("shall convert the database type to the enum when retrieving via select")
376 unittest {
377     static struct Foo {
378         enum MyEnum : string {
379             foo = "batman",
380             bar = "robin",
381         }
382 
383         ulong id;
384         MyEnum enum_;
385     }
386 
387     auto db = Miniorm(":memory:");
388     db.run(buildSchema!Foo);
389 
390     db.run(insert!Foo.insert, Foo(0, Foo.MyEnum.bar));
391     auto res = db.run(select!Foo).array;
392 
393     res.length.shouldEqual(1);
394     res[0].enum_.shouldEqual(Foo.MyEnum.bar);
395 }
396 
397 unittest {
398     struct Limit {
399         int min, max;
400     }
401 
402     struct Limits {
403         Limit volt, curr;
404     }
405 
406     struct Settings {
407         ulong id;
408         Limits limits;
409     }
410 
411     auto db = Miniorm(":memory:");
412     db.run(buildSchema!Settings);
413     assert(db.run(count!Settings) == 0);
414     db.run(insertOrReplace!Settings, Settings(10, Limits(Limit(0, 12), Limit(-10, 10))));
415     assert(db.run(count!Settings) == 1);
416 
417     db.run(insertOrReplace!Settings, Settings(10, Limits(Limit(0, 2), Limit(-3, 3))));
418     db.run(insertOrReplace!Settings, Settings(11, Limits(Limit(0, 11), Limit(-11, 11))));
419     db.run(insertOrReplace!Settings, Settings(12, Limits(Limit(0, 12), Limit(-12, 12))));
420 
421     assert(db.run(count!Settings) == 3);
422     assert(db.run(count!Settings.where(`"limits.volt.max" = 2`)) == 1);
423     assert(db.run(count!Settings.where(`"limits.volt.max" > 10`)) == 2);
424     db.run(delete_!Settings.where(`"limits.volt.max" < 10`));
425     assert(db.run(count!Settings) == 2);
426 }
427 
428 unittest {
429     struct Settings {
430         ulong id;
431         int[5] data;
432     }
433 
434     auto db = Miniorm(":memory:");
435     db.run(buildSchema!Settings);
436 
437     db.run(insert!Settings.insert, Settings(0, [1, 2, 3, 4, 5]));
438 
439     assert(db.run(count!Settings) == 1);
440     auto s = db.run(select!Settings).front;
441     assert(s.data == [1, 2, 3, 4, 5]);
442 }
443 
444 SysTime fromSqLiteDateTime(string raw_dt) {
445     import std.datetime : DateTime, UTC, Clock;
446     import std.format : formattedRead;
447 
448     try {
449         int year, month, day, hour, minute, second, msecs;
450         formattedRead(raw_dt, "%s-%s-%s %s:%s:%s.%s", year, month, day, hour,
451                 minute, second, msecs);
452         auto dt = DateTime(year, month, day, hour, minute, second);
453         return SysTime(dt, msecs.dur!"msecs", UTC());
454     } catch (Exception e) {
455         logger.trace(e.msg);
456         return Clock.currTime(UTC());
457     }
458 }
459 
460 string toSqliteDateTime(SysTime ts) {
461     import std.format;
462 
463     return format("%04s-%02s-%02s %02s:%02s:%02s.%s", ts.year,
464             cast(ushort) ts.month, ts.day, ts.hour, ts.minute, ts.second,
465             ts.fracSecs.total!"msecs");
466 }
467 
468 class SpinSqlTimeout : Exception {
469     this(string msg, string file = __FILE__, int line = __LINE__) @safe pure nothrow {
470         super(msg, file, line);
471     }
472 }
473 
474 /** Execute an SQL query until it succeeds.
475  *
476  * Note: If there are any errors in the query it will go into an infinite loop.
477  */
478 auto spinSql(alias query, alias logFn = logger.warning)(Duration timeout,
479         Duration minTime = 50.dur!"msecs", Duration maxTime = 150.dur!"msecs") {
480     import core.thread : Thread;
481     import std.datetime.stopwatch : StopWatch, AutoStart;
482     import std.exception : collectException;
483     import std.random : uniform;
484 
485     const sw = StopWatch(AutoStart.yes);
486 
487     while (sw.peek < timeout) {
488         try {
489             return query();
490         } catch (Exception e) {
491             logFn(e.msg).collectException;
492             // even though the database have a builtin sleep it still result in too much spam.
493             () @trusted {
494                 Thread.sleep(uniform(minTime.total!"msecs", maxTime.total!"msecs").dur!"msecs");
495             }();
496         }
497     }
498 
499     throw new SpinSqlTimeout(null);
500 }
501 
502 auto spinSql(alias query, alias logFn = logger.warning)() nothrow {
503     while (true) {
504         try {
505             return spinSql!(query, logFn)(Duration.max);
506         } catch (Exception e) {
507         }
508     }
509 }
510 
511 /// RAII handling of a transaction.
512 struct Transaction {
513     Database db;
514 
515     // can only do a rollback/commit if it has been constructed and thus
516     // executed begin.
517     enum State {
518         none,
519         rollback,
520         done,
521     }
522 
523     State st;
524 
525     this(Miniorm db) {
526         this(db.db);
527     }
528 
529     this(Database db) {
530         this.db = db;
531         spinSql!(() { db.begin; });
532         st = State.rollback;
533     }
534 
535     ~this() {
536         scope (exit)
537             st = State.done;
538         if (st == State.rollback) {
539             db.rollback;
540         }
541     }
542 
543     void commit() {
544         db.commit;
545         st = State.done;
546     }
547 
548     void rollback() {
549         scope (exit)
550             st = State.done;
551         if (st == State.rollback) {
552             db.rollback;
553         }
554     }
555 }
556 
557 struct RefCntStatement {
558     import std.exception : collectException;
559     import std.typecons : RefCounted, RefCountedAutoInitialize, refCounted;
560 
561     static struct Payload {
562         Statement* stmt;
563 
564         this(Statement* stmt) {
565             this.stmt = stmt;
566         }
567 
568         ~this() nothrow {
569             if (stmt is null)
570                 return;
571 
572             try {
573                 (*stmt).clearBindings;
574                 (*stmt).reset;
575             } catch (Exception e) {
576             }
577             stmt = null;
578         }
579     }
580 
581     RefCounted!(Payload, RefCountedAutoInitialize.no) rc;
582 
583     this(ref Statement stmt) @trusted {
584         rc = Payload(&stmt);
585     }
586 
587     ref Statement get() {
588         return *rc.refCountedPayload.stmt;
589     }
590 }