From 9e247bee7b3a4b5109c94570b4bca52d87b35a5c Mon Sep 17 00:00:00 2001 From: Anna Date: Sun, 31 Jan 2021 23:59:04 -0500 Subject: [PATCH] fix: kill old classifiers --- NoSoliciting/Ml/MlFilter.cs | 63 ++++++++++++++++++++++++++++++------- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/NoSoliciting/Ml/MlFilter.cs b/NoSoliciting/Ml/MlFilter.cs index cc7b5b3..e9fed84 100644 --- a/NoSoliciting/Ml/MlFilter.cs +++ b/NoSoliciting/Ml/MlFilter.cs @@ -74,20 +74,22 @@ namespace NoSoliciting.Ml { UpdateCachedFile(plugin, ModelName, data); UpdateCachedFile(plugin, ManifestName, Encoding.UTF8.GetBytes(manifest.Item2)); - using var exe = Resource.AsStream("NoSoliciting.NoSoliciting.MessageClassifier.exe"); var pluginFolder = Util.PluginFolder(plugin); - Directory.CreateDirectory(pluginFolder); - var exePath = Path.Combine(pluginFolder, "NoSoliciting.MessageClassifier.exe"); - using (var exeFile = File.Create(exePath)) { - await exe.CopyToAsync(exeFile); - } - var startInfo = new ProcessStartInfo(exePath) { - CreateNoWindow = true, - UseShellExecute = false, - }; - var process = Process.Start(startInfo); + var pidPath = Path.Combine(pluginFolder, "classifier.pid"); + // close the old classifier if it's still open + CloseOldClassifier(pidPath); + + var exePath = await ExtractClassifier(pluginFolder); + + var process = StartClassifier(exePath, pidPath); + var client = await CreateClassifierClient(data); + + return new MlFilter(manifest.Item1.Version, process!, client); + } + + private static async Task> CreateClassifierClient(byte[] data) { var serviceProvider = new ServiceCollection() .AddNamedPipeIpcClient("client", (_, options) => { options.PipeName = "NoSoliciting.MessageClassifier"; @@ -99,8 +101,45 @@ namespace NoSoliciting.Ml { var client = clientFactory.CreateClient("client"); await client.InvokeAsync(classifier => classifier.Initialise(data)); + return client; + } - return new MlFilter(manifest.Item1.Version, process!, client); + private static Process StartClassifier(string exePath, string pidPath) { + var startInfo = new ProcessStartInfo(exePath) { + CreateNoWindow = true, + UseShellExecute = false, + }; + var process = Process.Start(startInfo); + File.WriteAllText(pidPath, process!.Id.ToString()); + return process; + } + + private static async Task ExtractClassifier(string pluginFolder) { + using var exe = Resource.AsStream("NoSoliciting.NoSoliciting.MessageClassifier.exe"); + Directory.CreateDirectory(pluginFolder); + var exePath = Path.Combine(pluginFolder, "NoSoliciting.MessageClassifier.exe"); + using var exeFile = File.Create(exePath); + await exe.CopyToAsync(exeFile); + + return exePath; + } + + private static void CloseOldClassifier(string pidPath) { + if (!File.Exists(pidPath)) { + return; + } + + if (!int.TryParse(File.ReadAllText(pidPath).Trim(), out var pid)) { + return; + } + + try { + var old = Process.GetProcessById(pid); + old.Kill(); + old.WaitForExit(); + } catch (ArgumentException) { + // ignore + } } private static async Task DownloadModel(Uri url) {