NoSoliciting/NoSoliciting/Ml/MlFilter.cs

232 lines
8.4 KiB
C#
Raw Normal View History

2020-12-21 02:49:10 +00:00
using System;
using System.IO;
using System.Linq;
2020-12-21 02:49:10 +00:00
using System.Net;
using System.Security.Cryptography;
2020-12-21 02:49:10 +00:00
using System.Text;
using System.Threading.Tasks;
using NoSoliciting.Interface;
using NoSoliciting.Resources;
2020-12-21 02:49:10 +00:00
using YamlDotNet.Core;
using YamlDotNet.Serialization;
using YamlDotNet.Serialization.NamingConventions;
namespace NoSoliciting.Ml {
public class MlFilter : IDisposable {
public static string? LastError { get; private set; }
private const string ManifestName = "manifest.yaml";
private const string ModelName = "model.zip";
2020-12-26 01:24:43 +00:00
#if DEBUG
2020-12-21 02:49:10 +00:00
private const string Url = "http://localhost:8000/manifest.yaml";
2020-12-27 16:33:33 +00:00
#else
private const string Url = "https://annaclemens.io/assets/nosol/ml/manifest.yaml";
2020-12-26 01:24:43 +00:00
#endif
2020-12-21 02:49:10 +00:00
public uint Version { get; }
public Uri ReportUrl { get; }
2020-12-21 02:49:10 +00:00
2021-08-22 22:07:28 +00:00
private IClassifier Classifier { get; }
2021-08-22 22:07:28 +00:00
private MlFilter(uint version, Uri reportUrl, IClassifier classifier) {
this.Classifier = classifier;
this.Version = version;
this.ReportUrl = reportUrl;
2020-12-21 02:49:10 +00:00
}
public MessageCategory ClassifyMessage(ushort channel, string message) {
2021-08-22 22:07:28 +00:00
var prediction = this.Classifier.Classify(channel, message);
var category = MessageCategoryExt.FromString(prediction);
2020-12-21 02:49:10 +00:00
if (category != null) {
return (MessageCategory) category;
}
2023-09-29 00:59:42 +00:00
Plugin.Log.Warning($"Unknown message category: {prediction}");
2020-12-21 02:49:10 +00:00
return MessageCategory.Normal;
}
2021-03-06 04:00:52 +00:00
public static async Task<MlFilter?> Load(Plugin plugin, bool showWindow) {
2021-02-25 01:58:47 +00:00
plugin.MlStatus = MlFilterStatus.DownloadingManifest;
// download and parse the remote manifest
2020-12-21 02:49:10 +00:00
var manifest = await DownloadManifest();
if (manifest == null) {
2023-09-29 00:59:42 +00:00
Plugin.Log.Warning("Could not download manifest. Will attempt to fall back on cached version.");
2020-12-21 02:49:10 +00:00
}
// model zip file data
2020-12-21 02:49:10 +00:00
byte[]? data = null;
// load the cached manifest
2020-12-21 02:49:10 +00:00
var localManifest = LoadCachedManifest(plugin);
// if there is a cached manifest and we either couldn't download/parse the remote OR the cached version is the same as remote version
if (localManifest != null && (manifest?.Item1 == null || localManifest.Version == manifest.Value.manifest.Version)) {
2020-12-21 02:49:10 +00:00
try {
// try to reach the cached model
2021-08-22 22:07:28 +00:00
data = await File.ReadAllBytesAsync(CachedFilePath(plugin, ModelName));
// set the manifest to our local one and an empty string for the source
manifest ??= (localManifest, string.Empty);
2020-12-21 02:49:10 +00:00
} catch (IOException) {
// ignored
}
}
// if there is source for the manifest
if (!string.IsNullOrEmpty(manifest?.source)) {
2021-02-25 01:58:47 +00:00
plugin.MlStatus = MlFilterStatus.DownloadingModel;
// download the model if necessary
data ??= await DownloadModel(manifest!.Value.manifest!.ModelUrl);
2021-02-25 01:58:47 +00:00
}
2020-12-21 02:49:10 +00:00
// give up if we couldn't get any data at this point
2020-12-21 02:49:10 +00:00
if (data == null) {
2021-02-25 01:58:47 +00:00
plugin.MlStatus = MlFilterStatus.Uninitialised;
2020-12-21 02:49:10 +00:00
return null;
}
// validate checksum
var retries = 0;
const int maxRetries = 3;
var correctHash = manifest!.Value.manifest!.Hash();
using (var sha = SHA256.Create()) {
var hash = sha.ComputeHash(data);
while (!hash.SequenceEqual(correctHash) && retries < maxRetries) {
2023-09-29 00:59:42 +00:00
Plugin.Log.Warning($"Model checksum did not match. Redownloading (attempt {retries + 1}/{maxRetries})");
retries += 1;
data = await DownloadModel(manifest!.Value.manifest!.ModelUrl);
if (data != null) {
hash = sha.ComputeHash(data);
}
}
}
// give up if we couldn't get any data at this point
if (data == null) {
plugin.MlStatus = MlFilterStatus.Uninitialised;
return null;
}
2021-02-25 01:58:47 +00:00
plugin.MlStatus = MlFilterStatus.Initialising;
// if there is source for the manifest
if (!string.IsNullOrEmpty(manifest!.Value.source)) {
// update the cached files
2021-02-25 01:58:47 +00:00
UpdateCachedFile(plugin, ModelName, data);
UpdateCachedFile(plugin, ManifestName, Encoding.UTF8.GetBytes(manifest.Value.source));
2021-02-25 01:58:47 +00:00
}
2020-12-21 02:49:10 +00:00
// initialise the classifier
2021-08-22 22:07:28 +00:00
var classifier = new Classifier();
classifier.Initialise(data);
2021-02-01 04:59:04 +00:00
return new MlFilter(
manifest.Value.manifest!.Version,
manifest.Value.manifest!.ReportUrl,
2021-08-22 22:07:28 +00:00
classifier
);
2021-02-01 04:59:04 +00:00
}
2020-12-21 02:49:10 +00:00
private static async Task<byte[]?> DownloadModel(Uri url) {
try {
using var client = new WebClient();
var data = await client.DownloadDataTaskAsync(url);
return data;
} catch (WebException e) {
2023-09-29 00:59:42 +00:00
Plugin.Log.Error("Could not download newest model.");
Plugin.Log.Error(e.ToString());
2020-12-21 02:49:10 +00:00
LastError = e.Message;
return null;
}
}
private static string CachedFilePath(Plugin plugin, string name) {
var pluginFolder = plugin.Interface.ConfigDirectory.ToString();
2020-12-21 02:49:10 +00:00
Directory.CreateDirectory(pluginFolder);
return Path.Combine(pluginFolder, name);
}
private static async void UpdateCachedFile(Plugin plugin, string name, byte[] data) {
2020-12-21 02:49:10 +00:00
var cachePath = CachedFilePath(plugin, name);
var file = File.Create(cachePath);
2020-12-21 02:49:10 +00:00
await file.WriteAsync(data, 0, data.Length);
await file.FlushAsync();
2021-08-22 22:07:28 +00:00
await file.DisposeAsync();
2020-12-21 02:49:10 +00:00
}
private static async Task<(Manifest manifest, string source)?> DownloadManifest() {
2020-12-21 02:49:10 +00:00
try {
using var client = new WebClient();
var data = await client.DownloadStringTaskAsync(Url);
LastError = null;
return (LoadYaml<Manifest>(data), data);
} catch (Exception e) when (e is WebException or YamlException) {
2023-09-29 00:59:42 +00:00
Plugin.Log.Error("Could not download newest model manifest.");
Plugin.Log.Error(e.ToString());
2020-12-21 02:49:10 +00:00
LastError = e.Message;
return null;
}
}
private static Manifest? LoadCachedManifest(Plugin plugin) {
2020-12-21 02:49:10 +00:00
var manifestPath = CachedFilePath(plugin, ManifestName);
if (!File.Exists(manifestPath)) {
return null;
}
string data;
try {
data = File.ReadAllText(manifestPath);
} catch (IOException) {
return null;
}
try {
return LoadYaml<Manifest>(data);
} catch (YamlException) {
return null;
}
}
private static T LoadYaml<T>(string data) {
var de = new DeserializerBuilder()
.WithNamingConvention(UnderscoredNamingConvention.Instance)
.IgnoreUnmatchedProperties()
.Build();
return de.Deserialize<T>(data);
}
public void Dispose() {
2021-08-22 22:07:28 +00:00
this.Classifier.Dispose();
2020-12-21 02:49:10 +00:00
}
}
2021-02-25 01:58:47 +00:00
public enum MlFilterStatus {
Uninitialised,
Preparing,
DownloadingManifest,
DownloadingModel,
Initialising,
Initialised,
}
public static class MlFilterStatusExt {
public static string Description(this MlFilterStatus status) {
return status switch {
MlFilterStatus.Uninitialised => Language.ModelStatusUninitialised,
MlFilterStatus.Preparing => Language.ModelStatusPreparing,
MlFilterStatus.DownloadingManifest => Language.ModelStatusDownloadingManifest,
MlFilterStatus.DownloadingModel => Language.ModelStatusDownloadingModel,
MlFilterStatus.Initialising => Language.ModelStatusInitialising,
MlFilterStatus.Initialised => Language.ModelStatusInitialised,
2021-02-25 01:58:47 +00:00
_ => status.ToString(),
};
}
}
2020-12-21 02:49:10 +00:00
}