1 /**
2 Copyright: Copyright (c) 2019, Joakim Brännström. All rights reserved.
3 License: MPL-2
4 Author: Joakim Brännström (joakim.brannstrom@gmx.com)
5 
6 This Source Code Form is subject to the terms of the Mozilla Public License,
7 v.2.0. If a copy of the MPL was not distributed with this file, You can obtain
8 one at http://mozilla.org/MPL/2.0/.
9 */
10 module process;
11 
12 import core.sys.posix.unistd : pid_t;
13 import core.thread : Thread;
14 import core.time : dur, Duration;
15 import logger = std.experimental.logger;
16 import std.algorithm : filter, count, joiner, map;
17 import std.array : appender, empty, array;
18 import std.exception : collectException;
19 import std.stdio : File, fileno, writeln;
20 static import std.process;
21 
22 public import process.channel;
23 
24 version (unittest) {
25     import unit_threaded.assertions;
26     import std.file : remove;
27 }
28 
29 /// RAII handling of a process instance.
30 auto raii(T)(T p) if (is(T : Process)) {
31     return Raii!T(p);
32 }
33 
34 struct Raii(T) {
35     T process;
36     alias process this;
37 
38     ~this() {
39         process.destroy();
40     }
41 }
42 
43 ///
44 interface Process {
45     /// Access to stdin and stdout.
46     Channel pipe() nothrow @safe;
47 
48     /// Access stderr.
49     ReadChannel stderr() nothrow @safe;
50 
51     /// Kill and cleanup the process.
52     void destroy() @safe;
53 
54     /// Kill the process.
55     void kill() nothrow @safe;
56 
57     /// Blocking wait for the process to terminated.
58     /// Returns: the exit status.
59     int wait() @safe;
60 
61     /// Non-blocking wait for the process termination.
62     /// Returns: `true` if the process has terminated.
63     bool tryWait() @safe;
64 
65     /// Returns: The raw OS handle for the process ID.
66     RawPid osHandle() nothrow @safe;
67 
68     /// Returns: The exit status of the process.
69     int status() @safe;
70 
71     /// Returns: If the process has terminated.
72     bool terminated() nothrow @safe;
73 }
74 
75 /** Async process that do not block on read from stdin/stderr.
76  */
77 class PipeProcess : Process {
78     import std.algorithm : among;
79 
80     private {
81         enum State {
82             running,
83             terminated,
84             exitCode
85         }
86 
87         std.process.ProcessPipes process;
88         Pipe pipe_;
89         ReadChannel stderr_;
90         int status_;
91         State st;
92     }
93 
94     this(std.process.ProcessPipes process) @safe {
95         this.process = process;
96         this.pipe_ = new Pipe(this.process.stdout, this.process.stdin);
97         this.stderr_ = new FileReadChannel(this.process.stderr);
98     }
99 
100     override RawPid osHandle() nothrow @safe {
101         return process.pid.osHandle.RawPid;
102     }
103 
104     override Channel pipe() nothrow @safe {
105         return pipe_;
106     }
107 
108     override ReadChannel stderr() nothrow @safe {
109         return stderr_;
110     }
111 
112     override void destroy() @safe {
113         final switch (st) {
114         case State.running:
115             this.kill;
116             this.wait;
117             break;
118         case State.terminated:
119             this.wait;
120             break;
121         case State.exitCode:
122             break;
123         }
124 
125         pipe_.destroy;
126         stderr_.destroy;
127     }
128 
129     override void kill() nothrow @trusted {
130         import core.sys.posix.signal : SIGKILL;
131 
132         final switch (st) {
133         case State.running:
134             break;
135         case State.terminated:
136             return;
137         case State.exitCode:
138             return;
139         }
140 
141         try {
142             std.process.kill(process.pid, SIGKILL);
143         } catch (Exception e) {
144         }
145 
146         st = State.terminated;
147     }
148 
149     override int wait() @safe {
150         final switch (st) {
151         case State.running:
152             status_ = std.process.wait(process.pid);
153             break;
154         case State.terminated:
155             status_ = std.process.wait(process.pid);
156             break;
157         case State.exitCode:
158             break;
159         }
160 
161         st = State.exitCode;
162 
163         return status_;
164     }
165 
166     override bool tryWait() @safe {
167         final switch (st) {
168         case State.running:
169             auto s = std.process.tryWait(process.pid);
170             if (s.terminated) {
171                 st = State.exitCode;
172                 status_ = s.status;
173             }
174             break;
175         case State.terminated:
176             status_ = std.process.wait(process.pid);
177             st = State.exitCode;
178             break;
179         case State.exitCode:
180             break;
181         }
182 
183         return st.among(State.terminated, State.exitCode) != 0;
184     }
185 
186     override int status() @safe {
187         if (st != State.exitCode) {
188             throw new Exception(
189                     "Process has not terminated and wait/tryWait been called to collect the exit status");
190         }
191         return status_;
192     }
193 
194     override bool terminated() @safe {
195         return st.among(State.terminated, State.exitCode) != 0;
196     }
197 }
198 
199 Process pipeProcess(scope const(char[])[] args,
200         std.process.Redirect redirect = std.process.Redirect.all,
201         const string[string] env = null, std.process.Config config = std.process.Config.none,
202         scope const(char)[] workDir = null) @safe {
203     return new PipeProcess(std.process.pipeProcess(args, redirect, env, config, workDir));
204 }
205 
206 Process pipeShell(scope const(char)[] command,
207         std.process.Redirect redirect = std.process.Redirect.all,
208         const string[string] env = null, std.process.Config config = std.process.Config.none,
209         scope const(char)[] workDir = null, string shellPath = std.process.nativeShell) @safe {
210     return new PipeProcess(std.process.pipeShell(command, redirect, env,
211             config, workDir, shellPath));
212 }
213 
214 /** Moves the process to a separate process group and on exit kill it and all
215  * its children.
216  */
217 class Sandbox : Process {
218     private {
219         Process p;
220     }
221 
222     this(Process p) @safe nothrow {
223         import core.sys.posix.unistd : setpgid;
224 
225         this.p = p;
226         setpgid(p.osHandle, 0);
227     }
228 
229     override RawPid osHandle() nothrow @safe {
230         return p.osHandle;
231     }
232 
233     override Channel pipe() nothrow @safe {
234         return p.pipe;
235     }
236 
237     override ReadChannel stderr() nothrow @safe {
238         return p.stderr;
239     }
240 
241     override void destroy() @safe {
242         this.kill;
243         p.destroy;
244     }
245 
246     override void kill() nothrow @safe {
247         static import core.sys.posix.signal;
248         import core.sys.posix.sys.wait : waitpid, WNOHANG;
249 
250         static RawPid[] update(RawPid[] pids) @trusted {
251             auto app = appender!(RawPid[])();
252 
253             foreach (p; pids) {
254                 try {
255                     app.put(getDeepChildren(p));
256                 } catch (Exception e) {
257                 }
258             }
259 
260             return app.data;
261         }
262 
263         static void killChildren(RawPid[] children) @trusted {
264             foreach (const c; children) {
265                 core.sys.posix.signal.kill(c, core.sys.posix.signal.SIGKILL);
266             }
267         }
268 
269         p.kill;
270         auto children = update([p.osHandle]);
271         auto reapChildren = appender!(RawPid[])();
272         // if there ever are processes that are spawned with root permissions
273         // or something happens that they can't be killed by "this" process
274         // tree. Thus limit the iterations to a reasonable number
275         for (int i = 0; !children.empty && i < 5; ++i) {
276             reapChildren.put(children);
277             killChildren(children);
278             children = update(children);
279         }
280 
281         foreach (c; reapChildren.data) {
282             () @trusted { waitpid(c, null, WNOHANG); }();
283         }
284     }
285 
286     override int wait() @safe {
287         return p.wait;
288     }
289 
290     override bool tryWait() @safe {
291         return p.tryWait;
292     }
293 
294     override int status() @safe {
295         return p.status;
296     }
297 
298     override bool terminated() @safe {
299         return p.terminated;
300     }
301 }
302 
303 Sandbox sandbox(Process p) @safe {
304     return new Sandbox(p);
305 }
306 
307 @("shall terminate a group of processes")
308 unittest {
309     import std.algorithm : count;
310     import std.datetime.stopwatch : StopWatch, AutoStart;
311 
312     immutable scriptName = makeScript(`#!/bin/bash
313 sleep 10m &
314 sleep 10m &
315 sleep 10m
316 `);
317     scope (exit)
318         remove(scriptName);
319 
320     auto p = pipeProcess([scriptName]).sandbox.raii;
321     for (int i = 0; getDeepChildren(p.osHandle).count < 3; ++i) {
322         Thread.sleep(50.dur!"msecs");
323     }
324     const preChildren = getDeepChildren(p.osHandle).count;
325     p.kill;
326     Thread.sleep(500.dur!"msecs"); // wait for the OS to kill the children
327     const postChildren = getDeepChildren(p.osHandle).count;
328 
329     p.wait.shouldEqual(-9);
330     p.terminated.shouldBeTrue;
331     preChildren.shouldEqual(3);
332     postChildren.shouldEqual(0);
333 }
334 
335 /** Terminate the process after the timeout. The timeout is checked in the
336  * wait/tryWait methods.
337  */
338 class Timeout : Process {
339     import std.algorithm : among;
340     import std.datetime : Clock, Duration;
341     import std.concurrency;
342 
343     private {
344         enum Msg {
345             none,
346             stop,
347             status,
348         }
349 
350         enum Reply {
351             none,
352             running,
353             normalDeath,
354             killedByTimeout,
355         }
356 
357         Process p;
358         shared KillProcess kp;
359         Tid background;
360         Reply backgroundReply;
361     }
362 
363     this(Process p, Duration timeout) @trusted {
364         this.p = p;
365         this.kp = cast(shared) new KillProcess(p);
366         background = spawn(&checkProcess, p.osHandle, timeout, kp);
367     }
368 
369     private static class KillProcess {
370         import core.sync.mutex : Mutex;
371 
372         private {
373             Process p;
374             Mutex mtx;
375         }
376         this(Process p) {
377             this.p = p;
378             this.mtx = new Mutex();
379         }
380 
381         void kill() @trusted nothrow {
382             this.mtx.lock_nothrow();
383             scope (exit)
384                 this.mtx.unlock_nothrow();
385             p.kill;
386         }
387     }
388 
389     private static void checkProcess(RawPid p, Duration timeout, shared KillProcess kp) {
390         import core.sys.posix.signal : SIGKILL;
391         import std.algorithm : max;
392         import std.variant : Variant;
393         static import core.sys.posix.signal;
394 
395         const stopAt = Clock.currTime + timeout;
396         // the purpose is to poll the process often "enough" that if it
397         // terminates early `Process` detects it fast enough. 1000 is chosen
398         // because it "feels good". the purpose
399         auto sleepInterval = max(20, timeout.total!"msecs" / 1000).dur!"msecs";
400 
401         bool forceStop;
402         Msg msg;
403         while (!forceStop && Clock.currTime < stopAt) {
404             msg = Msg.none;
405             const hasMsg = receiveTimeout(sleepInterval, (Msg x) { msg = x; }, (Variant x) {
406             },);
407 
408             final switch (msg) {
409             case Msg.none:
410                 break;
411             case Msg.stop:
412                 forceStop = true;
413                 break;
414             case Msg.status:
415                 send(ownerTid, Reply.running);
416                 break;
417             }
418 
419             if (!hasMsg && (core.sys.posix.signal.kill(p, 0) == -1)) {
420                 break;
421             }
422         }
423 
424         if (!forceStop && Clock.currTime >= stopAt) {
425             (cast() kp).kill;
426             send(ownerTid, Reply.killedByTimeout);
427         } else {
428             send(ownerTid, Reply.normalDeath);
429         }
430     }
431 
432     override RawPid osHandle() nothrow @safe {
433         return p.osHandle;
434     }
435 
436     override Channel pipe() nothrow @safe {
437         return p.pipe;
438     }
439 
440     override ReadChannel stderr() nothrow @safe {
441         return p.stderr;
442     }
443 
444     override void destroy() @trusted {
445         if (backgroundReply.among(Reply.none, Reply.running)) {
446             send(background, Msg.stop);
447             backgroundReply = receiveOnly!Reply;
448         }
449         p.destroy;
450     }
451 
452     override void kill() nothrow @trusted {
453         (cast() kp).kill;
454     }
455 
456     override int wait() @trusted {
457         while (!this.tryWait) {
458             Thread.sleep(20.dur!"msecs");
459         }
460         return p.wait;
461     }
462 
463     override bool tryWait() @safe {
464         return p.tryWait;
465     }
466 
467     override int status() @safe {
468         return p.status;
469     }
470 
471     override bool terminated() @safe {
472         return p.terminated;
473     }
474 
475     bool timeoutTriggered() @trusted {
476         if (backgroundReply.among(Reply.none, Reply.running)) {
477             send(background, Msg.status);
478             backgroundReply = receiveOnly!Reply;
479         }
480         return backgroundReply == Reply.killedByTimeout;
481     }
482 }
483 
484 Timeout timeout(Process p, Duration timeout) @safe {
485     return new Timeout(p, timeout);
486 }
487 
488 @("shall kill the process after the timeout")
489 unittest {
490     import std.datetime.stopwatch : StopWatch, AutoStart;
491 
492     auto p = pipeProcess(["sleep", "1m"]).timeout(100.dur!"msecs").raii;
493     auto sw = StopWatch(AutoStart.yes);
494     p.wait;
495     sw.stop;
496 
497     sw.peek.shouldBeGreaterThan(100.dur!"msecs");
498     sw.peek.shouldBeSmallerThan(500.dur!"msecs");
499     p.wait.shouldEqual(-9);
500     p.terminated.shouldBeTrue;
501     p.status.shouldEqual(-9);
502     p.timeoutTriggered.shouldBeTrue;
503 }
504 
505 /** Measure the runtime of a process.
506  */
507 class MeasureTime : Process {
508     import std.datetime.stopwatch : StopWatch;
509 
510     private {
511         Process p;
512         StopWatch sw;
513     }
514 
515     this(Process p) @safe nothrow @nogc {
516         this.p = p;
517         sw.start;
518     }
519 
520     override RawPid osHandle() nothrow @safe {
521         return p.osHandle;
522     }
523 
524     override Channel pipe() nothrow @safe {
525         return p.pipe;
526     }
527 
528     override ReadChannel stderr() nothrow @safe {
529         return p.stderr;
530     }
531 
532     override void destroy() @safe {
533         p.destroy;
534     }
535 
536     override void kill() nothrow @safe {
537         p.kill;
538     }
539 
540     override int wait() @safe {
541         if (!terminated) {
542             p.wait;
543             sw.stop;
544         }
545         return p.status;
546     }
547 
548     override bool tryWait() @safe {
549         if (!terminated && p.tryWait) {
550             sw.stop;
551         }
552         return p.terminated;
553     }
554 
555     override int status() @safe {
556         return p.status;
557     }
558 
559     override bool terminated() @safe {
560         return p.terminated;
561     }
562 
563     Duration time() @safe nothrow const @nogc {
564         return sw.peek;
565     }
566 }
567 
568 MeasureTime measureTime(Process p) @safe nothrow {
569     return new MeasureTime(p);
570 }
571 
572 struct RawPid {
573     pid_t value;
574     alias value this;
575 }
576 
577 RawPid[] getShallowChildren(const int parentPid) {
578     import std.algorithm : filter, splitter;
579     import std.conv : to;
580     import std.file : exists;
581     import std.path : buildPath;
582 
583     const pidPath = buildPath("/proc", parentPid.to!string);
584     if (!exists(pidPath)) {
585         return null;
586     }
587 
588     auto children = appender!(RawPid[])();
589     foreach (const p; File(buildPath(pidPath, "task", parentPid.to!string, "children")).readln.splitter(" ")
590             .filter!(a => !a.empty)) {
591         try {
592             children.put(p.to!pid_t.RawPid);
593         } catch (Exception e) {
594             logger.trace(e.msg).collectException;
595         }
596     }
597 
598     return children.data;
599 }
600 
601 /// Returns: a list of all processes with the leafs being at the back.
602 RawPid[] getDeepChildren(const int parentPid) {
603     import std.container : DList;
604 
605     auto children = DList!(RawPid)();
606 
607     children.insert(getShallowChildren(parentPid));
608     auto res = appender!(RawPid[])();
609 
610     while (!children.empty) {
611         const p = children.front;
612         res.put(p);
613         children.insertBack(getShallowChildren(p));
614         children.removeFront;
615     }
616 
617     return res.data;
618 }
619 
620 /// Returns when the process has pending data.
621 void waitForPendingData(Process p) {
622     while (!p.pipe.hasPendingData || p.stderr.hasPendingData) {
623         Thread.sleep(20.dur!"msecs");
624     }
625 }
626 
627 struct DrainElement {
628     enum Type {
629         stdout,
630         stderr,
631     }
632 
633     Type type;
634     const(ubyte)[] data;
635 
636     /// Returns: iterates the data as an input range.
637     auto byUTF8() @safe pure nothrow const @nogc {
638         static import std.utf;
639 
640         return std.utf.byUTF!(const(char))(cast(const(char)[]) data);
641     }
642 
643     bool empty() @safe pure nothrow const @nogc {
644         return data.length == 0;
645     }
646 }
647 
648 /** A range that drains a process stdout/stderr until it terminates.
649  *
650  * There may be `DrainElement` that are empty.
651  */
652 struct DrainRange {
653     enum State {
654         start,
655         draining,
656         lastStdout,
657         lastStderr,
658         lastElement,
659         empty,
660     }
661 
662     private {
663         Process p;
664         DrainElement front_;
665         State st;
666     }
667 
668     this(Process p) @safe pure nothrow @nogc {
669         this.p = p;
670     }
671 
672     DrainElement front() @safe pure nothrow const @nogc {
673         assert(!empty, "Can't get front of an empty range");
674         return front_;
675     }
676 
677     void popFront() @safe {
678         assert(!empty, "Can't pop front of an empty range");
679 
680         bool isAnyPipeOpen() {
681             return p.pipe.hasData || p.stderr.hasData;
682         }
683 
684         void readData() @safe {
685             if (p.pipe.hasData && p.pipe.hasPendingData) {
686                 front_ = DrainElement(DrainElement.Type.stdout, p.pipe.read(4096));
687             } else if (p.stderr.hasData && p.stderr.hasPendingData) {
688                 front_ = DrainElement(DrainElement.Type.stderr, p.stderr.read(4096));
689             }
690         }
691 
692         void waitUntilData() @safe {
693             while (front_.data.empty && isAnyPipeOpen) {
694                 import core.thread : Thread;
695                 import core.time : dur;
696 
697                 readData();
698                 if (front_.data.empty) {
699                     () @trusted { Thread.sleep(20.dur!"msecs"); }();
700                 }
701             }
702         }
703 
704         front_ = DrainElement.init;
705 
706         final switch (st) {
707         case State.start:
708             st = State.draining;
709             waitUntilData;
710             break;
711         case State.draining:
712             if (isAnyPipeOpen) {
713                 waitUntilData();
714             } else {
715                 st = State.lastStdout;
716             }
717             break;
718         case State.lastStdout:
719             st = State.lastStderr;
720             readData();
721             if (p.pipe.hasData && p.pipe.hasPendingData) {
722                 st = State.lastStdout;
723             }
724             break;
725         case State.lastStderr:
726             st = State.lastElement;
727             readData();
728             if (p.stderr.hasData && p.stderr.hasPendingData) {
729                 st = State.lastStderr;
730             }
731             break;
732         case State.lastElement:
733             st = State.empty;
734             break;
735         case State.empty:
736             break;
737         }
738     }
739 
740     bool empty() @safe pure nothrow const @nogc {
741         return st == State.empty;
742     }
743 }
744 
745 /// Drain a process pipe until empty.
746 DrainRange drain(Process p) @safe pure nothrow @nogc {
747     return DrainRange(p);
748 }
749 
750 /// Read the data from a ReadChannel by line.
751 struct DrainByLineCopyRange {
752     private {
753         Process process;
754         DrainRange range;
755         const(ubyte)[] buf;
756         const(char)[] line;
757     }
758 
759     this(Process p) @safe pure nothrow @nogc {
760         process = p;
761         range = p.drain;
762     }
763 
764     string front() @trusted pure nothrow const @nogc {
765         import std.exception : assumeUnique;
766 
767         assert(!empty, "Can't get front of an empty range");
768         return line.assumeUnique;
769     }
770 
771     void popFront() @safe {
772         assert(!empty, "Can't pop front of an empty range");
773         import std.algorithm : countUntil;
774         import std.array : array;
775         static import std.utf;
776 
777         void fillBuf() {
778             if (!range.empty) {
779                 range.popFront;
780             }
781             if (!range.empty) {
782                 buf ~= range.front.data;
783             }
784         }
785 
786         size_t idx;
787         do {
788             fillBuf();
789             idx = buf.countUntil('\n');
790         }
791         while (!range.empty && idx == -1);
792 
793         const(ubyte)[] tmp;
794         if (buf.empty) {
795             // do nothing
796         } else if (idx == -1) {
797             tmp = buf;
798             buf = null;
799         } else {
800             idx = () {
801                 if (idx < buf.length) {
802                     return idx + 1;
803                 }
804                 return idx;
805             }();
806             tmp = buf[0 .. idx];
807             buf = buf[idx .. $];
808         }
809 
810         if (!tmp.empty && tmp[$ - 1] == '\n') {
811             tmp = tmp[0 .. $ - 1];
812         }
813 
814         line = std.utf.byUTF!(const(char))(cast(const(char)[]) tmp).array;
815     }
816 
817     bool empty() @safe pure nothrow const @nogc {
818         return range.empty && buf.empty && line.empty;
819     }
820 }
821 
822 @("shall drain the process output by line")
823 unittest {
824     import std.algorithm : filter, count, joiner, map;
825     import std.array : array;
826 
827     auto p = pipeProcess(["dd", "if=/dev/zero", "bs=10", "count=3"]).raii;
828     auto res = p.drainByLineCopy.filter!"!a.empty".array;
829 
830     res.length.shouldEqual(3);
831     res.joiner.count.shouldBeGreaterThan(30);
832     p.wait.shouldEqual(0);
833     p.terminated.shouldBeTrue;
834 }
835 
836 auto drainByLineCopy(Process p) @safe {
837     return DrainByLineCopyRange(p);
838 }
839 
840 /// Drain the process output until it is done executing.
841 Process drainToNull(Process p) @safe {
842     foreach (l; p.drain) {
843     }
844     return p;
845 }
846 
847 /// Drain the output from the process into an output range.
848 Process drain(T)(Process p, ref T range) {
849     foreach (l; p.drain) {
850         range.put(l);
851     }
852     return p;
853 }
854 
855 @("shall drain the output of a process while it is running with a separation of stdout and stderr")
856 unittest {
857     auto p = pipeProcess(["dd", "if=/dev/urandom", "bs=10", "count=3"]).raii;
858     auto res = p.drain.array;
859 
860     // this is just a sanity check. It has to be kind a high because there is
861     // some wiggleroom allowed
862     res.count.shouldBeSmallerThan(50);
863 
864     res.filter!(a => a.type == DrainElement.Type.stdout)
865         .map!(a => a.data)
866         .joiner
867         .count
868         .shouldEqual(30);
869     res.filter!(a => a.type == DrainElement.Type.stderr).count.shouldBeGreaterThan(0);
870     p.wait.shouldEqual(0);
871     p.terminated.shouldBeTrue;
872 }
873 
874 @("shall kill the process tree when the timeout is reached")
875 unittest {
876     immutable script = makeScript(`#!/bin/bash
877 sleep 10m
878 `);
879     scope (exit)
880         remove(script);
881 
882     auto p = pipeProcess([script]).sandbox.timeout(1.dur!"seconds").raii;
883     for (int i = 0; getDeepChildren(p.osHandle).count < 1; ++i) {
884         Thread.sleep(50.dur!"msecs");
885     }
886     const preChildren = getDeepChildren(p.osHandle).count;
887     const res = p.drain.array;
888     const postChildren = getDeepChildren(p.osHandle).count;
889 
890     p.wait.shouldEqual(-9);
891     p.terminated.shouldBeTrue;
892     preChildren.shouldEqual(1);
893     postChildren.shouldEqual(0);
894 }
895 
896 string makeScript(string script, string file = __FILE__, uint line = __LINE__) {
897     import core.sys.posix.sys.stat;
898     import std.file : getAttributes, setAttributes, thisExePath;
899     import std.stdio : File;
900     import std.path : baseName;
901     import std.conv : to;
902 
903     immutable fname = thisExePath ~ "_" ~ file.baseName ~ line.to!string ~ ".sh";
904 
905     File(fname, "w").writeln(script);
906     setAttributes(fname, getAttributes(fname) | S_IXUSR | S_IXGRP | S_IXOTH);
907     return fname;
908 }