1 module dpmatch.parser;
2 import pegged.grammar;
3 import std.stdio, std.format, std.array;
4 import std.algorithm, std..string, std.traits;
5 import dpmatch.util;
6 
7 enum DPMatchGrammar = `
8 DPMATCH:
9   PatternList < PatternListElement+
10   PatternListElement < "|" Pattern "->" PatternHandler
11 
12   Pattern < VariantPattern
13   VariantPattern < VariantPatternName VariantPatternBindings
14   VariantPatternBindings <  (:"(" VariantPatternName ("," VariantPatternName)* :")")?
15   VariantPatternArgs < "()" / :"(" VariantPattern ("," VariantPattern)* :")"
16   VariantPatternName <~ !Keyword [a-zA-Z_] [a-zA-Z0-9_]*
17 
18   PatternHandler <~ :"<""{" (!"}>" .)* "}":">"
19 
20   Keyword <~ "match"
21 `;
22 
23 mixin(grammar(DPMatchGrammar));
24 
25 enum ASTElemType {
26   tPatternList,
27   tPatternListElement,
28   tPattern,
29   tVariantPattern,
30   tVariantPatternBindings,
31   tVariantPatternArgs,
32   tVariantPatternName,
33   tPatternHandler
34 }
35 
36 interface ASTElement {
37   const ASTElemType etype();
38 }
39 
40 class PatternHandler : ASTElement {
41   string code;
42   this(string code) {
43     this.code = code;
44   }
45 
46   const ASTElemType etype() {
47     return ASTElemType.tPatternHandler;
48   }
49 }
50 
51 class VariantPatternName : ASTElement {
52   string name;
53   this(string name) {
54     this.name = name;
55   }
56 
57   const ASTElemType etype() {
58     return ASTElemType.tVariantPatternName;
59   }
60 
61   string getNameStr() {
62     return this.name;
63   }
64 }
65 
66 class VariantPatternBindings : ASTElement {
67   VariantPatternName[] bindings;
68   this(VariantPatternName[] bindings) {
69     this.bindings = bindings;
70   }
71 
72   const ASTElemType etype() {
73     return ASTElemType.tVariantPatternBindings;
74   }
75 
76   string[] bindingsStr() {
77     return this.bindings.map!(binding => binding.getNameStr()).array;
78   }
79 }
80 
81 class VariantPatternArgs : ASTElement {
82   VariantPattern[] args;
83   this(VariantPattern[] args) {
84     this.args = args;
85   }
86 
87   const ASTElemType etype() {
88     return ASTElemType.tVariantPatternArgs;
89   }
90 }
91 
92 enum PatternType {
93   pVariantPattern
94 }
95 
96 interface Pattern : ASTElement {
97   const PatternType ptype();
98 }
99 
100 class VariantPattern : Pattern {
101   VariantPatternName vp_name;
102   VariantPatternBindings bindings;
103 
104   this(VariantPatternName vp_name, VariantPatternBindings bindings) {
105     this.vp_name = vp_name;
106     this.bindings = bindings;
107   }
108 
109   string getNameStr() {
110     return this.vp_name.name;
111   }
112 
113   const ASTElemType etype() {
114     return ASTElemType.tVariantPattern;
115   }
116 
117   const PatternType ptype() {
118     return PatternType.pVariantPattern;
119   }
120 }
121 
122 class PatternListElement : ASTElement {
123   Pattern pattern;
124   PatternHandler handler;
125   this(Pattern pattern, PatternHandler handler) {
126     this.pattern = pattern;
127     this.handler = handler;
128   }
129 
130   const ASTElemType etype() {
131     return ASTElemType.tPatternListElement;
132   }
133 }
134 
135 class PatternList : ASTElement {
136   PatternListElement[] list;
137   this(PatternListElement[] list) {
138     this.list = list;
139   }
140 
141   const ASTElemType etype() {
142     return ASTElemType.tPatternList;
143   }
144 
145   PatternListElement[] getElems() {
146     return this.list;
147   }
148 }
149 
150 ASTElement buildAST(ParseTree p) {
151   /*
152   if (!__ctfe) {
153     writeln("p.name : ", p.name);
154   }
155   */
156 
157   final switch (p.name) {
158   case "DPMATCH":
159     return buildAST(p.children[0]);
160   case "DPMATCH.PatternList":
161     PatternListElement[] list;
162     foreach (child; p.children) {
163       PatternListElement elem = cast(PatternListElement)buildAST(child);
164       if (elem is null) {
165         throw new Error("Error in %s!".format(p.name));
166       }
167       list ~= elem;
168     }
169     return new PatternList(list);
170   case "DPMATCH.PatternListElement":
171     Pattern pattern = cast(Pattern)buildAST(p.children[0]);
172     if (pattern is null) {
173       throw new Error("Error in %s!".format(p.name));
174     }
175     PatternHandler handler = cast(PatternHandler)buildAST(p.children[1]);
176     if (handler is null) {
177       throw new Error("Error in %s!".format(p.name));
178     }
179     return new PatternListElement(pattern, handler);
180   case "DPMATCH.Pattern":
181     return buildAST(p.children[0]);
182   case "DPMATCH.VariantPattern":
183     VariantPatternBindings bindings;
184     if (p.children.length == 2) {
185       bindings = cast(VariantPatternBindings)buildAST(p.children[1]);
186       if (bindings is null) {
187         throw new Error("Error in %s!".format(p.name));
188       }
189     } else {
190       bindings = new VariantPatternBindings([]);
191     }
192     VariantPatternName vpn = cast(VariantPatternName)buildAST(p.children[0]);
193 
194     if (vpn is null) {
195       throw new Error("Error in %s!".format(p.name));
196     }
197     return new VariantPattern(vpn, bindings);
198   case "DPMATCH.VariantPatternBindings":
199     VariantPatternName[] bindings;
200     foreach (child; p.children) {
201       VariantPatternName v = cast(VariantPatternName)buildAST(child);
202       if (v is null) {
203         throw new Error("Error in %s!".format(p.name));
204       }
205       bindings ~= v;
206     }
207     return new VariantPatternBindings(bindings);
208   case "DPMATCH.VariantPatternName":
209     return new VariantPatternName(p.matches[0]);
210   case "DPMATCH.VariantPatternArgs":
211     VariantPattern[] args;
212     foreach (child; p.children) {
213       VariantPattern arg = cast(VariantPattern)buildAST(child);
214       if (arg is null) {
215         throw new Error("Error in %s!".format(p.name));
216       }
217       args ~= arg;
218     }
219     return new VariantPatternArgs(args);
220   case "DPMATCH.PatternHandler":
221     return new PatternHandler(p.matches[0]);
222   }
223 }
224 
225 string compileForDADT(DADTTypeType, alias __INTERNAL_PATTERN_MATCH_ARGUMENT)(
226     const ASTElement node, string INTERNAL_PATTERN_MATCH_ARGUMENT_NAME) {
227   if (node.etype != ASTElemType.tPatternList) {
228     throw new Error("compileForDADT accept only PatternList");
229   }
230   PatternList list = cast(PatternList)node;
231   PatternListElement[] elems = list.getElems();
232   string elems_code;
233 
234   string[] dadttypes;
235   foreach (elem; __traits(allMembers, DADTTypeType)) {
236     dadttypes ~= elem;
237   }
238   dadttypes.sort!"a<b";
239   string[] pattern_types;
240 
241   foreach (PatternListElement elem; elems) {
242     if (elem.pattern.ptype != PatternType.pVariantPattern) {
243       throw new Error("Error: compileForDADT only support pVariantPattern currently.");
244     }
245     VariantPattern vp = cast(VariantPattern)elem.pattern;
246     pattern_types ~= vp.getNameStr();
247   }
248   pattern_types.sort!"a<b";
249 
250   if (pattern_types != dadttypes) {
251     string cases_msg = dadttypes.filter!(dadttype => !pattern_types.canFind(dadttype))
252       .array
253       .map!(_case => "  %s".format(_case))
254       .join("\n");
255 
256     throw new Error(
257         "This pattern match is not exhausitve.\nHere is an example of a case that is not matched:\n%s".format(
258         cases_msg));
259   }
260 
261   // code generation
262   foreach (PatternListElement elem; elems) {
263     VariantPattern vp = cast(VariantPattern)elem.pattern;
264     VariantPatternBindings bindings = vp.bindings;
265     VariantPatternName vpn = vp.vp_name;
266 
267     string constructor_name = vpn.getNameStr();
268 
269     static if (isTemplate!(typeof(__INTERNAL_PATTERN_MATCH_ARGUMENT))) {
270       string[] constructor_args;
271 
272       foreach (arg; TemplateArgsOf!(typeof(__INTERNAL_PATTERN_MATCH_ARGUMENT))) {
273         constructor_args ~= arg.stringof;
274       }
275 
276       const constructor_str = "%s!(%s)".format(constructor_name, constructor_args.join(", "));
277     } else {
278       const constructor_str = constructor_name;
279     }
280 
281     string[] handler_args = bindings.bindingsStr().map!(arg => "\"%s\"".format(arg)).array;
282     string handler_body = elem.handler.code;
283 
284     elems_code ~= `
285   if ((cast(#{constructor_str}#)#{INTERNAL_PATTERN_MATCH_ARGUMENT_NAME}#) !is null) {
286     auto __INTERNAL_VALUE_CASTED = cast(#{constructor_str}#)#{INTERNAL_PATTERN_MATCH_ARGUMENT_NAME}#;
287     import std.algorithm, std.array, std.string;
288     import dpmatch.util;
289     enum original_members = getOriginalMembers!(#{constructor_str}#);
290     enum string[] handler_args = [#{handler_args}#];
291     enum binding_args = {
292       string[] binding_args;
293       foreach (i, member; original_members) {
294         binding_args ~= "typeof(%s) %s".format("__INTERNAL_VALUE_CASTED." ~ member, handler_args[i]);
295       }
296       return binding_args;
297     }();
298     enum call_args = original_members.map!(member => "__INTERNAL_VALUE_CASTED." ~ member).array.join(", ");
299 
300     mixin(q{enum __INTERNAL_BINDING = (%s) #{handler_body}#;}.format(binding_args.join(", ")));
301     mixin("enum __INTERNAL_BINDING_CALL = \"__INTERNAL_BINDING(%s)\";".format(call_args));
302     import std.traits;
303     static if (is(ReturnType!(__INTERNAL_BINDING) == void)) {
304       mixin("%s;".format(__INTERNAL_BINDING_CALL));
305       return;
306     } else {
307       mixin("return %s;".format(__INTERNAL_BINDING_CALL));
308     }
309   }
310 `.patternReplaceWithTable([
311         "constructor_str": constructor_str,
312         "INTERNAL_PATTERN_MATCH_ARGUMENT_NAME": INTERNAL_PATTERN_MATCH_ARGUMENT_NAME,
313         "handler_body": handler_body,
314         "handler_args": handler_args.join(", ")
315         ]);
316   }
317 
318   return `() {
319 %s
320   throw new Exception("Should not reach here");
321 }();`.format(elems_code);
322 }
323 
324 string patternMatchADT(alias __INTERNAL_PATTERN_MATCH_ARGUMENT, DADTTypeType, string def)() {
325   enum p = DPMATCH(def);
326   enum code = buildAST(p).compileForDADT!(DADTTypeType,
327         __INTERNAL_PATTERN_MATCH_ARGUMENT)(__INTERNAL_PATTERN_MATCH_ARGUMENT.stringof);
328   return code;
329 }
330 
331 string patternMatchADTReturn(alias __INTERNAL_PATTERN_MATCH_ARGUMENT, DADTTypeType, string def)() {
332   enum p = DPMATCH(def);
333   enum code = buildAST(p).compileForDADT!(DADTTypeType,
334         __INTERNAL_PATTERN_MATCH_ARGUMENT)(__INTERNAL_PATTERN_MATCH_ARGUMENT.stringof);
335   return "return %s".format(code);
336 }
337 
338 string patternMatchADTBind(alias __INTERNAL_PATTERN_MATCH_ARGUMENT,
339     DADTTypeType, string def, string target)() {
340   enum p = DPMATCH(def);
341   enum code = buildAST(p).compileForDADT!(DADTTypeType,
342         __INTERNAL_PATTERN_MATCH_ARGUMENT)(__INTERNAL_PATTERN_MATCH_ARGUMENT.stringof);
343   return "auto %s = %s".format(target, code);
344 }