refactor: put computation in interface

This basically undoes the benefits of the previous commit. May end up being reverted.
This commit is contained in:
Anna 2020-12-28 21:48:31 -05:00
parent effe41a345
commit 1b8f7806f5
7 changed files with 143 additions and 159 deletions

View File

@ -9,7 +9,7 @@ namespace NoSoliciting.CursedWorkaround {
private MLContext Context { get; set; } = null!;
private ITransformer Model { get; set; } = null!;
private DataViewSchema Schema { get; set; } = null!;
private PredictionEngine<MessageData, MessagePrediction> PredictionEngine { get; set; } = null!;
private PredictionEngine<Data, Prediction> PredictionEngine { get; set; } = null!;
public override object? InitializeLifetimeService() {
return null;
@ -17,15 +17,16 @@ namespace NoSoliciting.CursedWorkaround {
public void Initialise(byte[] data) {
this.Context = new MLContext();
this.Context.ComponentCatalog.RegisterAssembly(typeof(Data).Assembly);
using var stream = new MemoryStream(data);
var model = this.Context.Model.Load(stream, out var schema);
this.Model = model;
this.Schema = schema;
this.PredictionEngine = this.Context.Model.CreatePredictionEngine<MessageData, MessagePrediction>(this.Model, this.Schema);
this.PredictionEngine = this.Context.Model.CreatePredictionEngine<Data, Prediction>(this.Model, this.Schema);
}
public string Classify(ushort channel, string message) {
return this.PredictionEngine.Predict(new MessageData(channel, message)).Category;
return this.PredictionEngine.Predict(new Data(channel, message)).Category;
}
public void Dispose() {

View File

@ -1,72 +0,0 @@
using System.Globalization;
using System.Linq;
using System.Text.RegularExpressions;
using Microsoft.ML.Data;
namespace NoSoliciting.CursedWorkaround {
public class MessageData {
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
private static readonly string[] PlotWords = {
"plot",
"apartment",
"apt",
};
private static readonly Regex NumbersRegex = new Regex(@"\d{1,2}.{0,2}\d{1,2}", RegexOptions.Compiled);
private static readonly string[] TradeWords = {
"B> ",
"S> ",
"buy",
"sell",
"WTB",
"WTS",
};
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
public string? Category { get; }
public uint Channel { get; }
public string Message { get; }
public float Weight { get; } = 1;
public bool PartyFinder => this.Channel == 0;
public bool Shout => this.Channel == 11 || this.Channel == 30;
public bool ContainsWard => this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
public bool ContainsPlot => PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
public bool ContainsHousingNumbers => NumbersRegex.IsMatch(this.Message);
public bool ContainsTradeWords => TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
public bool ContainsSketchUrl => SketchUrlRegex.IsMatch(this.Message);
public MessageData(uint channel, string message) {
this.Channel = channel;
this.Message = message;
}
}
public class MessagePrediction {
[ColumnName("PredictedLabel")]
public string Category { get; set; } = null!;
[ColumnName("Score")]
public float[] Probabilities { get; set; } = null!;
}
public static class RmtExtensions {
public static bool ContainsIgnoreCase(this string haystack, string needle) {
return CultureInfo.InvariantCulture.CompareInfo.IndexOf(haystack, needle, CompareOptions.IgnoreCase) >= 0;
}
}
}

View File

@ -0,0 +1,123 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Text.RegularExpressions;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms;
namespace NoSoliciting.Interface {
[SuppressMessage("ReSharper", "UnusedMember.Global")]
public class Data {
[LoadColumn(0)]
public string? Category { get; set; }
[LoadColumn(1)]
public uint Channel { get; set; }
[LoadColumn(2)]
public string Message { get; set; } = null!;
public Data() {
}
public Data(ushort channel, string message) {
this.Channel = channel;
this.Message = message;
}
#region computed
[CustomMappingFactoryAttribute("Compute")]
public class ComputeContext : CustomMappingFactory<Data, Computed> {
private Dictionary<string, float> Weights { get; }
public ComputeContext() {
this.Weights = new Dictionary<string, float>();
}
public ComputeContext(Dictionary<string, float> weights) {
this.Weights = weights;
}
public void Compute(Data data, Computed computed) {
data.Compute(computed, this.Weights);
}
public override Action<Data, Computed> GetMapping() {
return this.Compute;
}
}
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
private static readonly string[] PlotWords = {
"plot",
"apartment",
"apt",
};
private static readonly Regex NumbersRegex = new Regex(@"\d{1,2}.{0,2}\d{1,2}", RegexOptions.Compiled);
private static readonly string[] TradeWords = {
"B> ",
"S> ",
"buy",
"sell",
"WTB",
"WTS",
};
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
public class Computed {
public float Weight { get; set; } = 1;
public bool PartyFinder { get; set; }
public bool Shout { get; set; }
public bool ContainsWard { get; set; }
public bool ContainsPlot { get; set; }
public bool ContainsHousingNumbers { get; set; }
public bool ContainsTradeWords { get; set; }
public bool ContainsSketchUrl { get; set; }
}
private void Compute(Computed output, IReadOnlyDictionary<string, float> weights) {
if (this.Category != null && weights.TryGetValue(this.Category, out var weight)) {
output.Weight = weight;
}
output.PartyFinder = this.Channel == 0;
output.Shout = this.Channel == 11 || this.Channel == 30;
output.ContainsWard = this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
output.ContainsPlot = PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
output.ContainsHousingNumbers = NumbersRegex.IsMatch(this.Message);
output.ContainsTradeWords = TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
output.ContainsSketchUrl = SketchUrlRegex.IsMatch(this.Message);
}
#endregion
}
public class Prediction {
[ColumnName("PredictedLabel")]
public string Category { get; set; }
[ColumnName("Score")]
public float[] Probabilities { get; set; }
}
internal static class Ext {
public static bool ContainsIgnoreCase(this string haystack, string needle) {
return CultureInfo.InvariantCulture.CompareInfo.IndexOf(haystack, needle, CompareOptions.IgnoreCase) >= 0;
}
}
}

View File

@ -2,6 +2,12 @@
<PropertyGroup>
<TargetFramework>net48</TargetFramework>
<LangVersion>8</LangVersion>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.ML" Version="1.5.4" />
</ItemGroup>
</Project>

View File

@ -12,4 +12,8 @@
<PackageReference Include="Microsoft.ML" Version="1.5.4" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\NoSoliciting.Interface\NoSoliciting.Interface.csproj" />
</ItemGroup>
</Project>

View File

@ -1,17 +1,15 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text.RegularExpressions;
using ConsoleTables;
using CsvHelper;
using CsvHelper.Configuration;
using CsvHelper.Configuration.Attributes;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms.Text;
using NoSoliciting.Interface;
namespace NoSoliciting.Trainer {
internal static class Program {
@ -57,12 +55,12 @@ namespace NoSoliciting.Trainer {
var ttd = ctx.Data.TrainTestSplit(df, 0.2, seed: 1);
void Compute(Data data, Data.Computed computed) {
data.Compute(computed, weights);
}
var compute = new Data.ComputeContext(weights);
ctx.ComponentCatalog.RegisterAssembly(typeof(Data).Assembly);
var pipeline = ctx.Transforms.Conversion.MapValueToKey("Label", nameof(Data.Category))
.Append(ctx.Transforms.CustomMapping((Action<Data, Data.Computed>) Compute, "Compute"))
.Append(ctx.Transforms.CustomMapping((Action<Data, Data.Computed>) compute.Compute, "Compute"))
.Append(ctx.Transforms.Text.NormalizeText("MsgNormal", nameof(Data.Message), keepPunctuations: false))
.Append(ctx.Transforms.Text.TokenizeIntoWords("MsgTokens", "MsgNormal"))
// .Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgTokens",
@ -164,82 +162,6 @@ namespace NoSoliciting.Trainer {
}
}
[SuppressMessage("ReSharper", "UnusedMember.Global")]
internal class Data {
[LoadColumn(0), Index(0)]
public string? Category { get; set; }
[LoadColumn(1), Index(1)]
public uint Channel { get; set; }
[LoadColumn(2), Index(2)]
public string Message { get; set; }
#region computed
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
private static readonly string[] PlotWords = {
"plot",
"apartment",
"apt",
};
private static readonly Regex NumbersRegex = new Regex(@"\d{1,2}.{0,2}\d{1,2}", RegexOptions.Compiled);
private static readonly string[] TradeWords = {
"B> ",
"S> ",
"buy",
"sell",
"WTB",
"WTS",
};
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
internal class Computed {
public float Weight { get; set; }
public bool PartyFinder { get; set; }
public bool Shout { get; set; }
public bool ContainsWard { get; set; }
public bool ContainsPlot { get; set; }
public bool ContainsHousingNumbers { get; set; }
public bool ContainsTradeWords { get; set; }
public bool ContainsSketchUrl { get; set; }
}
internal void Compute(Computed output, Dictionary<string, float> weights) {
output.Weight = this.Category == null ? 1 : weights[this.Category];
output.PartyFinder = this.Channel == 0;
output.Shout = this.Channel == 11 || this.Channel == 30;
output.ContainsWard = this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
output.ContainsPlot = PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
output.ContainsHousingNumbers = NumbersRegex.IsMatch(this.Message);
output.ContainsTradeWords = TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
output.ContainsSketchUrl = SketchUrlRegex.IsMatch(this.Message);
}
#endregion
}
internal class Prediction {
[ColumnName("PredictedLabel")]
public string Category { get; set; }
[ColumnName("Score")]
public float[] Probabilities { get; set; }
}
internal static class Ext {
public static bool ContainsIgnoreCase(this string haystack, string needle) {
return CultureInfo.InvariantCulture.CompareInfo.IndexOf(haystack, needle, CompareOptions.IgnoreCase) >= 0;

View File

@ -1,4 +1,4 @@
category,channel,message
Category,Channel,Message
FC,0,[FC recruitment] Small/New FC looking for more members to join us. New and experienced welcomed. Send tell if interested!
FC,0,"<<GRIND>> is recruiting! If you're looking for a med-sized social active FC, then we would love to meet you! /tell for more info?"
FC,0,<Panic> is recruiting! We're a slowly growing fc that would appreciate some new faces. /tell for more info or an inv <3

1 category Category channel Channel message Message
2 FC 0 [FC recruitment] Small/New FC looking for more members to join us. New and experienced welcomed. Send tell if interested!
3 FC 0 <<GRIND>> is recruiting! If you're looking for a med-sized social active FC, then we would love to meet you! /tell for more info?
4 FC 0 <Panic> is recruiting! We're a slowly growing fc that would appreciate some new faces. /tell for more info or an inv <3