1 /**
2 Copyright: Copyright (c) 2021, Joakim Brännström. All rights reserved.
3 License: MPL-2
4 Author: Joakim Brännström (joakim.brannstrom@gmx.com)
5 
6 This Source Code Form is subject to the terms of the Mozilla Public License,
7 v.2.0. If a copy of the MPL was not distributed with this file, You can obtain
8 one at http://mozilla.org/MPL/2.0/.
9 
10 Q-learning algorithm for training the schema generator.
11 
12 Each mutation subtype has a state in range [0,100]. It determins the
13 probability that a mutant of that kind is part of the intermediate schema
14 generation.
15 
16 The state is updated with feedback from if the schema successfully compiled and
17 executed the test suite OK.
18 */
19 module dextool.plugin.mutate.backend.analyze.schema_ml;
20 
21 import std.algorithm : joiner, clamp, min, max, filter;
22 import std.random : uniform01, MinstdRand0, unpredictableSeed;
23 import std.range : only;
24 
25 import dextool.plugin.mutate.backend.type : Mutation;
26 import dextool.plugin.mutate.backend.database.type : SchemaStatus;
27 
28 @safe:
29 
30 struct SchemaQ {
31     import std.traits : EnumMembers;
32     import my.hash;
33     import my.path : Path;
34 
35     static immutable MinState = 0;
36     static immutable MaxState = 1000;
37     static immutable LearnRate = 0.01;
38 
39     alias StatusData = Mutation.Kind[]delegate(SchemaStatus);
40 
41     MinstdRand0 rnd0;
42     int[Mutation.Kind][Checksum64] state;
43     Checksum64[Path] pathCache;
44 
45     static auto make() {
46         return SchemaQ(MinstdRand0(unpredictableSeed));
47     }
48 
49     this(MinstdRand0 rnd) {
50         this.rnd0 = rnd;
51     }
52 
53     this(typeof(state) st) {
54         rnd0 = MinstdRand0(unpredictableSeed);
55         state = st;
56     }
57 
58     /// Deep copy.
59     SchemaQ dup() {
60         typeof(state) st;
61         foreach (a; state.byKeyValue) {
62             int[Mutation.Kind] v;
63             foreach (b; a.value.byKeyValue) {
64                 v[b.key] = b.value;
65             }
66             st[a.key] = v;
67         }
68         return SchemaQ(st);
69     }
70 
71     /// Add a state for the `p` if it doesn't exist.
72     void addIfNew(const Path p) {
73         if (checksum(p) !in state) {
74             state[checksum(p)] = (int[Mutation.Kind]).init;
75         }
76     }
77 
78     /// Update the state for all mutants.
79     void update(const Path path, scope StatusData data) {
80         import std.math : round;
81 
82         addIfNew(path);
83         const ch = checksum(path);
84 
85         double[Mutation.Kind] change;
86 
87         // punish
88         foreach (k; data(SchemaStatus.broken))
89             change.update(k, () => (1.0 - LearnRate), (ref double x) {
90                 x -= LearnRate;
91             });
92         // reward
93         foreach (k; only(data(SchemaStatus.ok), data(SchemaStatus.allKilled)).joiner)
94             change.update(k, () => (1.0 + LearnRate), (ref double x) {
95                 x += LearnRate;
96             });
97 
98         // apply change
99         foreach (v; change.byKeyValue.filter!(a => a.value != 1.0)) {
100             state[ch].update(v.key, () => cast(int) round(MaxState * v.value), (ref int x) {
101                 if (v.value > 1.0)
102                     x = max(x + 1, cast(int) round(x * v.value));
103                 else
104                     x = min(x - 1, cast(int) round(x * v.value));
105             });
106         }
107 
108         // fix probability to be max P(1)
109         foreach (k; change.byKey) {
110             state[ch].update(k, () => MaxState, (ref int x) {
111                 x = clamp(x, MinState, MaxState);
112             });
113         }
114     }
115 
116     /** To allow those with zero probability to self heal give them a random +1 now and then.
117      */
118     void scatterTick() {
119         foreach (p; state.byKeyValue) {
120             foreach (k; p.value.byKeyValue.filter!(a => a.value == 0 && uniform01(rnd0) < 0.05)) {
121                 state[p.key][k.key] = 1;
122             }
123         }
124     }
125 
126     /** Roll the dice to see if the mutant should be used.
127      *
128      * Params:
129      *  p = path the mutant is located at.
130      *  k = kind of mutant
131      *  threshold = the mutants probability must be above the threshold
132      *  otherwise it will automatically fail.
133      *
134      * Return: true if the roll is positive, use the mutant.
135      */
136     bool use(const Path p, const Mutation.Kind k, const double threshold) {
137         const s = getState(p, k) / cast(double) MaxState;
138         return s >= threshold && uniform01(rnd0) < s;
139     }
140 
141     /// Returns: true if the probability of success is zero.
142     bool isZero(const Path p, const Mutation.Kind k) {
143         return getState(p, k) == 0;
144     }
145 
146     private Checksum64 checksum(const Path p) {
147         return pathCache.require(p, makeChecksum64(cast(const(ubyte)[]) p.toString));
148     }
149 
150     private int getState(const Path p, const Mutation.Kind k) {
151         if (auto st = checksum(p) in state)
152             return (*st).require(k, MaxState);
153         return MaxState;
154     }
155 }
156 
157 struct Feature {
158     import my.hash;
159 
160     // The path is extremly important because it allows the tool to clear out old data.
161     Checksum64 path;
162 
163     Mutation.Kind kind;
164 
165     Checksum64[] context;
166 
167     size_t toHash() @safe pure nothrow const @nogc {
168         auto rval = (cast(size_t) kind).hashOf();
169         rval = path.c0.hashOf(rval);
170         foreach (a; context)
171             rval = a.c0.hashOf(rval);
172         return rval;
173     }
174 }
175 
176 @("shall update the table")
177 unittest {
178     import std.random : MinstdRand0;
179     import my.path : Path;
180 
181     const foo = Path("foo");
182     SchemaQ q;
183     q.rnd0 = MinstdRand0(42);
184 
185     Mutation.Kind[] r1(SchemaStatus s) {
186         if (s == SchemaStatus.broken)
187             return [Mutation.Kind.rorLE];
188         if (s == SchemaStatus.ok)
189             return [Mutation.Kind.rorLT];
190         return null;
191     }
192 
193     q.update(foo, &r1);
194     const ch = q.pathCache[foo];
195     assert(q.state[ch][Mutation.Kind.rorLE] == 990);
196     assert(q.state[ch][Mutation.Kind.rorLT] == SchemaQ.MaxState);
197 
198     Mutation.Kind[] r2(SchemaStatus s) {
199         if (s == SchemaStatus.broken)
200             return [Mutation.Kind.rorLE, Mutation.Kind.rorLT];
201         if (s == SchemaStatus.allKilled)
202             return [Mutation.Kind.rorLT];
203         return null;
204     }
205 
206     q.update(foo, &r2);
207     assert(q.state[ch][Mutation.Kind.rorLE] == 980);
208     // in the last run it was one broken and one OK thus the change where 1.0.
209     assert(q.state[ch][Mutation.Kind.rorLT] == SchemaQ.MaxState);
210 }
211 
212 struct SchemaSizeQ {
213     static immutable LearnRate = 0.01;
214 
215     // Returns: an array of the nr of mutants schemas matching the condition.
216     alias StatusData = long[]delegate(SchemaStatus);
217 
218     MinstdRand0 rnd0;
219     long minSize;
220     long maxSize;
221     long currentSize;
222 
223     static auto make(const long minSize, const long maxSize) {
224         return SchemaSizeQ(MinstdRand0(unpredictableSeed), minSize, maxSize);
225     }
226 
227     void update(scope StatusData data, const long totalMutants) {
228         import std.math : pow;
229 
230         double newValue = currentSize;
231         scope (exit)
232             currentSize = clamp(cast(long) newValue, minSize, maxSize);
233 
234         double adjust = 1.0;
235         // ensure there is at least some change even though there is rounding
236         // errors or some schemas are small.
237         long fixed;
238         foreach (const v; data(SchemaStatus.broken).filter!(a => a < currentSize)) {
239             adjust -= LearnRate * (cast(double) v / cast(double) totalMutants);
240             fixed--;
241         }
242         foreach (const v; only(data(SchemaStatus.allKilled), data(SchemaStatus.ok)).joiner.filter!(
243                 a => a > currentSize)) {
244             adjust += LearnRate * (cast(double) v / cast(double) totalMutants);
245             fixed++;
246         }
247         newValue = newValue * adjust + fixed;
248     }
249 }