1 /++
2 A sum type for modern D.
3 
4 This module provides [SumType], an alternative to `std.variant.Algebraic` with
5 [match|improved pattern-matching], full attribute correctness (`pure`, `@safe`,
6 `@nogc`, and `nothrow` are inferred whenever possible), and no dependency on
7 runtime type information (`TypeInfo`).
8 
9 License: MIT
10 Authors: Paul Backus, Atila Neves
11 +/
12 module sumtype;
13 
14 /// $(H3 Basic usage)
15 @safe unittest {
16     import std.math: approxEqual;
17 
18     struct Fahrenheit { double degrees; }
19     struct Celsius { double degrees; }
20     struct Kelvin { double degrees; }
21 
22     alias Temperature = SumType!(Fahrenheit, Celsius, Kelvin);
23 
24     // Construct from any of the member types.
25     Temperature t1 = Fahrenheit(98.6);
26     Temperature t2 = Celsius(100);
27     Temperature t3 = Kelvin(273);
28 
29     // Use pattern matching to access the value.
30     pure @safe @nogc nothrow
31     Fahrenheit toFahrenheit(Temperature t)
32     {
33         return Fahrenheit(
34             t.match!(
35                 (Fahrenheit f) => f.degrees,
36                 (Celsius c) => c.degrees * 9.0/5 + 32,
37                 (Kelvin k) => k.degrees * 9.0/5 - 459.4
38             )
39         );
40     }
41 
42     assert(toFahrenheit(t1).degrees.approxEqual(98.6));
43     assert(toFahrenheit(t2).degrees.approxEqual(212));
44     assert(toFahrenheit(t3).degrees.approxEqual(32));
45 
46     // Use ref to modify the value in place.
47     pure @safe @nogc nothrow
48     void freeze(ref Temperature t)
49     {
50         t.match!(
51             (ref Fahrenheit f) => f.degrees = 32,
52             (ref Celsius c) => c.degrees = 0,
53             (ref Kelvin k) => k.degrees = 273
54         );
55     }
56 
57     freeze(t1);
58     assert(toFahrenheit(t1).degrees.approxEqual(32));
59 
60     // Use a catch-all handler to give a default result.
61     pure @safe @nogc nothrow
62     bool isFahrenheit(Temperature t)
63     {
64         return t.match!(
65             (Fahrenheit f) => true,
66             _ => false
67         );
68     }
69 
70     assert(isFahrenheit(t1));
71     assert(!isFahrenheit(t2));
72     assert(!isFahrenheit(t3));
73 }
74 
75 /** $(H3 Introspection-based matching)
76  *
77  * In the `length` and `horiz` functions below, the handlers for `match` do not
78  * specify the types of their arguments. Instead, matching is done based on how
79  * the argument is used in the body of the handler: any type with `x` and `y`
80  * properties will be matched by the `rect` handlers, and any type with `r` and
81  * `theta` properties will be matched by the `polar` handlers.
82  */
83 @safe unittest {
84     import std.math: approxEqual, cos, PI, sqrt;
85 
86     struct Rectangular { double x, y; }
87     struct Polar { double r, theta; }
88     alias Vector = SumType!(Rectangular, Polar);
89 
90     pure @safe @nogc nothrow
91     double length(Vector v)
92     {
93         return v.match!(
94             rect => sqrt(rect.x^^2 + rect.y^^2),
95             polar => polar.r
96         );
97     }
98 
99     pure @safe @nogc nothrow
100     double horiz(Vector v)
101     {
102         return v.match!(
103             rect => rect.x,
104             polar => polar.r * cos(polar.theta)
105         );
106     }
107 
108     Vector u = Rectangular(1, 1);
109     Vector v = Polar(1, PI/4);
110 
111     assert(length(u).approxEqual(sqrt(2.0)));
112     assert(length(v).approxEqual(1));
113     assert(horiz(u).approxEqual(1));
114     assert(horiz(v).approxEqual(sqrt(0.5)));
115 }
116 
117 /** $(H3 Arithmetic expression evaluator)
118  *
119  * This example makes use of the special placeholder type `This` to define a
120  * [https://en.wikipedia.org/wiki/Recursive_data_type|recursive data type]: an
121  * [https://en.wikipedia.org/wiki/Abstract_syntax_tree|abstract syntax tree] for
122  * representing simple arithmetic expressions.
123  */
124 @safe unittest {
125     import std.functional: partial;
126     import std.traits: EnumMembers;
127     import std.typecons: Tuple;
128 
129     enum Op : string
130     {
131         Plus  = "+",
132         Minus = "-",
133         Times = "*",
134         Div   = "/"
135     }
136 
137     // An expression is either
138     //  - a number,
139     //  - a variable, or
140     //  - a binary operation combining two sub-expressions.
141     alias Expr = SumType!(
142         double,
143         string,
144         Tuple!(Op, "op", This*, "lhs", This*, "rhs")
145     );
146 
147     // Shorthand for Tuple!(Op, "op", Expr*, "lhs", Expr*, "rhs"),
148     // the Tuple type above with Expr substituted for This.
149     alias BinOp = Expr.Types[2];
150 
151     // Factory function for number expressions
152     pure @safe
153     Expr* num(double value)
154     {
155         return new Expr(value);
156     }
157 
158     // Factory function for variable expressions
159     pure @safe
160     Expr* var(string name)
161     {
162         return new Expr(name);
163     }
164 
165     // Factory function for binary operation expressions
166     pure @safe
167     Expr* binOp(Op op, Expr* lhs, Expr* rhs)
168     {
169         return new Expr(BinOp(op, lhs, rhs));
170     }
171 
172     // Convenience wrappers for creating BinOp expressions
173     alias sum  = partial!(binOp, Op.Plus);
174     alias diff = partial!(binOp, Op.Minus);
175     alias prod = partial!(binOp, Op.Times);
176     alias quot = partial!(binOp, Op.Div);
177 
178     // Evaluate expr, looking up variables in env
179     pure @safe nothrow
180     double eval(Expr expr, double[string] env)
181     {
182         return expr.match!(
183             (double num) => num,
184             (string var) => env[var],
185             (BinOp bop) {
186                 double lhs = eval(*bop.lhs, env);
187                 double rhs = eval(*bop.rhs, env);
188                 final switch(bop.op) {
189                     static foreach(op; EnumMembers!Op) {
190                         case op:
191                             return mixin("lhs" ~ op ~ "rhs");
192                     }
193                 }
194             }
195         );
196     }
197 
198     // Return a "pretty-printed" representation of expr
199     @safe
200     string pprint(Expr expr)
201     {
202         import std.format;
203 
204         return expr.match!(
205             (double num) => "%g".format(num),
206             (string var) => var,
207             (BinOp bop) => "(%s %s %s)".format(
208                 pprint(*bop.lhs),
209                 bop.op,
210                 pprint(*bop.rhs)
211             )
212         );
213     }
214 
215     Expr* myExpr = sum(var("a"), prod(num(2), var("b")));
216     double[string] myEnv = ["a":3, "b":4, "c":7];
217 
218     assert(eval(*myExpr, myEnv) == 11);
219     assert(pprint(*myExpr) == "(a + (2 * b))");
220 }
221 
222 /// `This` placeholder, for use in self-referential types.
223 public import std.variant: This;
224 
225 import std.meta: NoDuplicates;
226 
227 /**
228  * A tagged union that can hold a single value from any of a specified set of
229  * types.
230  *
231  * The value in a `SumType` can be operated on using [match|pattern matching].
232  *
233  * The special type `This` can be used as a placeholder to create
234  * self-referential types, just like with `Algebraic`. See the
235  * [sumtype#arithmetic-expression-evaluator|"Arithmetic expression evaluator" example] for
236  * usage.
237  *
238  * A `SumType` is initialized by default to hold the `.init` value of its
239  * first member type, just like a regular union. The version identifier
240  * `SumTypeNoDefaultCtor` can be used to disable this behavior.
241  *
242  * To avoid ambiguity, duplicate types are not allowed (but see the
243  * [sumtype#basic-usage|"basic usage" example] for a workaround).
244  *
245  * Bugs:
246  *   Types with `@disable`d `opEquals` overloads cannot be members of a
247  *   `SumType`.
248  *
249  * See_Also: `std.variant.Algebraic`
250  */
251 struct SumType(TypeArgs...)
252 	if (is(NoDuplicates!TypeArgs == TypeArgs) && TypeArgs.length > 0)
253 {
254 	import std.meta: AliasSeq, Filter, anySatisfy, allSatisfy;
255 	import std.traits: hasElaborateCopyConstructor, hasElaborateDestructor, isAssignable, isCopyable;
256 	import std.typecons: ReplaceType;
257 
258 	/// The types a `SumType` can hold.
259 	alias Types = AliasSeq!(ReplaceType!(This, typeof(this), TypeArgs));
260 
261 private:
262 
263 	enum bool canHoldTag(T) = Types.length <= T.max;
264 	alias unsignedInts = AliasSeq!(ubyte, ushort, uint, ulong);
265 
266 	alias Tag = Filter!(canHoldTag, unsignedInts)[0];
267 
268 	union Storage
269 	{
270 		Types values;
271 
272 		static foreach (i, T; Types) {
273 			@trusted
274 			this()(auto ref T val)
275 			{
276 				import std.functional: forward;
277 
278 				static if (isCopyable!T) {
279 					values[i] = val;
280 				} else {
281 					values[i] = forward!val;
282 				}
283 			}
284 
285 			static if (isCopyable!T) {
286 				@trusted
287 				this()(auto ref const(T) val) const
288 				{
289 					values[i] = val;
290 				}
291 
292 				@trusted
293 				this()(auto ref immutable(T) val) immutable
294 				{
295 					values[i] = val;
296 				}
297 			} else {
298 				@disable this(const(T) val) const;
299 				@disable this(immutable(T) val) immutable;
300 			}
301 		}
302 	}
303 
304 	Tag tag;
305 	Storage storage;
306 
307 	@trusted
308 	ref inout(T) trustedGet(T)() inout
309 	{
310 		import std.meta: staticIndexOf;
311 
312 		enum tid = staticIndexOf!(T, Types);
313 		assert(tag == tid);
314 		return storage.values[tid];
315 	}
316 
317 public:
318 
319 	static foreach (i, T; Types) {
320 		/// Constructs a `SumType` holding a specific value.
321 		this()(auto ref T val)
322 		{
323 			import std.functional: forward;
324 
325 			static if (isCopyable!T) {
326 				storage = Storage(val);
327 			} else {
328 				storage = Storage(forward!val);
329 			}
330 
331 			tag = i;
332 		}
333 
334 		static if (isCopyable!T) {
335 			/// ditto
336 			this()(auto ref const(T) val) const
337 			{
338 				storage = const(Storage)(val);
339 				tag = i;
340 			}
341 
342 			/// ditto
343 			this()(auto ref immutable(T) val) immutable
344 			{
345 				storage = immutable(Storage)(val);
346 				tag = i;
347 			}
348 		} else {
349 			@disable this(const(T) val) const;
350 			@disable this(immutable(T) val) immutable;
351 		}
352 	}
353 
354 	version(SumTypeNoDefaultCtor) {
355 		@disable this();
356 	}
357 
358 	static foreach (i, T; Types) {
359 		static if (isAssignable!T) {
360 			/// Assigns a value to a `SumType`.
361 			void opAssign()(auto ref T rhs)
362 			{
363 				import std.functional: forward;
364 
365 				this.match!((ref value) {
366 					static if (hasElaborateDestructor!(typeof(value))) {
367 						destroy(value);
368 					}
369 				});
370 
371 				storage = Storage(forward!rhs);
372 				tag = i;
373 			}
374 		}
375 	}
376 
377 	/**
378 	 * Compares two `SumType`s for equality.
379 	 *
380 	 * Two `SumType`s are equal if they are the same kind of `SumType`, they
381 	 * contain values of the same type, and those values are equal.
382 	 */
383 	bool opEquals(const SumType rhs) const {
384 		return this.match!((ref value) {
385 			return rhs.match!((ref rhsValue) {
386 				static if (is(typeof(value) == typeof(rhsValue))) {
387 					return value == rhsValue;
388 				} else {
389 					return false;
390 				}
391 			});
392 		});
393 	}
394 
395 	// Workaround for dlang issue 19407
396 	static if (__traits(compiles, anySatisfy!(hasElaborateDestructor, Types))) {
397 		// If possible, include the destructor only when it's needed
398 		private enum includeDtor = anySatisfy!(hasElaborateDestructor, Types);
399 	} else {
400 		// If we can't tell, always include it, even when it does nothing
401 		private enum includeDtor = true;
402 	}
403 
404 	static if (includeDtor) {
405 		/// Calls the destructor of the `SumType`'s current value.
406 		~this()
407 		{
408 			this.match!((ref value) {
409 				static if (hasElaborateDestructor!(typeof(value))) {
410 					destroy(value);
411 				}
412 			});
413 		}
414 	}
415 
416 	static if (allSatisfy!(isCopyable, Types)) {
417 		static if (anySatisfy!(hasElaborateCopyConstructor, Types)) {
418 			/// Calls the postblit of the `SumType`'s current value.
419 			this(this)
420 			{
421 				this.match!((ref value) {
422 					static if (hasElaborateCopyConstructor!(typeof(value))) {
423 						value.__xpostblit;
424 					}
425 				});
426 			}
427 		}
428 	} else {
429 		@disable this(this);
430 	}
431 
432 	static if (allSatisfy!(isCopyable, Types)) {
433 		/// Returns a string representation of a `SumType`'s value.
434 		string toString(this T)() {
435 			import std.conv: text;
436 			return this.match!((auto ref value) {
437 				return value.text;
438 			});
439 		}
440 	}
441 }
442 
443 // Construction
444 @safe unittest {
445 	alias MySum = SumType!(int, float);
446 
447 	assert(__traits(compiles, MySum(42)));
448 	assert(__traits(compiles, MySum(3.14)));
449 }
450 
451 // Assignment
452 @safe unittest {
453 	alias MySum = SumType!(int, float);
454 
455 	MySum x = MySum(42);
456 
457 	assert(__traits(compiles, x = 3.14));
458 }
459 
460 // Self assignment
461 @safe unittest {
462 	alias MySum = SumType!(int, float);
463 
464 	MySum x = MySum(42);
465 	MySum y = MySum(3.14);
466 
467 	assert(__traits(compiles, y = x));
468 }
469 
470 // Equality
471 @safe unittest {
472 	alias MySum = SumType!(int, float);
473 
474 	MySum x = MySum(123);
475 	MySum y = MySum(123);
476 	MySum z = MySum(456);
477 	MySum w = MySum(123.0);
478 	MySum v = MySum(456.0);
479 
480 	assert(x == y);
481 	assert(x != z);
482 	assert(x != w);
483 	assert(x != v);
484 }
485 
486 // Imported types
487 @safe unittest {
488 	import std.typecons: Tuple;
489 
490 	assert(__traits(compiles, {
491 		alias MySum = SumType!(Tuple!(int, int));
492 	}));
493 }
494 
495 // const and immutable types
496 @safe unittest {
497 	assert(__traits(compiles, {
498 		alias MySum = SumType!(const(int[]), immutable(float[]));
499 	}));
500 }
501 
502 // Recursive types
503 @safe unittest {
504 	alias MySum = SumType!(This*);
505 	assert(is(MySum.Types[0] == MySum*));
506 }
507 
508 // Allowed types
509 @safe unittest {
510 	import std.meta: AliasSeq;
511 
512 	alias MySum = SumType!(int, float, This*);
513 
514 	assert(is(MySum.Types == AliasSeq!(int, float, MySum*)));
515 }
516 
517 // Works alongside Algebraic
518 @safe unittest {
519 	import std.variant;
520 
521 	alias Bar = Algebraic!(This*);
522 
523 	assert(is(Bar.AllowedTypes[0] == Bar*));
524 }
525 
526 // Types with destructors and postblits
527 @safe unittest {
528 	int copies;
529 
530 	struct Test
531 	{
532 		bool initialized = false;
533 
534 		this(this) { copies++; }
535 		~this() { if (initialized) copies--; }
536 	}
537 
538 	alias MySum = SumType!(int, Test);
539 
540 	Test t = Test(true);
541 
542 	{
543 		MySum x = t;
544 		assert(copies == 1);
545 	}
546 	assert(copies == 0);
547 
548 	{
549 		MySum x = 456;
550 		assert(copies == 0);
551 	}
552 	assert(copies == 0);
553 
554 	{
555 		MySum x = t;
556 		assert(copies == 1);
557 		x = 456;
558 		assert(copies == 0);
559 	}
560 
561 	{
562 		MySum x = 456;
563 		assert(copies == 0);
564 		x = t;
565 		assert(copies == 1);
566 	}
567 }
568 
569 // Doesn't destroy reference types
570 @safe unittest {
571 	bool destroyed;
572 
573 	class C
574 	{
575 		~this()
576 		{
577 			destroyed = true;
578 		}
579 	}
580 
581 	struct S
582 	{
583 		~this() {}
584 	}
585 
586 	alias MySum = SumType!(S, C);
587 
588 	C c = new C();
589 	{
590 		MySum x = c;
591 		destroyed = false;
592 	}
593 	assert(!destroyed);
594 
595 	{
596 		MySum x = c;
597 		destroyed = false;
598 		x = S();
599 		assert(!destroyed);
600 	}
601 }
602 
603 // Types with @disable this()
604 @safe unittest {
605 	struct NoInit
606 	{
607 		@disable this();
608 	}
609 
610 	alias MySum = SumType!(NoInit, int);
611 
612 	assert(!__traits(compiles, MySum()));
613 	assert(__traits(compiles, MySum(42)));
614 }
615 
616 // const SumTypes
617 @safe unittest {
618 	assert(__traits(compiles,
619 		const(SumType!(int[]))([1, 2, 3])
620 	));
621 }
622 
623 // Equality of const SumTypes
624 @safe unittest {
625 	alias MySum = SumType!int;
626 
627 	assert(__traits(compiles,
628 		const(MySum)(123) == const(MySum)(456)
629 	));
630 }
631 
632 // Compares reference types using value equality
633 @safe unittest {
634 	struct Field {}
635 	struct Struct { Field[] fields; }
636 	alias MySum = SumType!Struct;
637 
638 	auto a = MySum(Struct([Field()]));
639 	auto b = MySum(Struct([Field()]));
640 
641 	assert(a == b);
642 }
643 
644 // toString
645 @safe unittest {
646 	import std.conv: text;
647 
648 	static struct Int { int i; }
649 	static struct Double { double d; }
650 	alias Sum = SumType!(Int, Double);
651 
652 	assert(Sum(Int(42)).text == Int(42).text, Sum(Int(42)).text);
653 	assert(Sum(Double(33.3)).text == Double(33.3).text, Sum(Double(33.3)).text);
654 	assert((const(Sum)(Int(42))).text == (const(Int)(42)).text, (const(Sum)(Int(42))).text);
655 }
656 
657 // Github issue #16
658 @safe unittest {
659 	alias Node = SumType!(This[], string);
660 
661 	// override inference of @system attribute for cyclic functions
662 	assert((() @trusted =>
663 		Node([Node([Node("x")])])
664 		==
665 		Node([Node([Node("x")])])
666 	)());
667 }
668 
669 // Github issue #16 with const
670 @safe unittest {
671 	alias Node = SumType!(const(This)[], string);
672 
673 	// override inference of @system attribute for cyclic functions
674 	assert((() @trusted =>
675 		Node([Node([Node("x")])])
676 		==
677 		Node([Node([Node("x")])])
678 	)());
679 }
680 
681 // Stale pointers
682 version(none) {
683 @system unittest {
684 	import std.array: staticArray;
685 
686 	alias MySum = SumType!(ubyte, void*[2]);
687 
688 	MySum x = [null, cast(void*) 0x12345678];
689 	void** p = &x.trustedGet!(void*[2])[1];
690 	x = ubyte(123);
691 
692 	assert(*p != cast(void*) 0x12345678);
693 }
694 }
695 
696 // Exception-safe assignment
697 @safe unittest {
698 	struct A
699 	{
700 		int value = 123;
701 	}
702 
703 	struct B
704 	{
705 		int value = 456;
706 		this(this) { throw new Exception("oops"); }
707 	}
708 
709 	alias MySum = SumType!(A, B);
710 
711 	MySum x;
712 	try {
713 		x = B();
714 	} catch (Exception e) {}
715 
716 	assert(
717 		(x.tag == 0 && x.trustedGet!A.value == 123) ||
718 		(x.tag == 1 && x.trustedGet!B.value == 456)
719 	);
720 }
721 
722 // Types with @disable this(this)
723 @safe unittest {
724 	import std.algorithm.mutation: move;
725 
726 	struct NoCopy
727 	{
728 		@disable this(this);
729 	}
730 
731 	alias MySum = SumType!NoCopy;
732 
733 	NoCopy lval = NoCopy();
734 
735 	MySum x = NoCopy();
736 	MySum y = NoCopy();
737 
738 	assert(__traits(compiles, SumType!NoCopy(NoCopy())));
739 	assert(!__traits(compiles, SumType!NoCopy(lval)));
740 
741 	assert(__traits(compiles, y = NoCopy()));
742 	assert(__traits(compiles, y = move(x)));
743 	assert(!__traits(compiles, y = lval));
744 	assert(!__traits(compiles, y = x));
745 }
746 
747 // Github issue #22
748 @safe unittest {
749 	import std.typecons;
750 	assert(__traits(compiles, {
751 		static struct A {
752 			SumType!(Nullable!int) a = Nullable!int.init;
753 		}
754 	}));
755 }
756 
757 version(none) {
758 	// Known bug; needs fix for dlang issue 19458
759 	// Types with disabled opEquals
760 	@safe unittest {
761 		struct S
762 		{
763 			@disable bool opEquals(const S rhs) const;
764 		}
765 
766 		assert(__traits(compiles, SumType!S(S())));
767 	}
768 }
769 
770 version(none) {
771 	// Known bug; needs fix for dlang issue 19458
772 	@safe unittest {
773 		struct S
774 		{
775 			int i;
776 			bool opEquals(S rhs) { return i == rhs.i; }
777 		}
778 
779 		assert(__traits(compiles, SumType!S(S(123))));
780 	}
781 }
782 
783 /**
784  * Calls a type-appropriate function with the value held in a [SumType].
785  *
786  * For each possible type the [SumType] can hold, the given handlers are
787  * checked, in order, to see whether they accept a single argument of that type.
788  * The first one that does is chosen as the match for that type.
789  *
790  * Implicit conversions are not taken into account, except between
791  * differently-qualified versions of the same type. For example, a handler that
792  * accepts a `long` will not match the type `int`, but a handler that accepts a
793  * `const(int)[]` will match the type `immutable(int)[]`.
794  *
795  * Every type must have a matching handler, and every handler must match at
796  * least one type. This is enforced at compile time.
797  *
798  * Handlers may be functions, delegates, or objects with opCall overloads. If a
799  * function with more than one overload is given as a handler, all of the
800  * overloads are considered as potential matches.
801  *
802  * Templated handlers are also accepted, and will match any type for which they
803  * can be [implicitly instantiated](https://dlang.org/glossary.html#ifti). See
804  * [sumtype#introspection-based-matching|"Introspection-based matching"] for an
805  * example of templated handler usage.
806  *
807  * Returns:
808  *   The value returned from the handler that matches the currently-held type.
809  *
810  * See_Also: `std.variant.visit`
811  */
812 template match(handlers...)
813 {
814 	import std.typecons: Yes;
815 
816 	/**
817 	 * The actual `match` function.
818 	 *
819 	 * Params:
820 	 *   self = A [SumType] object
821 	 */
822 	auto match(Self)(auto ref Self self)
823 		if (is(Self : SumType!TypeArgs, TypeArgs...))
824 	{
825 		return self.matchImpl!(Yes.exhaustive, handlers);
826 	}
827 }
828 
829 /**
830  * Attempts to call a type-appropriate function with the value held in a
831  * [SumType], and throws on failure.
832  *
833  * Matches are chosen using the same rules as [match], but are not required to
834  * be exhaustive—in other words, a type is allowed to have no matching handler.
835  * If a type without a handler is encountered at runtime, a [MatchException]
836  * is thrown.
837  *
838  * Returns:
839  *   The value returned from the handler that matches the currently-held type,
840  *   if a handler was given for that type.
841  *
842  * Throws:
843  *   [MatchException], if the currently-held type has no matching handler.
844  *
845  * See_Also: `std.variant.tryVisit`
846  */
847 template tryMatch(handlers...)
848 {
849 	import std.typecons: No;
850 
851 	/**
852 	 * The actual `tryMatch` function.
853 	 *
854 	 * Params:
855 	 *   self = A [SumType] object
856 	 */
857 	auto tryMatch(Self)(auto ref Self self)
858 		if (is(Self : SumType!TypeArgs, TypeArgs...))
859 	{
860 		return self.matchImpl!(No.exhaustive, handlers);
861 	}
862 }
863 
864 /// Thrown by [tryMatch] when an unhandled type is encountered.
865 class MatchException : Exception
866 {
867 	pure @safe @nogc nothrow
868 	this(string msg, string file = __FILE__, size_t line = __LINE__)
869 	{
870 		super(msg, file, line);
871 	}
872 }
873 
874 /**
875  * Checks whether a handler can match a given type.
876  *
877  * See the documentation for [match] for a full explanation of how matches are
878  * chosen.
879  */
880 template canMatch(alias handler, T)
881 {
882 	private bool canMatchImpl()
883 	{
884 		import std.traits: hasMember, isCallable, isSomeFunction, Parameters;
885 
886 		// Include overloads even when called from outside of matchImpl
887 		alias realHandler = handlerWithOverloads!handler;
888 
889 		// immutable recursively overrides all other qualifiers, so the
890 		// right-hand side is true if and only if the two types are the
891 		// same when qualifiers are ignored.
892 		enum sameUpToQuals(T, U) = is(immutable(T) == immutable(U));
893 
894 		bool result = false;
895 
896 		static if (is(typeof((T arg) { realHandler(arg); }(T.init)))) {
897 			// Regular handlers
898 			static if (isCallable!realHandler) {
899 				// Functions and delegates
900 				static if (isSomeFunction!realHandler) {
901 					static if (sameUpToQuals!(T, Parameters!realHandler[0])) {
902 						result = true;
903 					}
904 				// Objects with overloaded opCall
905 				} else static if (hasMember!(typeof(realHandler), "opCall")) {
906 					static foreach (overload; __traits(getOverloads, typeof(realHandler), "opCall")) {
907 						static if (sameUpToQuals!(T, Parameters!overload[0])) {
908 							result = true;
909 						}
910 					}
911 				}
912 			// Generic handlers
913 			} else {
914 				result = true;
915 			}
916 		}
917 
918 		return result;
919 	}
920 
921 	/// True if `handler` is a potential match for `T`, otherwise false.
922 	enum bool canMatch = canMatchImpl;
923 }
924 
925 // Includes all overloads of the given handler
926 @safe unittest {
927 	static struct OverloadSet
928 	{
929 		static void fun(int n) {}
930 		static void fun(double d) {}
931 	}
932 
933 	assert(canMatch!(OverloadSet.fun, int));
934 	assert(canMatch!(OverloadSet.fun, double));
935 }
936 
937 import std.traits: isFunction;
938 
939 // An AliasSeq of a function's overloads
940 private template FunctionOverloads(alias fun)
941 	if (isFunction!fun)
942 {
943 	import std.meta: AliasSeq;
944 
945 	alias FunctionOverloads = AliasSeq!(
946 		__traits(getOverloads,
947 			__traits(parent, fun),
948 			__traits(identifier, fun)
949 		)
950 	);
951 }
952 
953 // A handler with an opCall overload for each overload of fun
954 private template overloadHandler(alias fun)
955 	if (isFunction!fun)
956 {
957 	struct OverloadHandler
958 	{
959 		import std.traits: Parameters, ReturnType;
960 
961 		static foreach(overload; FunctionOverloads!fun) {
962 			ReturnType!overload opCall(Parameters!overload args)
963 			{
964 				return overload(args);
965 			}
966 		}
967 	}
968 
969 	enum overloadHandler = OverloadHandler.init;
970 }
971 
972 // A handler that includes all overloads of the original handler, if applicable
973 private template handlerWithOverloads(alias handler)
974 {
975 	// Delegates and function pointers can't have overloads
976 	static if (isFunction!handler && FunctionOverloads!handler.length > 1) {
977 		alias handlerWithOverloads = overloadHandler!handler;
978 	} else {
979 		alias handlerWithOverloads = handler;
980 	}
981 }
982 
983 import std.typecons: Flag;
984 
985 private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
986 {
987 	auto matchImpl(Self)(auto ref Self self)
988 		if (is(Self : SumType!TypeArgs, TypeArgs...))
989 	{
990 		import std.meta: staticMap;
991 
992 		alias Types = self.Types;
993 		enum noMatch = size_t.max;
994 
995 		alias allHandlers = staticMap!(handlerWithOverloads, handlers);
996 
997 		pure size_t[Types.length] getHandlerIndices()
998 		{
999 			size_t[Types.length] indices;
1000 			indices[] = noMatch;
1001 
1002 			static foreach (tid, T; Types) {
1003 				static foreach (hid, handler; allHandlers) {
1004 					static if (canMatch!(handler, typeof(self.trustedGet!T()))) {
1005 						if (indices[tid] == noMatch) {
1006 							indices[tid] = hid;
1007 						}
1008 					}
1009 				}
1010 			}
1011 
1012 			return indices;
1013 		}
1014 
1015 		enum handlerIndices = getHandlerIndices;
1016 
1017 		final switch (self.tag) {
1018 			static foreach (tid, T; Types) {
1019 				case tid:
1020 					static if (handlerIndices[tid] != noMatch) {
1021 						return allHandlers[handlerIndices[tid]](self.trustedGet!T);
1022 					} else {
1023 						static if(exhaustive) {
1024 							static assert(false,
1025 								"No matching handler for type `" ~ T.stringof ~ "`");
1026 						} else {
1027 							throw new MatchException(
1028 								"No matching handler for type `" ~ T.stringof ~ "`");
1029 						}
1030 					}
1031 			}
1032 		}
1033 
1034 		assert(false); // unreached
1035 
1036 		import std.algorithm.searching: canFind;
1037 
1038 		// Check for unreachable handlers
1039 		static foreach (hid, handler; allHandlers) {
1040 			static assert(handlerIndices[].canFind(hid),
1041 				"handler `" ~ __traits(identifier, handler) ~ "` " ~
1042 				"of type `" ~ typeof(handler).stringof ~ "` " ~
1043 				"never matches"
1044 			);
1045 		}
1046 	}
1047 }
1048 
1049 // Matching
1050 @safe unittest {
1051 	alias MySum = SumType!(int, float);
1052 
1053 	MySum x = MySum(42);
1054 	MySum y = MySum(3.14);
1055 
1056 	assert(x.match!((int v) => true, (float v) => false));
1057 	assert(y.match!((int v) => false, (float v) => true));
1058 }
1059 
1060 // Missing handlers
1061 @safe unittest {
1062 	alias MySum = SumType!(int, float);
1063 
1064 	MySum x = MySum(42);
1065 
1066 	assert(!__traits(compiles, x.match!((int x) => true)));
1067 	assert(!__traits(compiles, x.match!()));
1068 }
1069 
1070 // No implicit converstion
1071 @safe unittest {
1072 	alias MySum = SumType!(int, float);
1073 
1074 	MySum x = MySum(42);
1075 
1076 	assert(!__traits(compiles,
1077 		x.match!((long v) => true, (float v) => false)
1078 	));
1079 }
1080 
1081 // Handlers with qualified parameters
1082 @safe unittest {
1083     alias MySum = SumType!(int[], float[]);
1084 
1085     MySum x = MySum([1, 2, 3]);
1086     MySum y = MySum([1.0, 2.0, 3.0]);
1087 
1088     assert(x.match!((const(int[]) v) => true, (const(float[]) v) => false));
1089     assert(y.match!((const(int[]) v) => false, (const(float[]) v) => true));
1090 }
1091 
1092 // Handlers for qualified types
1093 @safe unittest {
1094 	alias MySum = SumType!(immutable(int[]), immutable(float[]));
1095 
1096 	MySum x = MySum([1, 2, 3]);
1097 
1098 	assert(x.match!((immutable(int[]) v) => true, (immutable(float[]) v) => false));
1099 	assert(x.match!((const(int[]) v) => true, (const(float[]) v) => false));
1100 	// Tail-qualified parameters
1101 	assert(x.match!((immutable(int)[] v) => true, (immutable(float)[] v) => false));
1102 	assert(x.match!((const(int)[] v) => true, (const(float)[] v) => false));
1103 	// Generic parameters
1104 	assert(x.match!((immutable v) => true));
1105 	assert(x.match!((const v) => true));
1106 	// Unqualified parameters
1107 	assert(!__traits(compiles,
1108 		x.match!((int[] v) => true, (float[] v) => false)
1109 	));
1110 }
1111 
1112 // Delegate handlers
1113 @safe unittest {
1114 	alias MySum = SumType!(int, float);
1115 
1116 	int answer = 42;
1117 	MySum x = MySum(42);
1118 	MySum y = MySum(3.14);
1119 
1120 	assert(x.match!((int v) => v == answer, (float v) => v == answer));
1121 	assert(!y.match!((int v) => v == answer, (float v) => v == answer));
1122 }
1123 
1124 // Generic handler
1125 @safe unittest {
1126 	import std.math: approxEqual;
1127 
1128 	alias MySum = SumType!(int, float);
1129 
1130 	MySum x = MySum(42);
1131 	MySum y = MySum(3.14);
1132 
1133 	assert(x.match!(v => v*2) == 84);
1134 	assert(y.match!(v => v*2).approxEqual(6.28));
1135 }
1136 
1137 // Fallback to generic handler
1138 @safe unittest {
1139 	import std.conv: to;
1140 
1141 	alias MySum = SumType!(int, float, string);
1142 
1143 	MySum x = MySum(42);
1144 	MySum y = MySum("42");
1145 
1146 	assert(x.match!((string v) => v.to!int, v => v*2) == 84);
1147 	assert(y.match!((string v) => v.to!int, v => v*2) == 42);
1148 }
1149 
1150 // Multiple non-overlapping generic handlers
1151 @safe unittest {
1152 	import std.math: approxEqual;
1153 
1154 	alias MySum = SumType!(int, float, int[], char[]);
1155 
1156 	MySum x = MySum(42);
1157 	MySum y = MySum(3.14);
1158 	MySum z = MySum([1, 2, 3]);
1159 	MySum w = MySum(['a', 'b', 'c']);
1160 
1161 	assert(x.match!(v => v*2, v => v.length) == 84);
1162 	assert(y.match!(v => v*2, v => v.length).approxEqual(6.28));
1163 	assert(w.match!(v => v*2, v => v.length) == 3);
1164 	assert(z.match!(v => v*2, v => v.length) == 3);
1165 }
1166 
1167 // Structural matching
1168 @safe unittest {
1169 	struct S1 { int x; }
1170 	struct S2 { int y; }
1171 	alias MySum = SumType!(S1, S2);
1172 
1173 	MySum a = MySum(S1(0));
1174 	MySum b = MySum(S2(0));
1175 
1176 	assert(a.match!(s1 => s1.x + 1, s2 => s2.y - 1) == 1);
1177 	assert(b.match!(s1 => s1.x + 1, s2 => s2.y - 1) == -1);
1178 }
1179 
1180 // Separate opCall handlers
1181 @safe unittest {
1182 	struct IntHandler
1183 	{
1184 		bool opCall(int arg)
1185 		{
1186 			return true;
1187 		}
1188 	}
1189 
1190 	struct FloatHandler
1191 	{
1192 		bool opCall(float arg)
1193 		{
1194 			return false;
1195 		}
1196 	}
1197 
1198 	alias MySum = SumType!(int, float);
1199 
1200 	MySum x = MySum(42);
1201 	MySum y = MySum(3.14);
1202 	IntHandler handleInt;
1203 	FloatHandler handleFloat;
1204 
1205 	assert(x.match!(handleInt, handleFloat));
1206 	assert(!y.match!(handleInt, handleFloat));
1207 }
1208 
1209 // Compound opCall handler
1210 @safe unittest {
1211 	struct CompoundHandler
1212 	{
1213 		bool opCall(int arg)
1214 		{
1215 			return true;
1216 		}
1217 
1218 		bool opCall(float arg)
1219 		{
1220 			return false;
1221 		}
1222 	}
1223 
1224 	alias MySum = SumType!(int, float);
1225 
1226 	MySum x = MySum(42);
1227 	MySum y = MySum(3.14);
1228 	CompoundHandler handleBoth;
1229 
1230 	assert(x.match!handleBoth);
1231 	assert(!y.match!handleBoth);
1232 }
1233 
1234 // Ordered matching
1235 @safe unittest {
1236 	alias MySum = SumType!(int, float);
1237 
1238 	MySum x = MySum(42);
1239 
1240 	assert(x.match!((int v) => true, v => false));
1241 }
1242 
1243 // Non-exhaustive matching
1244 @system unittest {
1245 	import std.exception: assertThrown, assertNotThrown;
1246 
1247 	alias MySum = SumType!(int, float);
1248 
1249 	MySum x = MySum(42);
1250 	MySum y = MySum(3.14);
1251 
1252 	assertNotThrown!MatchException(x.tryMatch!((int n) => true));
1253 	assertThrown!MatchException(y.tryMatch!((int n) => true));
1254 }
1255 
1256 // Non-exhaustive matching in @safe code
1257 @safe unittest {
1258 	SumType!(int, float) x;
1259 
1260 	assert(__traits(compiles,
1261 		x.tryMatch!(
1262 			(int n) => n + 1,
1263 		)
1264 	));
1265 
1266 }
1267 
1268 // Handlers with ref parameters
1269 @safe unittest {
1270 	import std.math: approxEqual;
1271 	import std.meta: staticIndexOf;
1272 
1273 	alias Value = SumType!(long, double);
1274 
1275 	auto value = Value(3.14);
1276 
1277 	value.match!(
1278 		(long) {},
1279 		(ref double d) { d *= 2; }
1280 	);
1281 
1282 	assert(value.trustedGet!double.approxEqual(6.28));
1283 }
1284 
1285 // Unreachable handlers
1286 @safe unittest {
1287 	alias MySum = SumType!(int, string);
1288 
1289 	MySum s;
1290 
1291 	assert(!__traits(compiles,
1292 		s.match!(
1293 			(int _) => 0,
1294 			(string _) => 1,
1295 			(double _) => 2
1296 		)
1297 	));
1298 
1299 	assert(!__traits(compiles,
1300 		s.match!(
1301 			_ => 0,
1302 			(int _) => 1
1303 		)
1304 	));
1305 }
1306 
1307 // Unsafe handlers
1308 unittest {
1309 	SumType!(int, char*) x;
1310 
1311 	assert(!__traits(compiles, () @safe {
1312 		x.match!(
1313 			(ref int n) => &n,
1314 			_ => null,
1315 		);
1316 	}));
1317 
1318 	assert(__traits(compiles, () @system {
1319 		return x.match!(
1320 			(ref int n) => &n,
1321 			_ => null
1322 		);
1323 	}));
1324 }
1325 
1326 // Overloaded handlers
1327 @safe unittest {
1328 	static struct OverloadSet
1329 	{
1330 		static string fun(int i) { return "int"; }
1331 		static string fun(double d) { return "double"; }
1332 	}
1333 
1334 	alias MySum = SumType!(int, double);
1335 
1336 	MySum a = 42;
1337 	MySum b = 3.14;
1338 
1339 	assert(a.match!(OverloadSet.fun) == "int");
1340 	assert(b.match!(OverloadSet.fun) == "double");
1341 }
1342 
1343 // Overload sets that include SumType arguments
1344 @safe unittest {
1345 	alias Inner = SumType!(int, double);
1346 	alias Outer = SumType!(Inner, string);
1347 
1348 	static struct OverloadSet
1349 	{
1350 		@safe:
1351 		static string fun(int i) { return "int"; }
1352 		static string fun(double d) { return "double"; }
1353 		static string fun(string s) { return "string"; }
1354 		static string fun(Inner i) { return i.match!fun; }
1355 		static string fun(Outer o) { return o.match!fun; }
1356 	}
1357 
1358 	Outer a = Inner(42);
1359 	Outer b = Inner(3.14);
1360 	Outer c = "foo";
1361 
1362 	assert(OverloadSet.fun(a) == "int");
1363 	assert(OverloadSet.fun(b) == "double");
1364 	assert(OverloadSet.fun(c) == "string");
1365 }