1 module unit_threaded.mock;
2 
3 import unit_threaded.from;
4 
5 alias Identity(alias T) = T;
6 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member));
7 
8 
9 string implMixinStr(T)() {
10     import std.array: join;
11     import std.format : format;
12     import std.range : iota;
13     import std.traits: functionAttributes, FunctionAttribute, Parameters, ReturnType, arity;
14     import std.conv: text;
15 
16     if(!__ctfe) return null;
17 
18     string[] lines;
19 
20     string getOverload(in string memberName, in int i) {
21         return `Identity!(__traits(getOverloads, T, "%s")[%s])`
22             .format(memberName, i);
23     }
24 
25     foreach(memberName; __traits(allMembers, T)) {
26 
27         static if(!isPrivate!(T, memberName)) {
28 
29             alias member = Identity!(__traits(getMember, T, memberName));
30 
31             static if(__traits(isVirtualMethod, member)) {
32                 foreach(i, overload; __traits(getOverloads, T, memberName)) {
33 
34                     static if(!(functionAttributes!member & FunctionAttribute.const_) &&
35                               !(functionAttributes!member & FunctionAttribute.const_)) {
36 
37                         enum overloadName = text(memberName, "_", i);
38 
39                         enum overloadString = getOverload(memberName, i);
40                         lines ~= "private alias %s_parameters = Parameters!(%s);".format(overloadName, overloadString);
41                         lines ~= "private alias %s_returnType = ReturnType!(%s);".format(overloadName, overloadString);
42 
43                         static if(functionAttributes!member & FunctionAttribute.nothrow_)
44                             enum tryIndent = "    ";
45                         else
46                             enum tryIndent = "";
47 
48                         static if(is(ReturnType!member == void))
49                             enum returnDefault = "";
50                         else {
51                             enum varName = overloadName ~ `_returnValues`;
52                             lines ~= `%s_returnType[] %s;`.format(overloadName, varName);
53                             lines ~= "";
54                             enum returnDefault = [`    if(` ~ varName ~ `.length > 0) {`,
55                                                   `        auto ret = ` ~ varName ~ `[0];`,
56                                                   `        ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`,
57                                                   `        return ret;`,
58                                                   `    } else`,
59                                                   `        return %s_returnType.init;`.format(overloadName)];
60                         }
61 
62                         lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~
63                             typeAndArgsParens!(Parameters!overload)(overloadName) ~ " " ~
64                             functionAttributesString!member ~ ` {`;
65 
66                         static if(functionAttributes!member & FunctionAttribute.nothrow_)
67                             lines ~= "try {";
68 
69                         lines ~= tryIndent ~ `    calledFuncs ~= "` ~ memberName ~ `";`;
70                         lines ~= tryIndent ~ `    calledValues ~= tuple` ~ argNamesParens(arity!member) ~ `.to!string;`;
71 
72                         static if(functionAttributes!member & FunctionAttribute.nothrow_)
73                             lines ~= "    } catch(Exception) {}";
74 
75                         lines ~= returnDefault;
76 
77                         lines ~= `}`;
78                         lines ~= "";
79                     }
80                 }
81             }
82         }
83     }
84 
85     return lines.join("\n");
86 }
87 
88 private string argNamesParens(int N) @safe pure {
89     if(!__ctfe) return null;
90     return "(" ~ argNames(N) ~ ")";
91 }
92 
93 private string argNames(int N) @safe pure {
94     import std.range;
95     import std.algorithm;
96     import std.conv;
97 
98     if(!__ctfe) return null;
99     return iota(N).map!(a => "arg" ~ a.to!string).join(", ");
100 }
101 
102 private string typeAndArgsParens(T...)(string prefix) {
103     import std.array;
104     import std.conv;
105     import std.format : format;
106 
107     if(!__ctfe) return null;
108 
109     string[] parts;
110 
111     foreach(i, t; T)
112         parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i);
113     return "(" ~ parts.join(", ") ~ ")";
114 }
115 
116 private string functionAttributesString(alias F)() {
117     import std.traits: functionAttributes, FunctionAttribute;
118     import std.array: join;
119 
120     if(!__ctfe) return null;
121 
122     string[] parts;
123 
124     const attrs = functionAttributes!F;
125 
126     if(attrs & FunctionAttribute.pure_) parts ~= "pure";
127     if(attrs & FunctionAttribute.nothrow_) parts ~= "nothrow";
128     if(attrs & FunctionAttribute.trusted) parts ~= "@trusted";
129     if(attrs & FunctionAttribute.safe) parts ~= "@safe";
130     if(attrs & FunctionAttribute.nogc) parts ~= "@nogc";
131     if(attrs & FunctionAttribute.system) parts ~= "@system";
132     // const and immutable can't be done since the mock needs
133     // to alter state
134     // if(attrs & FunctionAttribute.const_) parts ~= "const";
135     // if(attrs & FunctionAttribute.immutable_) parts ~= "immutable";
136     if(attrs & FunctionAttribute.shared_) parts ~= "shared";
137 
138     return parts.join(" ");
139 }
140 
141 mixin template MockImplCommon() {
142     bool _verified;
143     string[] expectedFuncs;
144     string[] calledFuncs;
145     string[] expectedValues;
146     string[] calledValues;
147 
148     void expect(string funcName, V...)(auto ref V values) {
149         import std.conv: to;
150         import std.typecons: tuple;
151 
152         expectedFuncs ~= funcName;
153         static if(V.length > 0)
154             expectedValues ~= tuple(values).to!string;
155         else
156             expectedValues ~= "";
157     }
158 
159     void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)(auto ref V values) {
160         expect!func(values);
161         verify(file, line);
162         _verified = false;
163     }
164 
165     void verify(string file = __FILE__, size_t line = __LINE__) @safe pure {
166         import std.range: repeat, take, join;
167         import std.conv: to;
168         import unit_threaded.should: fail, UnitTestException;
169 
170         if(_verified)
171             fail("Mock already _verified", file, line);
172 
173         _verified = true;
174 
175         for(int i = 0; i < expectedFuncs.length; ++i) {
176 
177             if(i >= calledFuncs.length)
178                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " did not happen", file, line);
179 
180             if(expectedFuncs[i] != calledFuncs[i])
181                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " but got " ~ calledFuncs[i] ~
182                      " instead",
183                      file, line);
184 
185             if(expectedValues[i] != calledValues[i] && expectedValues[i] != "")
186                 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i],
187                                              " ".repeat.take(expectedFuncs[i].length + 4).join ~
188                                              "instead of the expected " ~ expectedValues[i]] ,
189                                             file, line);
190         }
191     }
192 }
193 
194 private enum isString(alias T) = is(typeof(T) == string);
195 
196 struct Mock(T) {
197 
198     MockAbstract _impl;
199     alias _impl this;
200 
201     class MockAbstract: T {
202         import std.conv: to;
203         import std.traits: Parameters, ReturnType;
204         import std.typecons: tuple;
205 
206         //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n");
207         mixin(implMixinStr!T);
208         mixin MockImplCommon;
209     }
210 
211     this(int/* force constructor*/) {
212         _impl = new MockAbstract;
213     }
214 
215     ~this() pure @safe {
216         if(!_verified) verify;
217     }
218 
219     void returnValue(string funcName, V...)(V values) {
220         assertFunctionIsVirtual!funcName;
221         return returnValue!(0, funcName)(values);
222     }
223 
224     /**
225        This version takes overloads into account. i is the overload
226        index. e.g.:
227        ---------
228        interface Interface { void foo(int); void foo(string); }
229        auto m = mock!Interface;
230        m.returnValue!(0, "foo"); // int overload
231        m.returnValue!(1, "foo"); // string overload
232        ---------
233      */
234     void returnValue(int i, string funcName, V...)(V values) {
235         assertFunctionIsVirtual!funcName;
236         import std.conv: text;
237         enum varName = funcName ~ text(`_`, i, `_returnValues`);
238         foreach(v; values)
239             mixin(varName ~ ` ~=  v;`);
240     }
241 
242     private static void assertFunctionIsVirtual(string funcName)() {
243         alias member = Identity!(__traits(getMember, T, funcName));
244 
245         static assert(__traits(isVirtualMethod, member),
246                       "Cannot use returnValue on '" ~ funcName ~ "'");
247     }
248 }
249 
250 private string importsString(string module_, string[] Modules...) {
251     if(!__ctfe) return null;
252 
253     auto ret = `import ` ~ module_ ~ ";\n";
254     foreach(extraModule; Modules) {
255         ret ~= `import ` ~ extraModule ~ ";\n";
256     }
257     return ret;
258 }
259 
260 auto mock(T)() {
261     return Mock!T(0);
262 }
263 
264 
265 @("mock interface positive test no params")
266 @safe pure unittest {
267     interface Foo {
268         int foo(int, string) @safe pure;
269         void bar() @safe pure;
270     }
271 
272     int fun(Foo f) {
273         return 2 * f.foo(5, "foobar");
274     }
275 
276     auto m = mock!Foo;
277     m.expect!"foo";
278     fun(m);
279 }
280 
281 @("mock interface positive test with params")
282 @safe pure unittest {
283     import unit_threaded.asserts;
284 
285     interface Foo {
286         int foo(int, string) @safe pure;
287         void bar() @safe pure;
288     }
289 
290     int fun(Foo f) {
291         return 2 * f.foo(5, "foobar");
292     }
293 
294     {
295         auto m = mock!Foo;
296         m.expect!"foo"(5, "foobar");
297         fun(m);
298     }
299 
300     {
301         auto m = mock!Foo;
302         m.expect!"foo"(6, "foobar");
303         fun(m);
304         assertExceptionMsg(m.verify,
305                            `    source/unit_threaded/mock.d:123 - foo was called with unexpected Tuple!(int, string)(5, "foobar")` ~ "\n" ~
306                            `    source/unit_threaded/mock.d:123 -        instead of the expected Tuple!(int, string)(6, "foobar")`);
307     }
308 
309     {
310         auto m = mock!Foo;
311         m.expect!"foo"(5, "quux");
312         fun(m);
313         assertExceptionMsg(m.verify,
314                            `    source/unit_threaded/mock.d:123 - foo was called with unexpected Tuple!(int, string)(5, "foobar")` ~ "\n" ~
315                            `    source/unit_threaded/mock.d:123 -        instead of the expected Tuple!(int, string)(5, "quux")`);
316     }
317 }
318 
319 
320 @("mock interface negative test")
321 @safe pure unittest {
322     import unit_threaded.should;
323 
324     interface Foo {
325         int foo(int, string) @safe pure;
326     }
327 
328     auto m = mock!Foo;
329     m.expect!"foo";
330     m.verify.shouldThrowWithMessage("Expected nth 0 call to foo did not happen");
331 }
332 
333 // can't be in the unit test itself
334 version(unittest)
335 private class Class {
336     abstract int foo(int, string) @safe pure;
337     final int timesTwo(int i) @safe pure nothrow const { return i * 2; }
338     int timesThree(int i) @safe pure nothrow const { return i * 3; }
339     int timesThreeMutable(int i) @safe pure nothrow { return i * 3; }
340 }
341 
342 @("mock class positive test")
343 @safe pure unittest {
344 
345     int fun(Class f) {
346         return 2 * f.foo(5, "foobar");
347     }
348 
349     auto m = mock!Class;
350     m.expect!"foo";
351     fun(m);
352 }
353 
354 
355 @("mock interface multiple calls")
356 @safe pure unittest {
357     interface Foo {
358         int foo(int, string) @safe pure;
359         int bar(int) @safe pure;
360     }
361 
362     void fun(Foo f) {
363         f.foo(3, "foo");
364         f.bar(5);
365         f.foo(4, "quux");
366     }
367 
368     auto m = mock!Foo;
369     m.expect!"foo"(3, "foo");
370     m.expect!"bar"(5);
371     m.expect!"foo"(4, "quux");
372     fun(m);
373     m.verify;
374 }
375 
376 @("interface expectCalled")
377 @safe pure unittest {
378     interface Foo {
379         int foo(int, string) @safe pure;
380         void bar() @safe pure;
381     }
382 
383     int fun(Foo f) {
384         return 2 * f.foo(5, "foobar");
385     }
386 
387     auto m = mock!Foo;
388     fun(m);
389     m.expectCalled!"foo"(5, "foobar");
390 }
391 
392 @("interface return value")
393 @safe pure unittest {
394     import unit_threaded.should;
395 
396     interface Foo {
397         int timesN(int i) @safe pure;
398     }
399 
400     int fun(Foo f) {
401         return f.timesN(3) * 2;
402     }
403 
404     auto m = mock!Foo;
405     m.returnValue!"timesN"(42);
406     immutable res = fun(m);
407     res.shouldEqual(84);
408 }
409 
410 @("interface return values")
411 @safe pure unittest {
412     import unit_threaded.should;
413 
414     interface Foo {
415         int timesN(int i) @safe pure;
416     }
417 
418     int fun(Foo f) {
419         return f.timesN(3) * 2;
420     }
421 
422     auto m = mock!Foo;
423     m.returnValue!"timesN"(42, 12);
424     fun(m).shouldEqual(84);
425     fun(m).shouldEqual(24);
426     fun(m).shouldEqual(0);
427 }
428 
429 struct ReturnValues(string function_, T...) if(from!"std.meta".allSatisfy!(isValue, T)) {
430     alias funcName = function_;
431     alias Values = T;
432 
433     static auto values() {
434         typeof(T[0])[] ret;
435         foreach(val; T) {
436             ret ~= val;
437         }
438         return ret;
439     }
440 }
441 
442 enum isReturnValue(alias T) = is(T: ReturnValues!U, U...);
443 enum isValue(alias T) = is(typeof(T));
444 
445 
446 /**
447    Version of mockStruct that accepts 0 or more values of the same
448    type. Whatever function is called on it, these values will
449    be returned one by one. The limitation is that if more than one
450    function is called on the mock, they all return the same type
451  */
452 auto mockStruct(T...)(auto ref T returns) {
453 
454     struct Mock {
455 
456         MockImpl* _impl;
457         alias _impl this;
458 
459         static struct MockImpl {
460 
461             static if(T.length > 0) {
462                 alias FirstType = typeof(returns[0]);
463                 private FirstType[] _returnValues;
464             }
465 
466             mixin MockImplCommon;
467 
468             auto opDispatch(string funcName, V...)(auto ref V values) {
469 
470                 import std.conv: to;
471                 import std.typecons: tuple;
472 
473                 calledFuncs ~= funcName;
474                 calledValues ~= tuple(values).to!string;
475 
476                 static if(T.length > 0) {
477 
478                     if(_returnValues.length == 0) return typeof(_returnValues[0]).init;
479                     auto ret = _returnValues[0];
480                     _returnValues = _returnValues[1..$];
481                     return ret;
482                 }
483             }
484         }
485     }
486 
487     Mock m;
488     m._impl = new Mock.MockImpl;
489     static if(T.length > 0) {
490         foreach(r; returns)
491             m._impl._returnValues ~= r;
492     }
493 
494     return m;
495 }
496 
497 // /**
498 //    Version of mockStruct that accepts a compile-time mapping
499 //    of function name to return values. Each template parameter
500 //    must be a value of type `ReturnValues`
501 //  */
502 
503 auto mockStruct(T...)() if(T.length > 0 && from!"std.meta".allSatisfy!(isReturnValue, T)) {
504 
505     struct Mock {
506         mixin MockImplCommon;
507 
508         int[string] _retIndices;
509 
510         auto opDispatch(string funcName, V...)(auto ref V values) {
511 
512             import std.conv: to;
513             import std.typecons: tuple;
514 
515             calledFuncs ~= funcName;
516             calledValues ~= tuple(values).to!string;
517 
518             foreach(retVal; T) {
519                 static if(retVal.funcName == funcName) {
520                     return retVal.values[_retIndices[funcName]++];
521                 }
522             }
523         }
524 
525         auto lefoofoo() {
526             return T[0].values[_retIndices["greet"]++];
527         }
528 
529     }
530 
531     Mock mock;
532 
533     foreach(retVal; T) {
534         mock._retIndices[retVal.funcName] = 0;
535     }
536 
537     return mock;
538 }
539 
540 
541 @("mock struct positive")
542 @safe pure unittest {
543     void fun(T)(T t) {
544         t.foobar;
545     }
546     auto m = mockStruct;
547     m.expect!"foobar";
548     fun(m);
549     m.verify;
550 }
551 
552 @("mock struct negative")
553 @safe pure unittest {
554     import unit_threaded.asserts;
555 
556     auto m = mockStruct;
557     m.expect!"foobar";
558     assertExceptionMsg(m.verify,
559                        "    source/unit_threaded/mock.d:123 - Expected nth 0 call to foobar did not happen\n");
560 
561 }
562 
563 
564 @("mock struct values positive")
565 @safe pure unittest {
566     void fun(T)(T t) {
567         t.foobar(2, "quux");
568     }
569 
570     auto m = mockStruct;
571     m.expect!"foobar"(2, "quux");
572     fun(m);
573     m.verify;
574 }
575 
576 @("mock struct values negative")
577 @safe pure unittest {
578     import unit_threaded.asserts;
579 
580     void fun(T)(T t) {
581         t.foobar(2, "quux");
582     }
583 
584     auto m = mockStruct;
585     m.expect!"foobar"(3, "quux");
586     fun(m);
587     assertExceptionMsg(m.verify,
588                        "    source/unit_threaded/mock.d:123 - foobar was called with unexpected Tuple!(int, string)(2, \"quux\")\n" ~
589                        "    source/unit_threaded/mock.d:123 -           instead of the expected Tuple!(int, string)(3, \"quux\")");
590 }
591 
592 
593 @("struct return value")
594 @safe pure unittest {
595     import unit_threaded.should;
596 
597     int fun(T)(T f) {
598         return f.timesN(3) * 2;
599     }
600 
601     auto m = mockStruct(42, 12);
602     fun(m).shouldEqual(84);
603     fun(m).shouldEqual(24);
604     fun(m).shouldEqual(0);
605     m.expectCalled!"timesN";
606 }
607 
608 @("struct expectCalled")
609 @safe pure unittest {
610     void fun(T)(T t) {
611         t.foobar(2, "quux");
612     }
613 
614     auto m = mockStruct;
615     fun(m);
616     m.expectCalled!"foobar"(2, "quux");
617 }
618 
619 @("mockStruct different return types for different functions")
620 @safe pure unittest {
621     import unit_threaded.should: shouldEqual;
622     auto m = mockStruct!(ReturnValues!("length", 5),
623                          ReturnValues!("greet", "hello"));
624     m.length.shouldEqual(5);
625     m.greet("bar").shouldEqual("hello");
626     m.expectCalled!"length";
627     m.expectCalled!"greet"("bar");
628 }
629 
630 @("mockStruct different return types for different functions and multiple return values")
631 @safe pure unittest {
632     import unit_threaded.should: shouldEqual;
633     auto m = mockStruct!(ReturnValues!("length", 5, 3),
634                          ReturnValues!("greet", "hello", "g'day"));
635     m.length.shouldEqual(5);
636     m.expectCalled!"length";
637     m.length.shouldEqual(3);
638     m.expectCalled!"length";
639 
640     m.greet("bar").shouldEqual("hello");
641     m.expectCalled!"greet"("bar");
642     m.greet("quux").shouldEqual("g'day");
643     m.expectCalled!"greet"("quux");
644 }
645 
646 
647 @("const(ubyte)[] return type]")
648 @safe pure unittest {
649     interface Interface {
650         const(ubyte)[] fun();
651     }
652 
653     auto m = mock!Interface;
654 }
655 
656 @("safe pure nothrow")
657 @safe pure unittest {
658     interface Interface {
659         int twice(int i) @safe pure nothrow /*@nogc*/;
660     }
661     auto m = mock!Interface;
662 }
663 
664 @("issue 63")
665 @safe pure unittest {
666     import unit_threaded.should;
667 
668     interface InterfaceWithOverloads {
669         int func(int) @safe pure;
670         int func(string) @safe pure;
671     }
672     alias ov = Identity!(__traits(allMembers, InterfaceWithOverloads)[0]);
673     auto m = mock!InterfaceWithOverloads;
674     m.returnValue!(0, "func")(3); // int overload
675     m.returnValue!(1, "func")(7); // string overload
676     m.expect!"func"("foo");
677     m.func("foo").shouldEqual(7);
678     m.verify;
679 }
680 
681 
682 auto throwStruct(E = from!"unit_threaded.should".UnitTestException, R = void)() {
683 
684     struct Mock {
685 
686         R opDispatch(string funcName, string file = __FILE__, size_t line = __LINE__, V...)
687                     (auto ref V values) {
688             throw new E(funcName ~ " was called", file, line);
689         }
690     }
691 
692     return Mock();
693 }
694 
695 @("throwStruct default")
696 @safe pure unittest {
697     import unit_threaded.should: shouldThrow, UnitTestException;
698     auto m = throwStruct;
699     m.foo.shouldThrow!UnitTestException;
700     m.bar(1, "foo").shouldThrow!UnitTestException;
701 }
702 
703 version(testing_unit_threaded) {
704     class FooException: Exception {
705         import std.exception: basicExceptionCtors;
706         mixin basicExceptionCtors;
707     }
708 
709 
710     @("throwStruct custom")
711         @safe pure unittest {
712         import unit_threaded.should: shouldThrow;
713 
714         auto m = throwStruct!FooException;
715         m.foo.shouldThrow!FooException;
716         m.bar(1, "foo").shouldThrow!FooException;
717     }
718 }
719 
720 
721 @("throwStruct return value type")
722 @safe pure unittest {
723     import unit_threaded.asserts;
724     import unit_threaded.should: UnitTestException;
725     auto m = throwStruct!(UnitTestException, int);
726     int i;
727     assertExceptionMsg(i = m.foo,
728                        "    source/unit_threaded/mock.d:123 - foo was called");
729     assertExceptionMsg(i = m.bar,
730                        "    source/unit_threaded/mock.d:123 - bar was called");
731 }
732 
733 @("issue 68")
734 @safe pure unittest {
735     import unit_threaded.should;
736 
737     int fun(Class f) {
738         // f.timesTwo is mocked to return 2, no matter what's passed in
739         return f.timesThreeMutable(2);
740     }
741 
742     auto m = mock!Class;
743     m.expect!"timesThreeMutable"(2);
744     m.returnValue!("timesThreeMutable")(42);
745     fun(m).shouldEqual(42);
746 }
747 
748 @("issue69")
749 unittest {
750     import unit_threaded.should;
751 
752     static interface InterfaceWithOverloadedFuncs {
753         string over();
754         string over(string str);
755     }
756 
757     static class ClassWithOverloadedFuncs {
758         string over() { return "oops"; }
759         string over(string str) { return "oopsie"; }
760     }
761 
762     auto iMock = mock!InterfaceWithOverloadedFuncs;
763     iMock.returnValue!(0, "over")("bar");
764     iMock.returnValue!(1, "over")("baz");
765     iMock.over.shouldEqual("bar");
766     iMock.over("zing").shouldEqual("baz");
767 
768     auto cMock = mock!ClassWithOverloadedFuncs;
769     cMock.returnValue!(0, "over")("bar");
770     cMock.returnValue!(1, "over")("baz");
771     cMock.over.shouldEqual("bar");
772     cMock.over("zing").shouldEqual("baz");
773 }