1 /**
2 Copyright: Copyright (c) 2017, Oleg Butko. All rights reserved.
3 Copyright: Copyright (c) 2018-2019, Joakim Brännström. All rights reserved.
4 License: MIT
5 Author: Joakim Brännström (joakim.brannstrom@gmx.com)
6 Author: Oleg Butko (deviator)
7 */
8 module miniorm.api;
9 
10 import core.time : dur;
11 import logger = std.experimental.logger;
12 import std.array : Appender;
13 import std.datetime : SysTime, Duration;
14 import std.range;
15 
16 import miniorm.exception;
17 import miniorm.queries;
18 
19 import d2sqlite3;
20 
21 version (unittest) {
22     import std.algorithm : map;
23     import unit_threaded.assertions;
24 }
25 
26 ///
27 struct Miniorm {
28     private LentCntStatement[string] cachedStmt;
29     private size_t cacheSize = 128;
30     /// True means that all queries are logged.
31     private bool log_;
32 
33     ///
34     private Database db;
35     alias getUnderlyingDb this;
36 
37     ref Database getUnderlyingDb() return  {
38         return db;
39     }
40 
41     ///
42     this(Database db) {
43         this.db = db;
44     }
45 
46     ///
47     this(string path, int flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE) {
48         this(Database(path, flags));
49     }
50 
51     ~this() {
52         cleanupCache;
53     }
54 
55     /// Start a RAII handled transaction.
56     Transaction transaction() {
57         return Transaction(db);
58     }
59 
60     void prepareCacheSize(size_t s) {
61         cacheSize = s;
62     }
63 
64     RefCntStatement prepare(string sql) {
65         if (cachedStmt.length > cacheSize) {
66             auto keys = appender!(string[])();
67             foreach (p; cachedStmt.byKeyValue) {
68                 // the statement is currently not lent to any user.
69                 if (p.value.count == 0) {
70                     keys.put(p.key);
71                 }
72             }
73 
74             foreach (k; keys.data) {
75                 cachedStmt[k].stmt.finalize;
76                 cachedStmt.remove(k);
77             }
78         }
79 
80         if (auto v = sql in cachedStmt) {
81             return RefCntStatement(*v);
82         }
83         auto r = db.prepare(sql);
84         cachedStmt[sql] = LentCntStatement(r);
85         return RefCntStatement(cachedStmt[sql]);
86     }
87 
88     /// Toggle logging.
89     void log(bool v) nothrow {
90         this.log_ = v;
91     }
92 
93     /// Returns: True if logging is activated
94     private bool isLog() {
95         return log_;
96     }
97 
98     private void cleanupCache() {
99         foreach (ref s; cachedStmt.byValue) {
100             s.stmt.finalize;
101         }
102         cachedStmt = null;
103     }
104 
105     void opAssign(ref typeof(this) rhs) {
106         cleanupCache;
107         db = rhs.db;
108     }
109 
110     void run(string sql, bool delegate(ResultRange) dg = null) {
111         if (isLog) {
112             logger.trace(sql);
113         }
114         db.run(sql, dg);
115     }
116 
117     void close() {
118         cleanupCache;
119         db.close();
120     }
121 
122     size_t run(T, Args...)(Count!T v, auto ref Args args) {
123         const sql = v.toSql.toString;
124         if (isLog) {
125             logger.trace(sql);
126         }
127         auto stmt = prepare(sql);
128         return executeCheck(stmt, sql, v.binds, args).front.front.as!size_t;
129     }
130 
131     auto run(T, Args...)(Select!T v, auto ref Args args) {
132         import std.algorithm : map;
133         import std.format : format;
134         import std.range : inputRangeObject;
135 
136         const sql = v.toSql.toString;
137         if (isLog) {
138             logger.trace(sql);
139         }
140 
141         auto stmt = prepare(sql);
142         auto result = executeCheck(stmt, sql, v.binds, args);
143 
144         static T qconv(typeof(result.front) e) {
145             import std.algorithm : min;
146             import std.conv : to;
147             import std.traits : isStaticArray, OriginalType;
148             import miniorm.api : fromSqLiteDateTime;
149             import miniorm.schema : fieldToCol;
150 
151             T ret;
152             static string rr() {
153                 string[] res;
154                 foreach (i, a; fieldToCol!("", T)()) {
155                     res ~= `{`;
156                     if (a.columnType == "DATETIME") {
157                         res ~= `{ ret.%1$s = fromSqLiteDateTime(e.peek!string(%2$d)); }`.format(a.identifier,
158                                 i);
159                     } else {
160                         res ~= q{alias ET = typeof(ret.%s);}.format(a.identifier);
161                         res ~= q{static if (isStaticArray!ET)};
162                         res ~= `
163                             {
164                                 auto ubval = e[%2$d].as!(ubyte[]);
165                                 auto etval = cast(typeof(ET.init[]))ubval;
166                                 auto ln = min(ret.%1$s.length, etval.length);
167                                 ret.%1$s[0..ln] = etval[0..ln];
168                             }
169                             `.format(a.identifier, i);
170                         res ~= q{else static if (is(ET == enum))};
171                         res ~= format(q{ret.%1$s = cast(ET) e.peek!ET(%2$d).to!ET;},
172                                 a.identifier, i);
173                         res ~= q{else};
174                         res ~= format(q{ret.%1$s = e.peek!ET(%2$d);}, a.identifier, i);
175                     }
176                     res ~= `}`;
177                 }
178                 return res.join("\n");
179             }
180 
181             mixin(rr());
182             return ret;
183         }
184 
185         return ResultRange2!(typeof(result))(stmt, result).map!qconv;
186     }
187 
188     void run(T, Args...)(Delete!T v, auto ref Args args) {
189         const sql = v.toSql.toString;
190         if (isLog) {
191             logger.trace(sql);
192         }
193         auto stmt = prepare(sql);
194         executeCheck(stmt, sql, v.binds, args);
195     }
196 
197     void run(T0, T1)(Insert!T0 v, T1[] arr...) if (!isInputRange!T1) {
198         procInsert(v, arr);
199     }
200 
201     void run(T, R)(Insert!T v, R rng) if (isInputRange!R) {
202         procInsert(v, rng);
203     }
204 
205     private void procInsert(T, R)(Insert!T q, R rng) {
206         import std.algorithm : among;
207 
208         // generate code for binding values in a struct to a prepared
209         // statement.
210         // Expects an external variable "n" to exist that keeps track of the
211         // index. This is requied when the binding is for multiple values.
212         // Expects the statement to be named "stmt".
213         // Expects the variable to read values from to be named "v".
214         // Indexing start from 1 according to the sqlite manual.
215         static string genBinding(T)(bool replace) {
216             import miniorm.schema : fieldToCol;
217 
218             string s;
219             foreach (i, v; fieldToCol!("", T)) {
220                 if (!replace && v.isPrimaryKey)
221                     continue;
222                 if (v.columnType == "DATETIME")
223                     s ~= "stmt.get.bind(n+1, v." ~ v.identifier ~ ".toUTC.toSqliteDateTime);";
224                 else
225                     s ~= "stmt.get.bind(n+1, v." ~ v.identifier ~ ");";
226                 s ~= "++n;";
227             }
228             return s;
229         }
230 
231         alias T = ElementType!R;
232 
233         const replace = q.query.opt == InsertOpt.InsertOrReplace;
234 
235         q = q.values(1);
236 
237         const sql = q.toSql.toString;
238 
239         auto stmt = prepare(sql);
240 
241         foreach (v; rng) {
242             int n;
243             if (replace) {
244                 mixin(genBinding!T(true));
245             } else {
246                 mixin(genBinding!T(false));
247             }
248             if (isLog) {
249                 logger.trace(sql, " -> ", v);
250             }
251             stmt.get.execute();
252             stmt.get.reset();
253         }
254     }
255 }
256 
257 /** Wheter one aggregated insert or multiple should be generated.
258  *
259  * no:
260  * ---
261  * INSERT INTO foo ('v0') VALUES (?)
262  * INSERT INTO foo ('v0') VALUES (?)
263  * INSERT INTO foo ('v0') VALUES (?)
264  * ---
265  *
266  * yes:
267  * ---
268  * INSERT INTO foo ('v0') VALUES (?) (?) (?)
269  * ---
270  */
271 enum AggregateInsert {
272     no,
273     yes
274 }
275 
276 version (unittest) {
277     import miniorm.schema;
278 
279     import std.conv : text, to;
280     import std.range;
281     import std.algorithm;
282     import std.datetime;
283     import std.array;
284     import std.stdio;
285 
286     import unit_threaded.assertions;
287 }
288 
289 @("shall operate on a database allocted in std.experimental.allocators without any errors")
290 unittest {
291     struct One {
292         ulong id;
293         string text;
294     }
295 
296     // TODO: fix this
297     //import std.experimental.allocator;
298     //import std.experimental.allocator.mallocator;
299     //import std.experimental.allocator.building_blocks.scoped_allocator;
300     //Microrm* db;
301     //ScopedAllocator!Mallocator scalloc;
302     //db = scalloc.make!Microrm(":memory:");
303     //scope (exit) {
304     //    db.close;
305     //    scalloc.dispose(db);
306     //}
307 
308     // TODO: replace the one below with the above code.
309     auto db = Miniorm(":memory:");
310     db.run(buildSchema!One);
311     db.run(insert!One.insert, iota(0, 10).map!(i => One(i * 100, "hello" ~ text(i))));
312     db.run(count!One).shouldEqual(10);
313 
314     auto ones = db.run(select!One).array;
315     ones.length.shouldEqual(10);
316     assert(ones.all!(a => a.id < 100));
317     db.getUnderlyingDb.lastInsertRowid.shouldEqual(ones[$ - 1].id);
318 
319     db.run(delete_!One);
320     db.run(count!One).shouldEqual(0);
321     db.run(insertOrReplace!One, iota(0, 499).map!(i => One((i + 1) * 100, "hello" ~ text(i))));
322     ones = db.run(select!One).array;
323     ones.length.shouldEqual(499);
324     assert(ones.all!(a => a.id >= 100));
325     db.lastInsertRowid.shouldEqual(ones[$ - 1].id);
326 }
327 
328 @("shall insert and extract datetime from the table")
329 unittest {
330     import std.datetime : Clock;
331     import core.thread : Thread;
332 
333     struct One {
334         ulong id;
335         SysTime time;
336     }
337 
338     auto db = Miniorm(":memory:");
339     db.run(buildSchema!One);
340 
341     const time = Clock.currTime;
342     Thread.sleep(1.dur!"msecs");
343 
344     db.run(insert!One.insert, One(0, Clock.currTime));
345 
346     auto ones = db.run(select!One).array;
347     ones.length.shouldEqual(1);
348     ones[0].time.shouldBeGreaterThan(time);
349 }
350 
351 unittest {
352     struct One {
353         ulong id;
354         string text;
355     }
356 
357     auto db = Miniorm(":memory:");
358     db.run(buildSchema!One);
359 
360     db.run(count!One).shouldEqual(0);
361     db.run(insert!One.insert, iota(0, 10).map!(i => One(i * 100, "hello" ~ text(i))));
362     db.run(count!One).shouldEqual(10);
363 
364     auto ones = db.run(select!One).array;
365     assert(ones.length == 10);
366     assert(ones.all!(a => a.id < 100));
367     assert(db.lastInsertRowid == ones[$ - 1].id);
368 
369     db.run(delete_!One);
370     db.run(count!One).shouldEqual(0);
371 
372     import std.datetime;
373     import std.conv : to;
374 
375     db.run(insertOrReplace!One, iota(0, 499).map!(i => One((i + 1) * 100, "hello" ~ text(i))));
376     ones = db.run(select!One).array;
377     assert(ones.length == 499);
378     assert(ones.all!(a => a.id >= 100));
379     assert(db.lastInsertRowid == ones[$ - 1].id);
380 }
381 
382 @("shall convert the database type to the enum when retrieving via select")
383 unittest {
384     static struct Foo {
385         enum MyEnum : string {
386             foo = "batman",
387             bar = "robin",
388         }
389 
390         ulong id;
391         MyEnum enum_;
392     }
393 
394     auto db = Miniorm(":memory:");
395     db.run(buildSchema!Foo);
396 
397     db.run(insert!Foo.insert, Foo(0, Foo.MyEnum.bar));
398     auto res = db.run(select!Foo).array;
399 
400     res.length.shouldEqual(1);
401     res[0].enum_.shouldEqual(Foo.MyEnum.bar);
402 }
403 
404 unittest {
405     struct Limit {
406         int min, max;
407     }
408 
409     struct Limits {
410         Limit volt, curr;
411     }
412 
413     struct Settings {
414         ulong id;
415         Limits limits;
416     }
417 
418     auto db = Miniorm(":memory:");
419     db.run(buildSchema!Settings);
420     assert(db.run(count!Settings) == 0);
421     db.run(insertOrReplace!Settings, Settings(10, Limits(Limit(0, 12), Limit(-10, 10))));
422     assert(db.run(count!Settings) == 1);
423 
424     db.run(insertOrReplace!Settings, Settings(10, Limits(Limit(0, 2), Limit(-3, 3))));
425     db.run(insertOrReplace!Settings, Settings(11, Limits(Limit(0, 11), Limit(-11, 11))));
426     db.run(insertOrReplace!Settings, Settings(12, Limits(Limit(0, 12), Limit(-12, 12))));
427 
428     assert(db.run(count!Settings) == 3);
429     assert(db.run(count!Settings.where(`"limits.volt.max" = :nr`, Bind("nr")), 2) == 1);
430     assert(db.run(count!Settings.where(`"limits.volt.max" > :nr`, Bind("nr")), 10) == 2);
431     db.run(count!Settings.where(`"limits.volt.max" > :nr`, Bind("nr"))
432             .and(`"limits.volt.max" < :topnr`, Bind("topnr")), 1, 12).shouldEqual(2);
433 
434     db.run(delete_!Settings.where(`"limits.volt.max" < :nr`, Bind("nr")), 10);
435     assert(db.run(count!Settings) == 2);
436 }
437 
438 unittest {
439     struct Settings {
440         ulong id;
441         int[5] data;
442     }
443 
444     auto db = Miniorm(":memory:");
445     db.run(buildSchema!Settings);
446 
447     db.run(insert!Settings.insert, Settings(0, [1, 2, 3, 4, 5]));
448 
449     assert(db.run(count!Settings) == 1);
450     auto s = db.run(select!Settings).front;
451     assert(s.data == [1, 2, 3, 4, 5]);
452 }
453 
454 SysTime fromSqLiteDateTime(string raw_dt) {
455     import std.datetime : DateTime, UTC, Clock;
456     import std.format : formattedRead;
457 
458     try {
459         int year, month, day, hour, minute, second, msecs;
460         formattedRead(raw_dt, "%s-%s-%s %s:%s:%s.%s", year, month, day, hour,
461                 minute, second, msecs);
462         auto dt = DateTime(year, month, day, hour, minute, second);
463         return SysTime(dt, msecs.dur!"msecs", UTC());
464     } catch (Exception e) {
465         logger.trace(e.msg);
466         return Clock.currTime(UTC());
467     }
468 }
469 
470 /// To ensure a consistency the time is always converted to UTC.
471 string toSqliteDateTime(SysTime ts_) {
472     import std.format;
473 
474     auto ts = ts_.toUTC;
475     return format("%04s-%02s-%02s %02s:%02s:%02s.%s", ts.year,
476             cast(ushort) ts.month, ts.day, ts.hour, ts.minute, ts.second,
477             ts.fracSecs.total!"msecs");
478 }
479 
480 class SpinSqlTimeout : Exception {
481     this(string msg, string file = __FILE__, int line = __LINE__) @safe pure nothrow {
482         super(msg, file, line);
483     }
484 }
485 
486 /** Execute an SQL query until it succeeds.
487  *
488  * Note: If there are any errors in the query it will go into an infinite loop.
489  */
490 auto spinSql(alias query, alias logFn = logger.warning)(Duration timeout, Duration minTime = 50.dur!"msecs",
491         Duration maxTime = 150.dur!"msecs", const string file = __FILE__, const size_t line = __LINE__) {
492     import core.thread : Thread;
493     import std.datetime.stopwatch : StopWatch, AutoStart;
494     import std.exception : collectException;
495     import std.format : format;
496     import std.random : uniform;
497 
498     const sw = StopWatch(AutoStart.yes);
499     const location = format!" [%s:%s]"(file, line);
500 
501     while (sw.peek < timeout) {
502         try {
503             return query();
504         } catch (Exception e) {
505             logFn(e.msg, location).collectException;
506             // even though the database have a builtin sleep it still result in too much spam.
507             () @trusted {
508                 Thread.sleep(uniform(minTime.total!"msecs", maxTime.total!"msecs").dur!"msecs");
509             }();
510         }
511     }
512 
513     throw new SpinSqlTimeout(null);
514 }
515 
516 auto spinSql(alias query, alias logFn = logger.warning)(const string file = __FILE__,
517         const size_t line = __LINE__) nothrow {
518     while (true) {
519         try {
520             return spinSql!(query, logFn)(Duration.max, 50.dur!"msecs",
521                     150.dur!"msecs", file, line);
522         } catch (Exception e) {
523         }
524     }
525 }
526 
527 /// RAII handling of a transaction.
528 struct Transaction {
529     Database db;
530 
531     // can only do a rollback/commit if it has been constructed and thus
532     // executed begin.
533     enum State {
534         none,
535         rollback,
536         done,
537     }
538 
539     State st;
540 
541     this(Miniorm db) {
542         this(db.db);
543     }
544 
545     this(Database db) {
546         this.db = db;
547         spinSql!(() { db.begin; });
548         st = State.rollback;
549     }
550 
551     ~this() {
552         scope (exit)
553             st = State.done;
554         if (st == State.rollback) {
555             db.rollback;
556         }
557     }
558 
559     void commit() {
560         db.commit;
561         st = State.done;
562     }
563 
564     void rollback() {
565         scope (exit)
566             st = State.done;
567         if (st == State.rollback) {
568             db.rollback;
569         }
570     }
571 }
572 
573 /// A prepared statement is lent to the user. The refcnt takes care of
574 /// resetting the statement when the user is done with it.
575 struct RefCntStatement {
576     import std.exception : collectException;
577     import std.typecons : RefCounted, RefCountedAutoInitialize, refCounted;
578 
579     static struct Payload {
580         LentCntStatement* stmt;
581 
582         this(LentCntStatement* stmt) {
583             this.stmt = stmt;
584             stmt.count++;
585         }
586 
587         ~this() nothrow {
588             if (stmt is null)
589                 return;
590 
591             try {
592                 (*stmt).stmt.clearBindings;
593                 (*stmt).stmt.reset;
594             } catch (Exception e) {
595             }
596             stmt.count--;
597             stmt = null;
598         }
599     }
600 
601     RefCounted!(Payload, RefCountedAutoInitialize.no) rc;
602 
603     this(ref LentCntStatement stmt) @trusted {
604         rc = Payload(&stmt);
605     }
606 
607     ref Statement get() {
608         return rc.refCountedPayload.stmt.stmt;
609     }
610 }
611 
612 struct ResultRange2(T) {
613     RefCntStatement stmt;
614     T result;
615 
616     auto front() {
617         assert(!empty, "Can't get front of an empty range");
618         return result.front;
619     }
620 
621     void popFront() {
622         assert(!empty, "Can't pop front of an empty range");
623         result.popFront;
624     }
625 
626     bool empty() {
627         return result.empty;
628     }
629 }
630 
631 /// It is lent to a user and thus can't be finalized if the counter > 0.
632 private struct LentCntStatement {
633     Statement stmt;
634     long count;
635 }
636 
637 @("shall remove all statements that are not lent to a user when the cache is full")
638 unittest {
639     struct Settings {
640         ulong id;
641     }
642 
643     auto db = Miniorm(":memory:");
644     db.run(buildSchema!Settings);
645     db.prepareCacheSize = 1;
646 
647     { // reuse statement
648         auto s0 = db.prepare("select * from Settings");
649         auto s1 = db.prepare("select * from Settings");
650 
651         db.cachedStmt.length.shouldEqual(1);
652         db.cachedStmt["select * from Settings"].count.shouldEqual(2);
653     }
654     db.cachedStmt.length.shouldEqual(1);
655     db.cachedStmt["select * from Settings"].count.shouldEqual(0);
656 
657     { // a lent statement is not removed when the cache is full
658         auto s0 = db.prepare("select * from Settings");
659         auto s1 = db.prepare("select id from Settings");
660 
661         db.cachedStmt.length.shouldEqual(2);
662         ("select * from Settings" in db.cachedStmt).shouldBeTrue;
663         ("select id from Settings" in db.cachedStmt).shouldBeTrue;
664     }
665     db.cachedStmt.length.shouldEqual(2);
666 
667     { // statements not lent to a user is removed when the cache is full
668         auto s0 = db.prepare("select * from Settings");
669 
670         db.cachedStmt.length.shouldEqual(1);
671         ("select * from Settings" in db.cachedStmt).shouldBeTrue;
672         ("select id from Settings" in db.cachedStmt).shouldBeFalse;
673     }
674 }