1 /** 2 Copyright: Copyright (c) 2021, Joakim Brännström. All rights reserved. 3 License: MPL-2 4 Author: Joakim Brännström (joakim.brannstrom@gmx.com) 5 6 This Source Code Form is subject to the terms of the Mozilla Public License, 7 v.2.0. If a copy of the MPL was not distributed with this file, You can obtain 8 one at http://mozilla.org/MPL/2.0/. 9 10 Q-learning algorithm for training the schema generator. 11 12 Each mutation subtype has a state in range [0,100]. It determins the 13 probability that a mutant of that kind is part of the intermediate schema 14 generation. 15 16 The state is updated with feedback from if the schema successfully compiled and 17 executed the test suite OK. 18 */ 19 module dextool.plugin.mutate.backend.analyze.schema_ml; 20 21 import std.algorithm : joiner, clamp, min, max, filter; 22 import std.random : uniform01, MinstdRand0, unpredictableSeed; 23 import std.range : only; 24 25 import dextool.plugin.mutate.backend.type : Mutation; 26 import dextool.plugin.mutate.backend.database.type : SchemaStatus; 27 28 @safe: 29 30 struct SchemaQ { 31 import std.traits : EnumMembers; 32 import my.hash; 33 import my.path : Path; 34 35 static immutable MinState = 0; 36 static immutable MaxState = 1000; 37 static immutable LearnRate = 0.01; 38 39 alias StatusData = Mutation.Kind[]delegate(SchemaStatus); 40 41 MinstdRand0 rnd0; 42 int[Mutation.Kind][Checksum64] state; 43 Checksum64[Path] pathCache; 44 45 static auto make() { 46 return SchemaQ(MinstdRand0(unpredictableSeed)); 47 } 48 49 this(MinstdRand0 rnd) { 50 this.rnd0 = rnd; 51 } 52 53 this(typeof(state) st) { 54 rnd0 = MinstdRand0(unpredictableSeed); 55 state = st; 56 } 57 58 /// Deep copy. 59 SchemaQ dup() { 60 typeof(state) st; 61 foreach (a; state.byKeyValue) { 62 int[Mutation.Kind] v; 63 foreach (b; a.value.byKeyValue) { 64 v[b.key] = b.value; 65 } 66 st[a.key] = v; 67 } 68 return SchemaQ(st); 69 } 70 71 /// Add a state for the `p` if it doesn't exist. 72 void addIfNew(const Path p) { 73 if (checksum(p) !in state) { 74 state[checksum(p)] = (int[Mutation.Kind]).init; 75 } 76 } 77 78 /// Update the state for all mutants. 79 void update(const Path path, scope StatusData data) { 80 import std.math : round; 81 82 addIfNew(path); 83 const ch = checksum(path); 84 85 double[Mutation.Kind] change; 86 87 // punish 88 foreach (k; data(SchemaStatus.broken)) 89 change.update(k, () => (1.0 - LearnRate), (ref double x) { 90 x -= LearnRate; 91 }); 92 // reward 93 foreach (k; only(data(SchemaStatus.ok), data(SchemaStatus.allKilled)).joiner) 94 change.update(k, () => (1.0 + LearnRate), (ref double x) { 95 x += LearnRate; 96 }); 97 98 // apply change 99 foreach (v; change.byKeyValue.filter!(a => a.value != 1.0)) { 100 state[ch].update(v.key, () => cast(int) round(MaxState * v.value), (ref int x) { 101 if (v.value > 1.0) 102 x = max(x + 1, cast(int) round(x * v.value)); 103 else 104 x = min(x - 1, cast(int) round(x * v.value)); 105 }); 106 } 107 108 // fix probability to be max P(1) 109 foreach (k; change.byKey) { 110 state[ch].update(k, () => MaxState, (ref int x) { 111 x = clamp(x, MinState, MaxState); 112 }); 113 } 114 } 115 116 /** To allow those with zero probability to self heal give them a random +1 now and then. 117 */ 118 void scatterTick() { 119 foreach (p; state.byKeyValue) { 120 foreach (k; p.value.byKeyValue.filter!(a => a.value == 0 && uniform01(rnd0) < 0.05)) { 121 state[p.key][k.key] = 1; 122 } 123 } 124 } 125 126 /** Roll the dice to see if the mutant should be used. 127 * 128 * Params: 129 * p = path the mutant is located at. 130 * k = kind of mutant 131 * threshold = the mutants probability must be above the threshold 132 * otherwise it will automatically fail. 133 * 134 * Return: true if the roll is positive, use the mutant. 135 */ 136 bool use(const Path p, const Mutation.Kind k, const double threshold) { 137 const s = getState(p, k) / cast(double) MaxState; 138 return s >= threshold && uniform01(rnd0) < s; 139 } 140 141 /// Returns: true if the probability of success is zero. 142 bool isZero(const Path p, const Mutation.Kind k) { 143 return getState(p, k) == 0; 144 } 145 146 private Checksum64 checksum(const Path p) { 147 return pathCache.require(p, makeChecksum64(cast(const(ubyte)[]) p.toString)); 148 } 149 150 private int getState(const Path p, const Mutation.Kind k) { 151 if (auto st = checksum(p) in state) 152 return (*st).require(k, MaxState); 153 return MaxState; 154 } 155 } 156 157 struct Feature { 158 import my.hash; 159 160 // The path is extremly important because it allows the tool to clear out old data. 161 Checksum64 path; 162 163 Mutation.Kind kind; 164 165 Checksum64[] context; 166 167 size_t toHash() @safe pure nothrow const @nogc { 168 auto rval = (cast(size_t) kind).hashOf(); 169 rval = path.c0.hashOf(rval); 170 foreach (a; context) 171 rval = a.c0.hashOf(rval); 172 return rval; 173 } 174 } 175 176 @("shall update the table") 177 unittest { 178 import std.random : MinstdRand0; 179 import my.path : Path; 180 181 const foo = Path("foo"); 182 SchemaQ q; 183 q.rnd0 = MinstdRand0(42); 184 185 Mutation.Kind[] r1(SchemaStatus s) { 186 if (s == SchemaStatus.broken) 187 return [Mutation.Kind.rorLE]; 188 if (s == SchemaStatus.ok) 189 return [Mutation.Kind.rorLT]; 190 return null; 191 } 192 193 q.update(foo, &r1); 194 const ch = q.pathCache[foo]; 195 assert(q.state[ch][Mutation.Kind.rorLE] == 990); 196 assert(q.state[ch][Mutation.Kind.rorLT] == SchemaQ.MaxState); 197 198 Mutation.Kind[] r2(SchemaStatus s) { 199 if (s == SchemaStatus.broken) 200 return [Mutation.Kind.rorLE, Mutation.Kind.rorLT]; 201 if (s == SchemaStatus.allKilled) 202 return [Mutation.Kind.rorLT]; 203 return null; 204 } 205 206 q.update(foo, &r2); 207 assert(q.state[ch][Mutation.Kind.rorLE] == 980); 208 // in the last run it was one broken and one OK thus the change where 1.0. 209 assert(q.state[ch][Mutation.Kind.rorLT] == SchemaQ.MaxState); 210 } 211 212 struct SchemaSizeQ { 213 static immutable LearnRate = 0.01; 214 215 // Returns: an array of the nr of mutants schemas matching the condition. 216 alias StatusData = long[]delegate(SchemaStatus); 217 218 MinstdRand0 rnd0; 219 long minSize; 220 long maxSize; 221 long currentSize; 222 223 static auto make(const long minSize, const long maxSize) { 224 return SchemaSizeQ(MinstdRand0(unpredictableSeed), minSize, maxSize); 225 } 226 227 void update(scope StatusData data, const long totalMutants) { 228 import std.math : pow; 229 230 double newValue = currentSize; 231 scope (exit) 232 currentSize = clamp(cast(long) newValue, minSize, maxSize); 233 234 double adjust = 1.0; 235 // ensure there is at least some change even though there is rounding 236 // errors or some schemas are small. 237 long fixed; 238 foreach (const v; data(SchemaStatus.broken).filter!(a => a < currentSize)) { 239 adjust -= LearnRate * (cast(double) v / cast(double) totalMutants); 240 fixed--; 241 } 242 foreach (const v; only(data(SchemaStatus.allKilled), data(SchemaStatus.ok)).joiner.filter!( 243 a => a > currentSize)) { 244 adjust += LearnRate * (cast(double) v / cast(double) totalMutants); 245 fixed++; 246 } 247 newValue = newValue * adjust + fixed; 248 } 249 }