1 module taggedalgebraic.visit;
2 
3 import taggedalgebraic.taggedalgebraic;
4 import taggedalgebraic.taggedunion;
5 
6 import std.meta : anySatisfy;
7 import std.traits : EnumMembers, isInstanceOf;
8 
9 /** Dispatches the value contained on a `TaggedUnion` or `TaggedAlgebraic` to a
10 	set of visitors.
11 
12 	A visitor can have one of three forms:
13 
14 	$(UL
15 		$(LI function or delegate taking a single typed parameter)
16 		$(LI function or delegate taking no parameters)
17 		$(LI function or delegate template taking any single parameter)
18 	)
19 
20 	....
21 */
22 template visit(VISITORS...) if (VISITORS.length > 0)
23 {
24 	auto visit(TU)(auto ref TU tu) if (isInstanceOf!(TaggedUnion, TU))
25 	{
26 		alias val = validateHandlers!(TU, VISITORS);
27 
28 		final switch (tu.kind)
29 		{
30 			static foreach (k; EnumMembers!(TU.Kind))
31 			{
32 		case k:
33 				{
34 					static if (isUnitType!(TU.FieldTypes[k]))
35 						alias T = void;
36 					else
37 						alias T = TU.FieldTypes[k];
38 					alias h = selectHandler!(T, VISITORS);
39 					static if (is(typeof(h) == typeof(null)))
40 						static assert(false, "No visitor defined for type type " ~ T.stringof);
41 					else static if (is(typeof(h) == string))
42 						static assert(false, h);
43 					else static if (is(T == void))
44 						return h();
45 					else
46 						return h(tu.value!k);
47 				}
48 			}
49 		}
50 	}
51 
52 	auto visit(U)(auto ref TaggedAlgebraic!U ta)
53 	{
54 		return visit(ta.get!(TaggedUnion!U));
55 	}
56 }
57 
58 ///
59 unittest
60 {
61 	static if (__VERSION__ >= 2081)
62 	{
63 		import std.conv : to;
64 
65 		union U
66 		{
67 			int number;
68 			string text;
69 		}
70 
71 		alias TU = TaggedUnion!U;
72 
73 		auto tu = TU.number(42);
74 		tu.visit!((int n) { assert(n == 42); }, (string s) { assert(false); });
75 
76 		assert(tu.visit!((v) => to!int(v)) == 42);
77 
78 		tu.setText("43");
79 
80 		assert(tu.visit!((v) => to!int(v)) == 43);
81 	}
82 }
83 
84 unittest
85 {
86 	// repeat test from TaggedUnion
87 	union U
88 	{
89 		Void none;
90 		int count;
91 		float length;
92 	}
93 
94 	TaggedAlgebraic!U u;
95 
96 	//
97 	static assert(is(typeof(u.visit!((int) {}, (float) {}, () {}))));
98 	static assert(is(typeof(u.visit!((_) {}, () {}))));
99 	static assert(is(typeof(u.visit!((_) {}, (float) {}, () {}))));
100 	static assert(is(typeof(u.visit!((float) {}, (_) {}, () {}))));
101 
102 	static assert(!is(typeof(u.visit!((_) {})))); // missing void handler
103 	static assert(!is(typeof(u.visit!(() {})))); // missing value handler
104 
105 	static assert(!is(typeof(u.visit!((_) {}, () {}, (string) {})))); // invalid typed handler
106 	static assert(!is(typeof(u.visit!((int) {}, (float) {}, () {}, () {})))); // duplicate void handler
107 	static assert(!is(typeof(u.visit!((_) {}, () {}, (_) {})))); // duplicate generic handler
108 	static assert(!is(typeof(u.visit!((int) {}, (float) {}, (float) {}, () {})))); // duplicate typed handler
109 
110 	// TODO: error out for superfluous generic handlers
111 	//static assert(!is(typeof(u.visit!((int) {}, (float) {}, () {}, (_) {})))); // superfluous generic handler
112 }
113 
114 unittest
115 {
116 	union U
117 	{
118 		Void none;
119 		int count;
120 		float length;
121 	}
122 
123 	TaggedUnion!U u;
124 
125 	//
126 	static assert(is(typeof(u.visit!((int) {}, (float) {}, () {}))));
127 	static assert(is(typeof(u.visit!((_) {}, () {}))));
128 	static assert(is(typeof(u.visit!((_) {}, (float) {}, () {}))));
129 	static assert(is(typeof(u.visit!((float) {}, (_) {}, () {}))));
130 
131 	static assert(!is(typeof(u.visit!((_) {})))); // missing void handler
132 	static assert(!is(typeof(u.visit!(() {})))); // missing value handler
133 
134 	static assert(!is(typeof(u.visit!((_) {}, () {}, (string) {})))); // invalid typed handler
135 	static assert(!is(typeof(u.visit!((int) {}, (float) {}, () {}, () {})))); // duplicate void handler
136 	static assert(!is(typeof(u.visit!((_) {}, () {}, (_) {})))); // duplicate generic handler
137 	static assert(!is(typeof(u.visit!((int) {}, (float) {}, (float) {}, () {})))); // duplicate typed handler
138 
139 	// TODO: error out for superfluous generic handlers
140 	//static assert(!is(typeof(u.visit!((int) {}, (float) {}, () {}, (_) {})))); // superfluous generic handler
141 }
142 
143 unittest
144 {
145 	// make sure that the generic handler is not instantiated with types for
146 	// which it doesn't compile
147 	class C
148 	{
149 	}
150 
151 	union U
152 	{
153 		int i;
154 		C c;
155 	}
156 
157 	TaggedUnion!U u;
158 	u.visit!((C c) => c !is null, (v) {
159 		static assert(is(typeof(v) == int));
160 		return v != 0;
161 	});
162 }
163 
164 /** The same as `visit`, except that failure to handle types is checked at runtime.
165 
166 	Instead of failing to compile, `tryVisit` will throw an `Exception` if none
167 	of the handlers is able to handle the value contained in `tu`.
168 */
169 template tryVisit(VISITORS...) if (VISITORS.length > 0)
170 {
171 	auto tryVisit(TU)(auto ref TU tu) if (isInstanceOf!(TaggedUnion, TU))
172 	{
173 		final switch (tu.kind)
174 		{
175 			static foreach (k; EnumMembers!(TU.Kind))
176 			{
177 		case k:
178 				{
179 					static if (isUnitType!(TU.FieldTypes[k]))
180 						alias T = void;
181 					else
182 						alias T = TU.FieldTypes[k];
183 					alias h = selectHandler!(T, VISITORS);
184 					static if (is(typeof(h) == typeof(null)))
185 						throw new Exception("Type " ~ T.stringof ~ " not handled by any visitor.");
186 					else static if (is(typeof(h) == string))
187 						static assert(false, h);
188 					else static if (is(T == void))
189 						return h();
190 					else
191 						return h(tu.value!k);
192 				}
193 			}
194 		}
195 	}
196 
197 	auto tryVisit(U)(auto ref TaggedAlgebraic!U ta)
198 	{
199 		return tryVisit(ta.get!(TaggedUnion!U));
200 	}
201 }
202 
203 ///
204 unittest
205 {
206 	import std.exception : assertThrown;
207 
208 	union U
209 	{
210 		int number;
211 		string text;
212 	}
213 
214 	alias TU = TaggedUnion!U;
215 
216 	auto tu = TU.number(42);
217 	tu.tryVisit!((int n) { assert(n == 42); });
218 	assertThrown(tu.tryVisit!((string s) { assert(false); }));
219 }
220 
221 // repeat from TaggedUnion
222 unittest
223 {
224 	import std.exception : assertThrown;
225 
226 	union U
227 	{
228 		int number;
229 		string text;
230 	}
231 
232 	alias TA = TaggedAlgebraic!U;
233 
234 	auto ta = TA(42);
235 	ta.tryVisit!((int n) { assert(n == 42); });
236 	assertThrown(ta.tryVisit!((string s) { assert(false); }));
237 }
238 
239 private template validateHandlers(TU, VISITORS...)
240 {
241 	import std.traits : isSomeFunction;
242 
243 	alias Types = TU.FieldTypes;
244 
245 	static foreach (int i; 0 .. VISITORS.length)
246 	{
247 		static assert(!is(VISITORS[i]) || isSomeFunction!(VISITORS[i]),
248 				"Visitor at index " ~ i.stringof
249 				~ " must be a function/delegate literal: " ~ VISITORS[i].stringof);
250 		static assert(anySatisfy!(matchesType!(VISITORS[i]), Types),
251 				"Visitor at index " ~ i.stringof
252 				~ " does not match any type of " ~ TU.FieldTypes.stringof);
253 	}
254 }
255 
256 private template matchesType(alias fun)
257 {
258 	import std.traits : ParameterTypeTuple, isSomeFunction;
259 
260 	template matchesType(T)
261 	{
262 		static if (isSomeFunction!fun)
263 		{
264 			alias Params = ParameterTypeTuple!fun;
265 			static if (Params.length == 0 && isUnitType!T)
266 				enum matchesType = true;
267 			else static if (Params.length == 1 && is(T == Params[0]))
268 				enum matchesType = true;
269 			else
270 				enum matchesType = false;
271 		}
272 		else static if (!isUnitType!T)
273 		{
274 			static if (__traits(compiles, fun!T) && isSomeFunction!(fun!T))
275 			{
276 				alias Params = ParameterTypeTuple!(fun!T);
277 				static if (Params.length == 1 && is(T == Params[0]))
278 					enum matchesType = true;
279 				else
280 					enum matchesType = false;
281 			}
282 			else
283 				enum matchesType = false;
284 		}
285 		else
286 			enum matchesType = false;
287 	}
288 }
289 
290 unittest
291 {
292 	class C
293 	{
294 	}
295 
296 	alias mt1 = matchesType!((C c) => true);
297 	alias mt2 = matchesType!((c) { static assert(!is(typeof(c) == C)); });
298 	static assert(mt1!C);
299 	static assert(!mt1!int);
300 	static assert(mt2!int);
301 	static assert(!mt2!C);
302 }
303 
304 private template selectHandler(T, VISITORS...)
305 {
306 	import std.traits : ParameterTypeTuple, isSomeFunction;
307 
308 	template typedIndex(int i, int matched_index = -1)
309 	{
310 		static if (i < VISITORS.length)
311 		{
312 			alias fun = VISITORS[i];
313 			static if (isSomeFunction!fun)
314 			{
315 				alias Params = ParameterTypeTuple!fun;
316 				static if (Params.length > 1)
317 					enum typedIndex = "Visitor at index " ~ i.stringof
318 						~ " must not take more than one parameter.";
319 				else static if (Params.length == 0 && is(T == void)
320 						|| Params.length == 1 && is(T == Params[0]))
321 				{
322 					static if (matched_index >= 0)
323 						enum typedIndex = "Vistor at index " ~ i.stringof
324 							~ " conflicts with visitor at index " ~ matched_index ~ ".";
325 					else
326 						enum typedIndex = typedIndex!(i + 1, i);
327 				}
328 				else
329 					enum typedIndex = typedIndex!(i + 1, matched_index);
330 			}
331 			else
332 				enum typedIndex = typedIndex!(i + 1, matched_index);
333 		}
334 		else
335 			enum typedIndex = matched_index;
336 	}
337 
338 	template genericIndex(int i, int matched_index = -1)
339 	{
340 		static if (i < VISITORS.length)
341 		{
342 			alias fun = VISITORS[i];
343 			static if (!isSomeFunction!fun)
344 			{
345 				static if (__traits(compiles, fun!T) && isSomeFunction!(fun!T))
346 				{
347 					static if (ParameterTypeTuple!(fun!T).length == 1)
348 					{
349 						static if (matched_index >= 0)
350 							enum genericIndex = "Only one generic visitor allowed";
351 						else
352 							enum genericIndex = genericIndex!(i + 1, i);
353 					}
354 					else
355 						enum genericIndex = "Generic visitor at index "
356 							~ i.stringof ~ " must have a single parameter.";
357 				}
358 				else
359 					enum genericIndex = "Visitor at index " ~ i.stringof ~ " (or its template instantiation with type "
360 						~ T.stringof ~ ") must be a valid function or delegate.";
361 			}
362 			else
363 				enum genericIndex = genericIndex!(i + 1, matched_index);
364 		}
365 		else
366 			enum genericIndex = matched_index;
367 	}
368 
369 	enum typed_index = typedIndex!0;
370 	static if (is(T == void))
371 		enum generic_index = -1;
372 	else
373 		enum generic_index = genericIndex!0;
374 
375 	static if (is(typeof(typed_index) == string))
376 		enum selectHandler = typed_index;
377 	else static if (is(typeof(generic_index == string)))
378 		enum selectHandler = generic_index;
379 	else static if (typed_index >= 0)
380 		alias selectHandler = VISITORS[typed_index];
381 	else static if (generic_index >= 0)
382 		alias selectHandler = VISITORS[generic_index];
383 	else
384 		enum selectHandler = null;
385 }