1 /**
2 Copyright: Copyright (c) 2017, Oleg Butko. All rights reserved.
3 Copyright: Copyright (c) 2018-2019, Joakim Brännström. All rights reserved.
4 License: MIT
5 Author: Joakim Brännström (joakim.brannstrom@gmx.com)
6 Author: Oleg Butko (deviator)
7 */
8 module miniorm.api;
9 
10 import core.time : dur;
11 import logger = std.experimental.logger;
12 import std.array : Appender;
13 import std.datetime : SysTime, Duration;
14 import std.range;
15 
16 import miniorm.exception;
17 import miniorm.queries;
18 
19 import d2sqlite3;
20 
21 version (unittest) {
22     import std.algorithm : map;
23     import unit_threaded.assertions;
24 }
25 
26 ///
27 struct Miniorm {
28     private LentCntStatement[string] cachedStmt;
29     private size_t cacheSize = 128;
30     /// True means that all queries are logged.
31     private bool log_;
32 
33     ///
34     private Database db;
35     alias getUnderlyingDb this;
36 
37     ref Database getUnderlyingDb() return  {
38         return db;
39     }
40 
41     ///
42     this(Database db) {
43         this.db = db;
44     }
45 
46     ///
47     this(string path, int flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE) {
48         this(Database(path, flags));
49     }
50 
51     ~this() {
52         cleanupCache;
53     }
54 
55     /// Start a RAII handled transaction.
56     Transaction transaction() {
57         return Transaction(db);
58     }
59 
60     void prepareCacheSize(size_t s) {
61         cacheSize = s;
62     }
63 
64     RefCntStatement prepare(string sql) {
65         if (cachedStmt.length > cacheSize) {
66             auto keys = appender!(string[])();
67             foreach (p; cachedStmt.byKeyValue) {
68                 // the statement is currently not lent to any user.
69                 if (p.value.count == 0) {
70                     keys.put(p.key);
71                 }
72             }
73 
74             foreach (k; keys.data) {
75                 cachedStmt[k].stmt.finalize;
76                 cachedStmt.remove(k);
77             }
78         }
79 
80         if (auto v = sql in cachedStmt) {
81             return RefCntStatement(*v);
82         }
83         auto r = db.prepare(sql);
84         cachedStmt[sql] = LentCntStatement(r);
85         return RefCntStatement(cachedStmt[sql]);
86     }
87 
88     /// Toggle logging.
89     void log(bool v) nothrow {
90         this.log_ = v;
91     }
92 
93     /// Returns: True if logging is activated
94     private bool isLog() {
95         return log_;
96     }
97 
98     private void cleanupCache() {
99         foreach (ref s; cachedStmt.byValue) {
100             s.stmt.finalize;
101         }
102         cachedStmt = null;
103     }
104 
105     void opAssign(ref typeof(this) rhs) {
106         cleanupCache;
107         db = rhs.db;
108     }
109 
110     void run(string sql, bool delegate(ResultRange) dg = null) {
111         if (isLog) {
112             logger.trace(sql);
113         }
114         db.run(sql, dg);
115     }
116 
117     void close() {
118         cleanupCache;
119         db.close();
120     }
121 
122     size_t run(T, Args...)(Count!T v, auto ref Args args) {
123         const sql = v.toSql.toString;
124         if (isLog) {
125             logger.trace(sql);
126         }
127         auto stmt = prepare(sql);
128         return executeCheck(stmt, sql, v.binds, args).front.front.as!size_t;
129     }
130 
131     auto run(T, Args...)(Select!T v, auto ref Args args) {
132         import std.algorithm : map;
133         import std.format : format;
134         import std.range : inputRangeObject;
135 
136         const sql = v.toSql.toString;
137         if (isLog) {
138             logger.trace(sql);
139         }
140 
141         auto stmt = prepare(sql);
142         auto result = executeCheck(stmt, sql, v.binds, args);
143 
144         static T qconv(typeof(result.front) e) {
145             import std.algorithm : min;
146             import std.conv : to;
147             import std.traits : isStaticArray, OriginalType;
148             import miniorm.api : fromSqLiteDateTime;
149             import miniorm.schema : fieldToCol;
150 
151             T ret;
152             static string rr() {
153                 string[] res;
154                 foreach (i, a; fieldToCol!("", T)()) {
155                     res ~= `{`;
156                     if (a.columnType == "DATETIME") {
157                         res ~= `{ ret.%1$s = fromSqLiteDateTime(e.peek!string(%2$d)); }`.format(a.identifier,
158                                 i);
159                     } else {
160                         res ~= q{alias ET = typeof(ret.%s);}.format(a.identifier);
161                         res ~= q{static if (isStaticArray!ET)};
162                         res ~= `
163                             {
164                                 auto ubval = e[%2$d].as!(ubyte[]);
165                                 auto etval = cast(typeof(ET.init[]))ubval;
166                                 auto ln = min(ret.%1$s.length, etval.length);
167                                 ret.%1$s[0..ln] = etval[0..ln];
168                             }
169                             `.format(a.identifier, i);
170                         res ~= q{else static if (is(ET == enum))};
171                         res ~= format(q{ret.%1$s = cast(ET) e.peek!ET(%2$d).to!ET;},
172                                 a.identifier, i);
173                         res ~= q{else};
174                         res ~= format(q{ret.%1$s = e.peek!ET(%2$d);}, a.identifier, i);
175                     }
176                     res ~= `}`;
177                 }
178                 return res.join("\n");
179             }
180 
181             mixin(rr());
182             return ret;
183         }
184 
185         return ResultRange2!(typeof(result))(stmt, result).map!qconv;
186     }
187 
188     void run(T, Args...)(Delete!T v, auto ref Args args) {
189         const sql = v.toSql.toString;
190         if (isLog) {
191             logger.trace(sql);
192         }
193         auto stmt = prepare(sql);
194         executeCheck(stmt, sql, v.binds, args);
195     }
196 
197     void run(T0, T1)(Insert!T0 v, T1[] arr...) if (!isInputRange!T1) {
198         procInsert(v, arr);
199     }
200 
201     void run(T, R)(Insert!T v, R rng) if (isInputRange!R) {
202         procInsert(v, rng);
203     }
204 
205     private void procInsert(T, R)(Insert!T q, R rng) {
206         import std.algorithm : among;
207 
208         // generate code for binding values in a struct to a prepared
209         // statement.
210         // Expects an external variable "n" to exist that keeps track of the
211         // index. This is requied when the binding is for multiple values.
212         // Expects the statement to be named "stmt".
213         // Expects the variable to read values from to be named "v".
214         // Indexing start from 1 according to the sqlite manual.
215         static string genBinding(T)(bool replace) {
216             import miniorm.schema : fieldToCol;
217 
218             string s;
219             foreach (i, v; fieldToCol!("", T)) {
220                 if (!replace && v.isPrimaryKey)
221                     continue;
222                 if (v.columnType == "DATETIME")
223                     s ~= "stmt.get.bind(n+1, v." ~ v.identifier ~ ".toUTC.toSqliteDateTime);";
224                 else
225                     s ~= "stmt.get.bind(n+1, v." ~ v.identifier ~ ");";
226                 s ~= "++n;";
227             }
228             return s;
229         }
230 
231         alias T = ElementType!R;
232 
233         const replace = q.query.opt == InsertOpt.InsertOrReplace;
234 
235         q = q.values(1);
236 
237         const sql = q.toSql.toString;
238 
239         auto stmt = prepare(sql);
240 
241         foreach (v; rng) {
242             int n;
243             if (replace) {
244                 mixin(genBinding!T(true));
245             } else {
246                 mixin(genBinding!T(false));
247             }
248             if (isLog) {
249                 logger.trace(sql, " -> ", v);
250             }
251             stmt.get.execute();
252             stmt.get.reset();
253         }
254     }
255 }
256 
257 /** Wheter one aggregated insert or multiple should be generated.
258  *
259  * no:
260  * ---
261  * INSERT INTO foo ('v0') VALUES (?)
262  * INSERT INTO foo ('v0') VALUES (?)
263  * INSERT INTO foo ('v0') VALUES (?)
264  * ---
265  *
266  * yes:
267  * ---
268  * INSERT INTO foo ('v0') VALUES (?) (?) (?)
269  * ---
270  */
271 enum AggregateInsert {
272     no,
273     yes
274 }
275 
276 version (unittest) {
277     import miniorm.schema;
278 
279     import std.conv : text, to;
280     import std.range;
281     import std.algorithm;
282     import std.datetime;
283     import std.array;
284     import std.stdio;
285 
286     import unit_threaded.assertions;
287 }
288 
289 @("shall operate on a database allocted in std.experimental.allocators without any errors")
290 unittest {
291     struct One {
292         ulong id;
293         string text;
294     }
295 
296     // TODO: fix this
297     //import std.experimental.allocator;
298     //import std.experimental.allocator.mallocator;
299     //import std.experimental.allocator.building_blocks.scoped_allocator;
300     //Microrm* db;
301     //ScopedAllocator!Mallocator scalloc;
302     //db = scalloc.make!Microrm(":memory:");
303     //scope (exit) {
304     //    db.close;
305     //    scalloc.dispose(db);
306     //}
307 
308     // TODO: replace the one below with the above code.
309     auto db = Miniorm(":memory:");
310     db.run(buildSchema!One);
311     db.run(insert!One.insert, iota(0, 10).map!(i => One(i * 100, "hello" ~ text(i))));
312     db.run(count!One).shouldEqual(10);
313 
314     auto ones = db.run(select!One).array;
315     ones.length.shouldEqual(10);
316     assert(ones.all!(a => a.id < 100));
317     db.getUnderlyingDb.lastInsertRowid.shouldEqual(ones[$ - 1].id);
318 
319     db.run(delete_!One);
320     db.run(count!One).shouldEqual(0);
321     db.run(insertOrReplace!One, iota(0, 499).map!(i => One((i + 1) * 100, "hello" ~ text(i))));
322     ones = db.run(select!One).array;
323     ones.length.shouldEqual(499);
324     assert(ones.all!(a => a.id >= 100));
325     db.lastInsertRowid.shouldEqual(ones[$ - 1].id);
326 }
327 
328 @("shall insert and extract datetime from the table")
329 unittest {
330     import std.datetime : Clock;
331     import core.thread : Thread;
332 
333     struct One {
334         ulong id;
335         SysTime time;
336     }
337 
338     auto db = Miniorm(":memory:");
339     db.run(buildSchema!One);
340 
341     const time = Clock.currTime;
342     Thread.sleep(1.dur!"msecs");
343 
344     db.run(insert!One.insert, One(0, Clock.currTime));
345 
346     auto ones = db.run(select!One).array;
347     ones.length.shouldEqual(1);
348     ones[0].time.shouldBeGreaterThan(time);
349 }
350 
351 unittest {
352     struct One {
353         ulong id;
354         string text;
355     }
356 
357     auto db = Miniorm(":memory:");
358     db.run(buildSchema!One);
359 
360     db.run(count!One).shouldEqual(0);
361     db.run(insert!One.insert, iota(0, 10).map!(i => One(i * 100, "hello" ~ text(i))));
362     db.run(count!One).shouldEqual(10);
363 
364     auto ones = db.run(select!One).array;
365     assert(ones.length == 10);
366     assert(ones.all!(a => a.id < 100));
367     assert(db.lastInsertRowid == ones[$ - 1].id);
368 
369     db.run(delete_!One);
370     db.run(count!One).shouldEqual(0);
371 
372     import std.datetime;
373     import std.conv : to;
374 
375     db.run(insertOrReplace!One, iota(0, 499).map!(i => One((i + 1) * 100, "hello" ~ text(i))));
376     ones = db.run(select!One).array;
377     assert(ones.length == 499);
378     assert(ones.all!(a => a.id >= 100));
379     assert(db.lastInsertRowid == ones[$ - 1].id);
380 }
381 
382 @("shall convert the database type to the enum when retrieving via select")
383 unittest {
384     static struct Foo {
385         enum MyEnum : string {
386             foo = "batman",
387             bar = "robin",
388         }
389 
390         ulong id;
391         MyEnum enum_;
392     }
393 
394     auto db = Miniorm(":memory:");
395     db.run(buildSchema!Foo);
396 
397     db.run(insert!Foo.insert, Foo(0, Foo.MyEnum.bar));
398     auto res = db.run(select!Foo).array;
399 
400     res.length.shouldEqual(1);
401     res[0].enum_.shouldEqual(Foo.MyEnum.bar);
402 }
403 
404 unittest {
405     struct Limit {
406         int min, max;
407     }
408 
409     struct Limits {
410         Limit volt, curr;
411     }
412 
413     struct Settings {
414         ulong id;
415         Limits limits;
416     }
417 
418     auto db = Miniorm(":memory:");
419     db.run(buildSchema!Settings);
420     assert(db.run(count!Settings) == 0);
421     db.run(insertOrReplace!Settings, Settings(10, Limits(Limit(0, 12), Limit(-10, 10))));
422     assert(db.run(count!Settings) == 1);
423 
424     db.run(insertOrReplace!Settings, Settings(10, Limits(Limit(0, 2), Limit(-3, 3))));
425     db.run(insertOrReplace!Settings, Settings(11, Limits(Limit(0, 11), Limit(-11, 11))));
426     db.run(insertOrReplace!Settings, Settings(12, Limits(Limit(0, 12), Limit(-12, 12))));
427 
428     assert(db.run(count!Settings) == 3);
429     assert(db.run(count!Settings.where(`"limits.volt.max" = :nr`, Bind("nr")), 2) == 1);
430     assert(db.run(count!Settings.where(`"limits.volt.max" > :nr`, Bind("nr")), 10) == 2);
431     db.run(count!Settings.where(`"limits.volt.max" > :nr`, Bind("nr"))
432             .and(`"limits.volt.max" < :topnr`, Bind("topnr")), 1, 12).shouldEqual(2);
433 
434     db.run(delete_!Settings.where(`"limits.volt.max" < :nr`, Bind("nr")), 10);
435     assert(db.run(count!Settings) == 2);
436 }
437 
438 unittest {
439     struct Settings {
440         ulong id;
441         int[5] data;
442     }
443 
444     auto db = Miniorm(":memory:");
445     db.run(buildSchema!Settings);
446 
447     db.run(insert!Settings.insert, Settings(0, [1, 2, 3, 4, 5]));
448 
449     assert(db.run(count!Settings) == 1);
450     auto s = db.run(select!Settings).front;
451     assert(s.data == [1, 2, 3, 4, 5]);
452 }
453 
454 SysTime fromSqLiteDateTime(string raw_dt) {
455     import std.datetime : DateTime, UTC, Clock;
456     import std.format : formattedRead;
457 
458     try {
459         int year, month, day, hour, minute, second, msecs;
460         formattedRead(raw_dt, "%s-%s-%s %s:%s:%s.%s", year, month, day, hour,
461                 minute, second, msecs);
462         auto dt = DateTime(year, month, day, hour, minute, second);
463         return SysTime(dt, msecs.dur!"msecs", UTC());
464     } catch (Exception e) {
465         logger.trace(e.msg);
466         return Clock.currTime(UTC());
467     }
468 }
469 
470 /// To ensure a consistency the time is always converted to UTC.
471 string toSqliteDateTime(SysTime ts_) {
472     import std.format;
473 
474     auto ts = ts_.toUTC;
475     return format("%04s-%02s-%02s %02s:%02s:%02s.%s", ts.year,
476             cast(ushort) ts.month, ts.day, ts.hour, ts.minute, ts.second,
477             ts.fracSecs.total!"msecs");
478 }
479 
480 class SpinSqlTimeout : Exception {
481     this(string msg, string file = __FILE__, int line = __LINE__) @safe pure nothrow {
482         super(msg, file, line);
483     }
484 }
485 
486 void silentLog(Args...)(auto ref Args args) {
487 }
488 
489 /** Execute an SQL query until it succeeds.
490  *
491  * Note: If there are any errors in the query it will go into an infinite loop.
492  */
493 auto spinSql(alias query, alias logFn = logger.warning)(Duration timeout, Duration minTime = 50.dur!"msecs",
494         Duration maxTime = 150.dur!"msecs", const string file = __FILE__, const size_t line = __LINE__) {
495     import core.thread : Thread;
496     import std.datetime : Clock;
497     import std.exception : collectException;
498     import std.format : format;
499     import std.random : uniform;
500 
501     const stopAfter = Clock.currTime + timeout;
502     // do not spam the log
503     auto nextLogMsg = Clock.currTime + maxTime * 10;
504 
505     while (Clock.currTime < stopAfter) {
506         try {
507             return query();
508         } catch (Exception e) {
509             if (Clock.currTime > nextLogMsg) {
510                 const location = format!" [%s:%s]"(file, line);
511                 logFn(e.msg, location).collectException;
512                 nextLogMsg += maxTime * 10;
513             }
514             // even though the database have a builtin sleep it still result in too much spam.
515             () @trusted {
516                 Thread.sleep(uniform(minTime.total!"msecs", maxTime.total!"msecs").dur!"msecs");
517             }();
518         }
519     }
520 
521     throw new SpinSqlTimeout(null);
522 }
523 
524 auto spinSql(alias query, alias logFn = logger.warning)(const string file = __FILE__,
525         const size_t line = __LINE__) nothrow {
526     while (true) {
527         try {
528             return spinSql!(query, logFn)(365.dur!"days", 50.dur!"msecs",
529                     150.dur!"msecs", file, line);
530         } catch (Exception e) {
531         }
532     }
533 }
534 
535 /// RAII handling of a transaction.
536 struct Transaction {
537     Database db;
538 
539     // can only do a rollback/commit if it has been constructed and thus
540     // executed begin.
541     enum State {
542         none,
543         rollback,
544         done,
545     }
546 
547     State st;
548 
549     this(Miniorm db) {
550         this(db.db);
551     }
552 
553     this(Database db) {
554         this.db = db;
555         spinSql!(() { db.begin; });
556         st = State.rollback;
557     }
558 
559     ~this() {
560         scope (exit)
561             st = State.done;
562         if (st == State.rollback) {
563             db.rollback;
564         }
565     }
566 
567     void commit() {
568         db.commit;
569         st = State.done;
570     }
571 
572     void rollback() {
573         scope (exit)
574             st = State.done;
575         if (st == State.rollback) {
576             db.rollback;
577         }
578     }
579 }
580 
581 /// A prepared statement is lent to the user. The refcnt takes care of
582 /// resetting the statement when the user is done with it.
583 struct RefCntStatement {
584     import std.exception : collectException;
585     import std.typecons : RefCounted, RefCountedAutoInitialize, refCounted;
586 
587     static struct Payload {
588         LentCntStatement* stmt;
589 
590         this(LentCntStatement* stmt) {
591             this.stmt = stmt;
592             stmt.count++;
593         }
594 
595         ~this() nothrow {
596             if (stmt is null)
597                 return;
598 
599             try {
600                 (*stmt).stmt.clearBindings;
601                 (*stmt).stmt.reset;
602             } catch (Exception e) {
603             }
604             stmt.count--;
605             stmt = null;
606         }
607     }
608 
609     RefCounted!(Payload, RefCountedAutoInitialize.no) rc;
610 
611     this(ref LentCntStatement stmt) @trusted {
612         rc = Payload(&stmt);
613     }
614 
615     ref Statement get() {
616         return rc.refCountedPayload.stmt.stmt;
617     }
618 }
619 
620 struct ResultRange2(T) {
621     RefCntStatement stmt;
622     T result;
623 
624     auto front() {
625         assert(!empty, "Can't get front of an empty range");
626         return result.front;
627     }
628 
629     void popFront() {
630         assert(!empty, "Can't pop front of an empty range");
631         result.popFront;
632     }
633 
634     bool empty() {
635         return result.empty;
636     }
637 }
638 
639 /// It is lent to a user and thus can't be finalized if the counter > 0.
640 private struct LentCntStatement {
641     Statement stmt;
642     long count;
643 }
644 
645 @("shall remove all statements that are not lent to a user when the cache is full")
646 unittest {
647     struct Settings {
648         ulong id;
649     }
650 
651     auto db = Miniorm(":memory:");
652     db.run(buildSchema!Settings);
653     db.prepareCacheSize = 1;
654 
655     { // reuse statement
656         auto s0 = db.prepare("select * from Settings");
657         auto s1 = db.prepare("select * from Settings");
658 
659         db.cachedStmt.length.shouldEqual(1);
660         db.cachedStmt["select * from Settings"].count.shouldEqual(2);
661     }
662     db.cachedStmt.length.shouldEqual(1);
663     db.cachedStmt["select * from Settings"].count.shouldEqual(0);
664 
665     { // a lent statement is not removed when the cache is full
666         auto s0 = db.prepare("select * from Settings");
667         auto s1 = db.prepare("select id from Settings");
668 
669         db.cachedStmt.length.shouldEqual(2);
670         ("select * from Settings" in db.cachedStmt).shouldBeTrue;
671         ("select id from Settings" in db.cachedStmt).shouldBeTrue;
672     }
673     db.cachedStmt.length.shouldEqual(2);
674 
675     { // statements not lent to a user is removed when the cache is full
676         auto s0 = db.prepare("select * from Settings");
677 
678         db.cachedStmt.length.shouldEqual(1);
679         ("select * from Settings" in db.cachedStmt).shouldBeTrue;
680         ("select id from Settings" in db.cachedStmt).shouldBeFalse;
681     }
682 }