NoSoliciting/NoSoliciting/Ml/MlFilter.cs

233 lines
8.5 KiB
C#

using System;
using System.IO;
using System.Linq;
using System.Net;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
using Dalamud.Logging;
using NoSoliciting.Interface;
using NoSoliciting.Resources;
using YamlDotNet.Core;
using YamlDotNet.Serialization;
using YamlDotNet.Serialization.NamingConventions;
namespace NoSoliciting.Ml {
public class MlFilter : IDisposable {
public static string? LastError { get; private set; }
private const string ManifestName = "manifest.yaml";
private const string ModelName = "model.zip";
#if DEBUG
private const string Url = "http://localhost:8000/manifest.yaml";
#else
private const string Url = "https://annaclemens.io/assets/nosol/ml/manifest.yaml";
#endif
public uint Version { get; }
public Uri ReportUrl { get; }
private IClassifier Classifier { get; }
private MlFilter(uint version, Uri reportUrl, IClassifier classifier) {
this.Classifier = classifier;
this.Version = version;
this.ReportUrl = reportUrl;
}
public MessageCategory ClassifyMessage(ushort channel, string message) {
var prediction = this.Classifier.Classify(channel, message);
var category = MessageCategoryExt.FromString(prediction);
if (category != null) {
return (MessageCategory) category;
}
PluginLog.LogWarning($"Unknown message category: {prediction}");
return MessageCategory.Normal;
}
public static async Task<MlFilter?> Load(Plugin plugin, bool showWindow) {
plugin.MlStatus = MlFilterStatus.DownloadingManifest;
// download and parse the remote manifest
var manifest = await DownloadManifest();
if (manifest == null) {
PluginLog.LogWarning("Could not download manifest. Will attempt to fall back on cached version.");
}
// model zip file data
byte[]? data = null;
// load the cached manifest
var localManifest = LoadCachedManifest(plugin);
// if there is a cached manifest and we either couldn't download/parse the remote OR the cached version is the same as remote version
if (localManifest != null && (manifest?.Item1 == null || localManifest.Version == manifest.Value.manifest.Version)) {
try {
// try to reach the cached model
data = await File.ReadAllBytesAsync(CachedFilePath(plugin, ModelName));
// set the manifest to our local one and an empty string for the source
manifest ??= (localManifest, string.Empty);
} catch (IOException) {
// ignored
}
}
// if there is source for the manifest
if (!string.IsNullOrEmpty(manifest?.source)) {
plugin.MlStatus = MlFilterStatus.DownloadingModel;
// download the model if necessary
data ??= await DownloadModel(manifest!.Value.manifest!.ModelUrl);
}
// give up if we couldn't get any data at this point
if (data == null) {
plugin.MlStatus = MlFilterStatus.Uninitialised;
return null;
}
// validate checksum
var retries = 0;
const int maxRetries = 3;
var correctHash = manifest!.Value.manifest!.Hash();
using (var sha = SHA256.Create()) {
var hash = sha.ComputeHash(data);
while (!hash.SequenceEqual(correctHash) && retries < maxRetries) {
PluginLog.Warning($"Model checksum did not match. Redownloading (attempt {retries + 1}/{maxRetries})");
retries += 1;
data = await DownloadModel(manifest!.Value.manifest!.ModelUrl);
if (data != null) {
hash = sha.ComputeHash(data);
}
}
}
// give up if we couldn't get any data at this point
if (data == null) {
plugin.MlStatus = MlFilterStatus.Uninitialised;
return null;
}
plugin.MlStatus = MlFilterStatus.Initialising;
// if there is source for the manifest
if (!string.IsNullOrEmpty(manifest!.Value.source)) {
// update the cached files
UpdateCachedFile(plugin, ModelName, data);
UpdateCachedFile(plugin, ManifestName, Encoding.UTF8.GetBytes(manifest.Value.source));
}
// initialise the classifier
var classifier = new Classifier();
classifier.Initialise(data);
return new MlFilter(
manifest.Value.manifest!.Version,
manifest.Value.manifest!.ReportUrl,
classifier
);
}
private static async Task<byte[]?> DownloadModel(Uri url) {
try {
using var client = new WebClient();
var data = await client.DownloadDataTaskAsync(url);
return data;
} catch (WebException e) {
PluginLog.LogError("Could not download newest model.");
PluginLog.LogError(e.ToString());
LastError = e.Message;
return null;
}
}
private static string CachedFilePath(Plugin plugin, string name) {
var pluginFolder = plugin.Interface.ConfigDirectory.ToString();
Directory.CreateDirectory(pluginFolder);
return Path.Combine(pluginFolder, name);
}
private static async void UpdateCachedFile(Plugin plugin, string name, byte[] data) {
var cachePath = CachedFilePath(plugin, name);
var file = File.Create(cachePath);
await file.WriteAsync(data, 0, data.Length);
await file.FlushAsync();
await file.DisposeAsync();
}
private static async Task<(Manifest manifest, string source)?> DownloadManifest() {
try {
using var client = new WebClient();
var data = await client.DownloadStringTaskAsync(Url);
LastError = null;
return (LoadYaml<Manifest>(data), data);
} catch (Exception e) when (e is WebException or YamlException) {
PluginLog.LogError("Could not download newest model manifest.");
PluginLog.LogError(e.ToString());
LastError = e.Message;
return null;
}
}
private static Manifest? LoadCachedManifest(Plugin plugin) {
var manifestPath = CachedFilePath(plugin, ManifestName);
if (!File.Exists(manifestPath)) {
return null;
}
string data;
try {
data = File.ReadAllText(manifestPath);
} catch (IOException) {
return null;
}
try {
return LoadYaml<Manifest>(data);
} catch (YamlException) {
return null;
}
}
private static T LoadYaml<T>(string data) {
var de = new DeserializerBuilder()
.WithNamingConvention(UnderscoredNamingConvention.Instance)
.IgnoreUnmatchedProperties()
.Build();
return de.Deserialize<T>(data);
}
public void Dispose() {
this.Classifier.Dispose();
}
}
public enum MlFilterStatus {
Uninitialised,
Preparing,
DownloadingManifest,
DownloadingModel,
Initialising,
Initialised,
}
public static class MlFilterStatusExt {
public static string Description(this MlFilterStatus status) {
return status switch {
MlFilterStatus.Uninitialised => Language.ModelStatusUninitialised,
MlFilterStatus.Preparing => Language.ModelStatusPreparing,
MlFilterStatus.DownloadingManifest => Language.ModelStatusDownloadingManifest,
MlFilterStatus.DownloadingModel => Language.ModelStatusDownloadingModel,
MlFilterStatus.Initialising => Language.ModelStatusInitialising,
MlFilterStatus.Initialised => Language.ModelStatusInitialised,
_ => status.ToString(),
};
}
}
}