From 1b8f7806f5f161b19089b0d1e8a6ddaf1da27aa2 Mon Sep 17 00:00:00 2001 From: Anna Date: Mon, 28 Dec 2020 21:48:31 -0500 Subject: [PATCH] refactor: put computation in interface This basically undoes the benefits of the previous commit. May end up being reverted. --- .../CursedWorkaround.cs | 7 +- NoSoliciting.CursedWorkaround/Models.cs | 72 ---------- NoSoliciting.Interface/Data.cs | 123 ++++++++++++++++++ .../NoSoliciting.Interface.csproj | 6 + .../NoSoliciting.Trainer.csproj | 4 + NoSoliciting.Trainer/Program.cs | 88 +------------ NoSoliciting.Trainer/data.csv | 2 +- 7 files changed, 143 insertions(+), 159 deletions(-) delete mode 100644 NoSoliciting.CursedWorkaround/Models.cs create mode 100644 NoSoliciting.Interface/Data.cs diff --git a/NoSoliciting.CursedWorkaround/CursedWorkaround.cs b/NoSoliciting.CursedWorkaround/CursedWorkaround.cs index 87e2d60..833de97 100644 --- a/NoSoliciting.CursedWorkaround/CursedWorkaround.cs +++ b/NoSoliciting.CursedWorkaround/CursedWorkaround.cs @@ -9,7 +9,7 @@ namespace NoSoliciting.CursedWorkaround { private MLContext Context { get; set; } = null!; private ITransformer Model { get; set; } = null!; private DataViewSchema Schema { get; set; } = null!; - private PredictionEngine PredictionEngine { get; set; } = null!; + private PredictionEngine PredictionEngine { get; set; } = null!; public override object? InitializeLifetimeService() { return null; @@ -17,15 +17,16 @@ namespace NoSoliciting.CursedWorkaround { public void Initialise(byte[] data) { this.Context = new MLContext(); + this.Context.ComponentCatalog.RegisterAssembly(typeof(Data).Assembly); using var stream = new MemoryStream(data); var model = this.Context.Model.Load(stream, out var schema); this.Model = model; this.Schema = schema; - this.PredictionEngine = this.Context.Model.CreatePredictionEngine(this.Model, this.Schema); + this.PredictionEngine = this.Context.Model.CreatePredictionEngine(this.Model, this.Schema); } public string Classify(ushort channel, string message) { - return this.PredictionEngine.Predict(new MessageData(channel, message)).Category; + return this.PredictionEngine.Predict(new Data(channel, message)).Category; } public void Dispose() { diff --git a/NoSoliciting.CursedWorkaround/Models.cs b/NoSoliciting.CursedWorkaround/Models.cs deleted file mode 100644 index 63bdbbc..0000000 --- a/NoSoliciting.CursedWorkaround/Models.cs +++ /dev/null @@ -1,72 +0,0 @@ -using System.Globalization; -using System.Linq; -using System.Text.RegularExpressions; -using Microsoft.ML.Data; - -namespace NoSoliciting.CursedWorkaround { - public class MessageData { - private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase); - - private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase); - - private static readonly string[] PlotWords = { - "plot", - "apartment", - "apt", - }; - - private static readonly Regex NumbersRegex = new Regex(@"\d{1,2}.{0,2}\d{1,2}", RegexOptions.Compiled); - - private static readonly string[] TradeWords = { - "B> ", - "S> ", - "buy", - "sell", - "WTB", - "WTS", - }; - - private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled); - - public string? Category { get; } - - public uint Channel { get; } - - public string Message { get; } - - public float Weight { get; } = 1; - - public bool PartyFinder => this.Channel == 0; - - public bool Shout => this.Channel == 11 || this.Channel == 30; - - public bool ContainsWard => this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message); - - public bool ContainsPlot => PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message); - - public bool ContainsHousingNumbers => NumbersRegex.IsMatch(this.Message); - - public bool ContainsTradeWords => TradeWords.Any(word => this.Message.ContainsIgnoreCase(word)); - - public bool ContainsSketchUrl => SketchUrlRegex.IsMatch(this.Message); - - public MessageData(uint channel, string message) { - this.Channel = channel; - this.Message = message; - } - } - - public class MessagePrediction { - [ColumnName("PredictedLabel")] - public string Category { get; set; } = null!; - - [ColumnName("Score")] - public float[] Probabilities { get; set; } = null!; - } - - public static class RmtExtensions { - public static bool ContainsIgnoreCase(this string haystack, string needle) { - return CultureInfo.InvariantCulture.CompareInfo.IndexOf(haystack, needle, CompareOptions.IgnoreCase) >= 0; - } - } -} diff --git a/NoSoliciting.Interface/Data.cs b/NoSoliciting.Interface/Data.cs new file mode 100644 index 0000000..967aff2 --- /dev/null +++ b/NoSoliciting.Interface/Data.cs @@ -0,0 +1,123 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Linq; +using System.Text.RegularExpressions; +using Microsoft.ML.Data; +using Microsoft.ML.Transforms; + +namespace NoSoliciting.Interface { + [SuppressMessage("ReSharper", "UnusedMember.Global")] + public class Data { + [LoadColumn(0)] + public string? Category { get; set; } + + [LoadColumn(1)] + public uint Channel { get; set; } + + [LoadColumn(2)] + public string Message { get; set; } = null!; + + public Data() { + } + + public Data(ushort channel, string message) { + this.Channel = channel; + this.Message = message; + } + + #region computed + + [CustomMappingFactoryAttribute("Compute")] + public class ComputeContext : CustomMappingFactory { + private Dictionary Weights { get; } + + public ComputeContext() { + this.Weights = new Dictionary(); + } + + public ComputeContext(Dictionary weights) { + this.Weights = weights; + } + + public void Compute(Data data, Computed computed) { + data.Compute(computed, this.Weights); + } + + public override Action GetMapping() { + return this.Compute; + } + } + + private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase); + + private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase); + + private static readonly string[] PlotWords = { + "plot", + "apartment", + "apt", + }; + + private static readonly Regex NumbersRegex = new Regex(@"\d{1,2}.{0,2}\d{1,2}", RegexOptions.Compiled); + + private static readonly string[] TradeWords = { + "B> ", + "S> ", + "buy", + "sell", + "WTB", + "WTS", + }; + + private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled); + + public class Computed { + public float Weight { get; set; } = 1; + + public bool PartyFinder { get; set; } + + public bool Shout { get; set; } + + public bool ContainsWard { get; set; } + + public bool ContainsPlot { get; set; } + + public bool ContainsHousingNumbers { get; set; } + + public bool ContainsTradeWords { get; set; } + + public bool ContainsSketchUrl { get; set; } + } + + private void Compute(Computed output, IReadOnlyDictionary weights) { + if (this.Category != null && weights.TryGetValue(this.Category, out var weight)) { + output.Weight = weight; + } + output.PartyFinder = this.Channel == 0; + output.Shout = this.Channel == 11 || this.Channel == 30; + output.ContainsWard = this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message); + output.ContainsPlot = PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message); + output.ContainsHousingNumbers = NumbersRegex.IsMatch(this.Message); + output.ContainsTradeWords = TradeWords.Any(word => this.Message.ContainsIgnoreCase(word)); + output.ContainsSketchUrl = SketchUrlRegex.IsMatch(this.Message); + } + + #endregion + } + + public class Prediction { + [ColumnName("PredictedLabel")] + public string Category { get; set; } + + [ColumnName("Score")] + public float[] Probabilities { get; set; } + } + + internal static class Ext { + public static bool ContainsIgnoreCase(this string haystack, string needle) { + return CultureInfo.InvariantCulture.CompareInfo.IndexOf(haystack, needle, CompareOptions.IgnoreCase) >= 0; + } + } +} diff --git a/NoSoliciting.Interface/NoSoliciting.Interface.csproj b/NoSoliciting.Interface/NoSoliciting.Interface.csproj index 4ef46b5..77eb784 100644 --- a/NoSoliciting.Interface/NoSoliciting.Interface.csproj +++ b/NoSoliciting.Interface/NoSoliciting.Interface.csproj @@ -2,6 +2,12 @@ net48 + 8 + enable + + + + diff --git a/NoSoliciting.Trainer/NoSoliciting.Trainer.csproj b/NoSoliciting.Trainer/NoSoliciting.Trainer.csproj index 279f2f8..5387569 100644 --- a/NoSoliciting.Trainer/NoSoliciting.Trainer.csproj +++ b/NoSoliciting.Trainer/NoSoliciting.Trainer.csproj @@ -12,4 +12,8 @@ + + + + diff --git a/NoSoliciting.Trainer/Program.cs b/NoSoliciting.Trainer/Program.cs index af8631e..caf5e46 100644 --- a/NoSoliciting.Trainer/Program.cs +++ b/NoSoliciting.Trainer/Program.cs @@ -1,17 +1,15 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO; using System.Linq; -using System.Text.RegularExpressions; using ConsoleTables; using CsvHelper; using CsvHelper.Configuration; -using CsvHelper.Configuration.Attributes; using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Transforms.Text; +using NoSoliciting.Interface; namespace NoSoliciting.Trainer { internal static class Program { @@ -57,12 +55,12 @@ namespace NoSoliciting.Trainer { var ttd = ctx.Data.TrainTestSplit(df, 0.2, seed: 1); - void Compute(Data data, Data.Computed computed) { - data.Compute(computed, weights); - } + var compute = new Data.ComputeContext(weights); + + ctx.ComponentCatalog.RegisterAssembly(typeof(Data).Assembly); var pipeline = ctx.Transforms.Conversion.MapValueToKey("Label", nameof(Data.Category)) - .Append(ctx.Transforms.CustomMapping((Action) Compute, "Compute")) + .Append(ctx.Transforms.CustomMapping((Action) compute.Compute, "Compute")) .Append(ctx.Transforms.Text.NormalizeText("MsgNormal", nameof(Data.Message), keepPunctuations: false)) .Append(ctx.Transforms.Text.TokenizeIntoWords("MsgTokens", "MsgNormal")) // .Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgTokens", @@ -164,82 +162,6 @@ namespace NoSoliciting.Trainer { } } - [SuppressMessage("ReSharper", "UnusedMember.Global")] - internal class Data { - [LoadColumn(0), Index(0)] - public string? Category { get; set; } - - [LoadColumn(1), Index(1)] - public uint Channel { get; set; } - - [LoadColumn(2), Index(2)] - public string Message { get; set; } - - #region computed - - private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase); - - private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase); - - private static readonly string[] PlotWords = { - "plot", - "apartment", - "apt", - }; - - private static readonly Regex NumbersRegex = new Regex(@"\d{1,2}.{0,2}\d{1,2}", RegexOptions.Compiled); - - private static readonly string[] TradeWords = { - "B> ", - "S> ", - "buy", - "sell", - "WTB", - "WTS", - }; - - private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled); - - internal class Computed { - public float Weight { get; set; } - - public bool PartyFinder { get; set; } - - public bool Shout { get; set; } - - public bool ContainsWard { get; set; } - - public bool ContainsPlot { get; set; } - - public bool ContainsHousingNumbers { get; set; } - - public bool ContainsTradeWords { get; set; } - - public bool ContainsSketchUrl { get; set; } - } - - internal void Compute(Computed output, Dictionary weights) { - output.Weight = this.Category == null ? 1 : weights[this.Category]; - output.PartyFinder = this.Channel == 0; - output.Shout = this.Channel == 11 || this.Channel == 30; - output.ContainsWard = this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message); - output.ContainsPlot = PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message); - output.ContainsHousingNumbers = NumbersRegex.IsMatch(this.Message); - output.ContainsTradeWords = TradeWords.Any(word => this.Message.ContainsIgnoreCase(word)); - output.ContainsSketchUrl = SketchUrlRegex.IsMatch(this.Message); - } - - #endregion - } - - internal class Prediction { - [ColumnName("PredictedLabel")] - public string Category { get; set; } - - [ColumnName("Score")] - public float[] Probabilities { get; set; } - } - internal static class Ext { public static bool ContainsIgnoreCase(this string haystack, string needle) { return CultureInfo.InvariantCulture.CompareInfo.IndexOf(haystack, needle, CompareOptions.IgnoreCase) >= 0; diff --git a/NoSoliciting.Trainer/data.csv b/NoSoliciting.Trainer/data.csv index d359829..76caba2 100644 --- a/NoSoliciting.Trainer/data.csv +++ b/NoSoliciting.Trainer/data.csv @@ -1,4 +1,4 @@ -category,channel,message +Category,Channel,Message FC,0,[FC recruitment] Small/New FC looking for more members to join us. New and experienced welcomed. Send tell if interested! FC,0,"<> is recruiting! If you're looking for a med-sized social active FC, then we would love to meet you! /tell for more info?" FC,0, is recruiting! We're a slowly growing fc that would appreciate some new faces. /tell for more info or an inv <3