1 /**
2 Copyright: Copyright (c) 2020, Joakim Brännström. All rights reserved.
3 License: $(LINK2 http://www.boost.org/LICENSE_1_0.txt, Boost Software License 1.0)
4 Author: Joakim Brännström (joakim.brannstrom@gmx.com)
5 
6 This module contains some simple statistics functionality. It isn't intended to
7 be a full blown stat packaged, that is
8 [mir](http://mir-algorithm.libmir.org/mir_math_stat.html). I wrote this module
9 because I had problem using **mir** and only needed a small subset of the
10 functionality.
11 
12 The functions probably contain rounding errors etc so be aware. But it seems to
13 work well enough for simple needs.
14 */
15 module my.stat;
16 
17 import logger = std.experimental.logger;
18 import std.algorithm : map, sum, min, max, maxElement, sort;
19 import std.array : appender, array;
20 import std.ascii : newline;
21 import std.format : formattedWrite, format;
22 import std.math : pow, sqrt, log, ceil, floor, log10, isInfinity, SQRT2;
23 import std.random : uniform;
24 import std.range : isOutputRange, put, iota, take;
25 import std.stdio : writeln;
26 
27 @safe:
28 
29 /// Example:
30 unittest {
31     auto d0 = [3, 14, 18, 24, 29].makeData;
32 
33     writeln(basicStat(d0));
34 
35     writeln(histogram(d0, 3));
36     writeln(histogram(d0, 3).toBar);
37 
38     auto d1 = pdf(NormDistribution(0, 1)).take(10000).makeData;
39     writeln(basicStat(d1));
40     writeln(stdError(d1));
41 
42     auto hist = histogram(d1, 21);
43     writeln(hist.toBar);
44     writeln(hist.mode);
45 
46     writeln(cdf(NormDistribution(0, 1), 1) - cdf(NormDistribution(0, 1), -1));
47 }
48 
49 struct StatData {
50     double[] value;
51 
52     size_t length() {
53         return value.length;
54     }
55 }
56 
57 /// Convert user data to a representation useful for simple, statistics calculations.
58 StatData makeData(T)(T raw) {
59     import std.algorithm;
60 
61     double[] r = raw.map!(a => cast(double) a).array;
62     if (r.length <= 1)
63         throw new Exception("Too few samples");
64     return StatData(r);
65 }
66 
67 struct Mean {
68     double value;
69 }
70 
71 Mean mean(StatData data) {
72     const N = cast(double) data.length;
73     return Mean(data.value.sum / N);
74 }
75 
76 /// According to wikipedia this is the Corrected Sample Standard Deviation
77 struct SampleStdDev {
78     double value;
79 }
80 
81 SampleStdDev sampleStdDev(StatData data, Mean mean) {
82     const N = cast(double) data.length;
83     const s = data.value.map!(a => pow(a - mean.value, 2.0)).sum;
84     return SampleStdDev(sqrt(s / (N - 1.0)));
85 }
86 
87 struct Median {
88     double value;
89 }
90 
91 Median median(StatData data_) {
92     const data = data_.value.sort.map!(a => cast(double) a).array;
93 
94     if (data.length % 2 == 0)
95         return Median((data[$ / 2 - 1] + data[$ / 2]) / 2.0);
96     return Median(data[$ / 2]);
97 }
98 
99 struct Histogram {
100     long[] buckets;
101     double low;
102     double high;
103     double interval;
104 
105     this(double low, double high, long nrBuckets)
106     in (nrBuckets > 1, "failed nrBuckets > 1")
107     in (low < high, "failed low < high") {
108         this.low = low;
109         this.high = high;
110         interval = (high - low) / cast(double) nrBuckets;
111         buckets = iota(0, cast(long) ceil((high - low) / interval)).map!(a => 0L).array;
112     }
113 
114     void put(const double v)
115     in (v >= low && v <= high, "v must be in the range [low, high]") {
116         const idx = cast(long) floor((v - low) / interval);
117         assert(idx >= 0);
118 
119         if (idx < buckets.length)
120             buckets[idx] += 1;
121         else
122             buckets[$ - 1] += 1;
123     }
124 
125     string toString() @safe const {
126         auto buf = appender!string;
127         toString(buf);
128         return buf.data;
129     }
130 
131     void toString(Writer)(ref Writer w) const if (isOutputRange!(Writer, char)) {
132         import std.range : put;
133 
134         formattedWrite(w, "Histogram(low:%s, high:%s, interval:%s, buckets: [",
135                 low, high, interval);
136         foreach (const i; 0 .. buckets.length) {
137             if (i != 0)
138                 put(w, ", ");
139             formattedWrite(w, "[%s, %s]:%s", (low + i * interval),
140                     (low + (i + 1) * interval), buckets[i]);
141         }
142         put(w, "])");
143     }
144 
145     string toBar() @safe const {
146         auto buf = appender!string;
147         toBar(buf);
148         return buf.data;
149     }
150 
151     void toBar(Writer)(ref Writer w) const if (isOutputRange!(Writer, char)) {
152         import std.range : put;
153         import std.range : repeat;
154 
155         immutable maxWidth = 42;
156         const fit = () {
157             const m = maxElement(buckets);
158             if (m > maxWidth)
159                 return cast(double) m / cast(double) maxWidth;
160             return 1.0;
161         }();
162 
163         const indexWidth = cast(int) ceil(log10(buckets.length) + 1);
164 
165         foreach (const i; 0 .. buckets.length) {
166             const row = format("[%.3f, %.3f]", (low + i * interval), (low + (i + 1) * interval));
167             formattedWrite(w, "%*s %30s: %-(%s%) %s", indexWidth, i, row,
168                     repeat("#", cast(size_t)(buckets[i] / fit)), buckets[i]);
169             put(w, newline);
170         }
171     }
172 }
173 
174 Histogram histogram(StatData data, long nrBuckets) {
175     auto hist = () {
176         double low = data.value[0];
177         double high = data.value[0];
178         foreach (const v; data.value) {
179             low = min(low, v);
180             high = max(high, v);
181         }
182         return Histogram(low, high, nrBuckets);
183     }();
184 
185     foreach (const v; data.value)
186         hist.put(v);
187 
188     return hist;
189 }
190 
191 struct Mode {
192     double value;
193 }
194 
195 Mode mode(Histogram hist) {
196     long cnt = hist.buckets[0];
197     double rval = hist.low;
198     foreach (const i; 1 .. hist.buckets.length) {
199         if (hist.buckets[i] > cnt) {
200             rval = hist.low + (i + 0.5) * hist.interval;
201             cnt = hist.buckets[i];
202         }
203     }
204 
205     return Mode(rval);
206 }
207 
208 struct BasicStat {
209     Mean mean;
210     Median median;
211     SampleStdDev sd;
212 
213     string toString() @safe const {
214         auto buf = appender!string;
215         toString(buf);
216         return buf.data;
217     }
218 
219     void toString(Writer)(ref Writer w) const if (isOutputRange!(Writer, char)) {
220         formattedWrite(w, "BasicStat(mean:%s, median:%s, stdev: %s)",
221                 mean.value, median.value, sd.value);
222     }
223 }
224 
225 BasicStat basicStat(StatData data) {
226     auto m = mean(data);
227     return BasicStat(m, median(data), sampleStdDev(data, m));
228 }
229 
230 struct NormDistribution {
231     double mean;
232     double sd;
233 }
234 
235 /// From the C++ standard library implementation.
236 struct NormalDistributionPdf {
237     NormDistribution nd;
238     private double front_;
239     private double V;
240     private bool Vhot;
241 
242     double front() @safe pure nothrow {
243         assert(!empty, "Can't get front of an empty range");
244         return front_;
245     }
246 
247     void popFront() @safe {
248         assert(!empty, "Can't pop front of an empty range");
249 
250         import std.random : uniform;
251 
252         double Up;
253 
254         if (Vhot) {
255             Vhot = false;
256             Up = V;
257         } else {
258             double u;
259             double v;
260             double s;
261 
262             do {
263                 u = uniform(-1.0, 1.0);
264                 v = uniform(-1.0, 1.0);
265                 s = u * u + v * v;
266             }
267             while (s > 1 || s == 0);
268 
269             double Fp = sqrt(-2.0 * log(s) / s);
270             V = v * Fp;
271             Vhot = true;
272             Up = u * Fp;
273         }
274         front_ = Up * nd.sd + nd.mean;
275     }
276 
277     enum bool empty = false;
278 }
279 
280 NormalDistributionPdf pdf(NormDistribution nd) {
281     auto rval = NormalDistributionPdf(nd);
282     rval.popFront;
283     return rval;
284 }
285 
286 double cdf(NormDistribution nd, double x)
287 in (nd.sd > 0, "domain error") {
288     import core.stdc.math : erfc;
289 
290     if (isInfinity(x)) {
291         if (x < 0)
292             return 0;
293         return 1;
294     }
295 
296     const diff = (x - nd.mean) / (nd.sd * SQRT2);
297 
298     return cast(double) erfc(-diff) / 2.0;
299 }
300 
301 struct StdMeanError {
302     double value;
303 }
304 
305 StdMeanError stdError(StatData data)
306 in (data.value.length > 1) {
307     const len = data.value.length;
308     double[] means;
309     long samples = max(30, data.value.length);
310     for (; samples > 0; --samples) {
311         means ~= bootstrap(data).sum / cast(double) len;
312     }
313 
314     return StdMeanError(sampleStdDev(StatData(means), StatData(means).mean).value);
315 }
316 
317 auto bootstrap(StatData data, long minSamples = 5)
318 in (minSamples > 0)
319 in (data.value.length > 1) {
320     const len = data.value.length;
321     return iota(min(minSamples, len)).map!(a => uniform(0, len - 1))
322         .map!(a => data.value[a]);
323 }