refactor: put computation in interface
This basically undoes the benefits of the previous commit. May end up being reverted.
This commit is contained in:
parent
effe41a345
commit
1b8f7806f5
|
@ -9,7 +9,7 @@ namespace NoSoliciting.CursedWorkaround {
|
|||
private MLContext Context { get; set; } = null!;
|
||||
private ITransformer Model { get; set; } = null!;
|
||||
private DataViewSchema Schema { get; set; } = null!;
|
||||
private PredictionEngine<MessageData, MessagePrediction> PredictionEngine { get; set; } = null!;
|
||||
private PredictionEngine<Data, Prediction> PredictionEngine { get; set; } = null!;
|
||||
|
||||
public override object? InitializeLifetimeService() {
|
||||
return null;
|
||||
|
@ -17,15 +17,16 @@ namespace NoSoliciting.CursedWorkaround {
|
|||
|
||||
public void Initialise(byte[] data) {
|
||||
this.Context = new MLContext();
|
||||
this.Context.ComponentCatalog.RegisterAssembly(typeof(Data).Assembly);
|
||||
using var stream = new MemoryStream(data);
|
||||
var model = this.Context.Model.Load(stream, out var schema);
|
||||
this.Model = model;
|
||||
this.Schema = schema;
|
||||
this.PredictionEngine = this.Context.Model.CreatePredictionEngine<MessageData, MessagePrediction>(this.Model, this.Schema);
|
||||
this.PredictionEngine = this.Context.Model.CreatePredictionEngine<Data, Prediction>(this.Model, this.Schema);
|
||||
}
|
||||
|
||||
public string Classify(ushort channel, string message) {
|
||||
return this.PredictionEngine.Predict(new MessageData(channel, message)).Category;
|
||||
return this.PredictionEngine.Predict(new Data(channel, message)).Category;
|
||||
}
|
||||
|
||||
public void Dispose() {
|
||||
|
|
|
@ -1,72 +0,0 @@
|
|||
using System.Globalization;
|
||||
using System.Linq;
|
||||
using System.Text.RegularExpressions;
|
||||
using Microsoft.ML.Data;
|
||||
|
||||
namespace NoSoliciting.CursedWorkaround {
|
||||
public class MessageData {
|
||||
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||
|
||||
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||
|
||||
private static readonly string[] PlotWords = {
|
||||
"plot",
|
||||
"apartment",
|
||||
"apt",
|
||||
};
|
||||
|
||||
private static readonly Regex NumbersRegex = new Regex(@"\d{1,2}.{0,2}\d{1,2}", RegexOptions.Compiled);
|
||||
|
||||
private static readonly string[] TradeWords = {
|
||||
"B> ",
|
||||
"S> ",
|
||||
"buy",
|
||||
"sell",
|
||||
"WTB",
|
||||
"WTS",
|
||||
};
|
||||
|
||||
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
|
||||
|
||||
public string? Category { get; }
|
||||
|
||||
public uint Channel { get; }
|
||||
|
||||
public string Message { get; }
|
||||
|
||||
public float Weight { get; } = 1;
|
||||
|
||||
public bool PartyFinder => this.Channel == 0;
|
||||
|
||||
public bool Shout => this.Channel == 11 || this.Channel == 30;
|
||||
|
||||
public bool ContainsWard => this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
|
||||
|
||||
public bool ContainsPlot => PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
|
||||
|
||||
public bool ContainsHousingNumbers => NumbersRegex.IsMatch(this.Message);
|
||||
|
||||
public bool ContainsTradeWords => TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
|
||||
|
||||
public bool ContainsSketchUrl => SketchUrlRegex.IsMatch(this.Message);
|
||||
|
||||
public MessageData(uint channel, string message) {
|
||||
this.Channel = channel;
|
||||
this.Message = message;
|
||||
}
|
||||
}
|
||||
|
||||
public class MessagePrediction {
|
||||
[ColumnName("PredictedLabel")]
|
||||
public string Category { get; set; } = null!;
|
||||
|
||||
[ColumnName("Score")]
|
||||
public float[] Probabilities { get; set; } = null!;
|
||||
}
|
||||
|
||||
public static class RmtExtensions {
|
||||
public static bool ContainsIgnoreCase(this string haystack, string needle) {
|
||||
return CultureInfo.InvariantCulture.CompareInfo.IndexOf(haystack, needle, CompareOptions.IgnoreCase) >= 0;
|
||||
}
|
||||
}
|
||||
}
|
123
NoSoliciting.Interface/Data.cs
Normal file
123
NoSoliciting.Interface/Data.cs
Normal file
|
@ -0,0 +1,123 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Globalization;
|
||||
using System.Linq;
|
||||
using System.Text.RegularExpressions;
|
||||
using Microsoft.ML.Data;
|
||||
using Microsoft.ML.Transforms;
|
||||
|
||||
namespace NoSoliciting.Interface {
|
||||
[SuppressMessage("ReSharper", "UnusedMember.Global")]
|
||||
public class Data {
|
||||
[LoadColumn(0)]
|
||||
public string? Category { get; set; }
|
||||
|
||||
[LoadColumn(1)]
|
||||
public uint Channel { get; set; }
|
||||
|
||||
[LoadColumn(2)]
|
||||
public string Message { get; set; } = null!;
|
||||
|
||||
public Data() {
|
||||
}
|
||||
|
||||
public Data(ushort channel, string message) {
|
||||
this.Channel = channel;
|
||||
this.Message = message;
|
||||
}
|
||||
|
||||
#region computed
|
||||
|
||||
[CustomMappingFactoryAttribute("Compute")]
|
||||
public class ComputeContext : CustomMappingFactory<Data, Computed> {
|
||||
private Dictionary<string, float> Weights { get; }
|
||||
|
||||
public ComputeContext() {
|
||||
this.Weights = new Dictionary<string, float>();
|
||||
}
|
||||
|
||||
public ComputeContext(Dictionary<string, float> weights) {
|
||||
this.Weights = weights;
|
||||
}
|
||||
|
||||
public void Compute(Data data, Computed computed) {
|
||||
data.Compute(computed, this.Weights);
|
||||
}
|
||||
|
||||
public override Action<Data, Computed> GetMapping() {
|
||||
return this.Compute;
|
||||
}
|
||||
}
|
||||
|
||||
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||
|
||||
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||
|
||||
private static readonly string[] PlotWords = {
|
||||
"plot",
|
||||
"apartment",
|
||||
"apt",
|
||||
};
|
||||
|
||||
private static readonly Regex NumbersRegex = new Regex(@"\d{1,2}.{0,2}\d{1,2}", RegexOptions.Compiled);
|
||||
|
||||
private static readonly string[] TradeWords = {
|
||||
"B> ",
|
||||
"S> ",
|
||||
"buy",
|
||||
"sell",
|
||||
"WTB",
|
||||
"WTS",
|
||||
};
|
||||
|
||||
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
|
||||
|
||||
public class Computed {
|
||||
public float Weight { get; set; } = 1;
|
||||
|
||||
public bool PartyFinder { get; set; }
|
||||
|
||||
public bool Shout { get; set; }
|
||||
|
||||
public bool ContainsWard { get; set; }
|
||||
|
||||
public bool ContainsPlot { get; set; }
|
||||
|
||||
public bool ContainsHousingNumbers { get; set; }
|
||||
|
||||
public bool ContainsTradeWords { get; set; }
|
||||
|
||||
public bool ContainsSketchUrl { get; set; }
|
||||
}
|
||||
|
||||
private void Compute(Computed output, IReadOnlyDictionary<string, float> weights) {
|
||||
if (this.Category != null && weights.TryGetValue(this.Category, out var weight)) {
|
||||
output.Weight = weight;
|
||||
}
|
||||
output.PartyFinder = this.Channel == 0;
|
||||
output.Shout = this.Channel == 11 || this.Channel == 30;
|
||||
output.ContainsWard = this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
|
||||
output.ContainsPlot = PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
|
||||
output.ContainsHousingNumbers = NumbersRegex.IsMatch(this.Message);
|
||||
output.ContainsTradeWords = TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
|
||||
output.ContainsSketchUrl = SketchUrlRegex.IsMatch(this.Message);
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
|
||||
public class Prediction {
|
||||
[ColumnName("PredictedLabel")]
|
||||
public string Category { get; set; }
|
||||
|
||||
[ColumnName("Score")]
|
||||
public float[] Probabilities { get; set; }
|
||||
}
|
||||
|
||||
internal static class Ext {
|
||||
public static bool ContainsIgnoreCase(this string haystack, string needle) {
|
||||
return CultureInfo.InvariantCulture.CompareInfo.IndexOf(haystack, needle, CompareOptions.IgnoreCase) >= 0;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,6 +2,12 @@
|
|||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net48</TargetFramework>
|
||||
<LangVersion>8</LangVersion>
|
||||
<Nullable>enable</Nullable>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.ML" Version="1.5.4" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
|
@ -12,4 +12,8 @@
|
|||
<PackageReference Include="Microsoft.ML" Version="1.5.4" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\NoSoliciting.Interface\NoSoliciting.Interface.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Globalization;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text.RegularExpressions;
|
||||
using ConsoleTables;
|
||||
using CsvHelper;
|
||||
using CsvHelper.Configuration;
|
||||
using CsvHelper.Configuration.Attributes;
|
||||
using Microsoft.ML;
|
||||
using Microsoft.ML.Data;
|
||||
using Microsoft.ML.Transforms.Text;
|
||||
using NoSoliciting.Interface;
|
||||
|
||||
namespace NoSoliciting.Trainer {
|
||||
internal static class Program {
|
||||
|
@ -57,12 +55,12 @@ namespace NoSoliciting.Trainer {
|
|||
|
||||
var ttd = ctx.Data.TrainTestSplit(df, 0.2, seed: 1);
|
||||
|
||||
void Compute(Data data, Data.Computed computed) {
|
||||
data.Compute(computed, weights);
|
||||
}
|
||||
var compute = new Data.ComputeContext(weights);
|
||||
|
||||
ctx.ComponentCatalog.RegisterAssembly(typeof(Data).Assembly);
|
||||
|
||||
var pipeline = ctx.Transforms.Conversion.MapValueToKey("Label", nameof(Data.Category))
|
||||
.Append(ctx.Transforms.CustomMapping((Action<Data, Data.Computed>) Compute, "Compute"))
|
||||
.Append(ctx.Transforms.CustomMapping((Action<Data, Data.Computed>) compute.Compute, "Compute"))
|
||||
.Append(ctx.Transforms.Text.NormalizeText("MsgNormal", nameof(Data.Message), keepPunctuations: false))
|
||||
.Append(ctx.Transforms.Text.TokenizeIntoWords("MsgTokens", "MsgNormal"))
|
||||
// .Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgTokens",
|
||||
|
@ -164,82 +162,6 @@ namespace NoSoliciting.Trainer {
|
|||
}
|
||||
}
|
||||
|
||||
[SuppressMessage("ReSharper", "UnusedMember.Global")]
|
||||
internal class Data {
|
||||
[LoadColumn(0), Index(0)]
|
||||
public string? Category { get; set; }
|
||||
|
||||
[LoadColumn(1), Index(1)]
|
||||
public uint Channel { get; set; }
|
||||
|
||||
[LoadColumn(2), Index(2)]
|
||||
public string Message { get; set; }
|
||||
|
||||
#region computed
|
||||
|
||||
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||
|
||||
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||
|
||||
private static readonly string[] PlotWords = {
|
||||
"plot",
|
||||
"apartment",
|
||||
"apt",
|
||||
};
|
||||
|
||||
private static readonly Regex NumbersRegex = new Regex(@"\d{1,2}.{0,2}\d{1,2}", RegexOptions.Compiled);
|
||||
|
||||
private static readonly string[] TradeWords = {
|
||||
"B> ",
|
||||
"S> ",
|
||||
"buy",
|
||||
"sell",
|
||||
"WTB",
|
||||
"WTS",
|
||||
};
|
||||
|
||||
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
|
||||
|
||||
internal class Computed {
|
||||
public float Weight { get; set; }
|
||||
|
||||
public bool PartyFinder { get; set; }
|
||||
|
||||
public bool Shout { get; set; }
|
||||
|
||||
public bool ContainsWard { get; set; }
|
||||
|
||||
public bool ContainsPlot { get; set; }
|
||||
|
||||
public bool ContainsHousingNumbers { get; set; }
|
||||
|
||||
public bool ContainsTradeWords { get; set; }
|
||||
|
||||
public bool ContainsSketchUrl { get; set; }
|
||||
}
|
||||
|
||||
internal void Compute(Computed output, Dictionary<string, float> weights) {
|
||||
output.Weight = this.Category == null ? 1 : weights[this.Category];
|
||||
output.PartyFinder = this.Channel == 0;
|
||||
output.Shout = this.Channel == 11 || this.Channel == 30;
|
||||
output.ContainsWard = this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
|
||||
output.ContainsPlot = PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
|
||||
output.ContainsHousingNumbers = NumbersRegex.IsMatch(this.Message);
|
||||
output.ContainsTradeWords = TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
|
||||
output.ContainsSketchUrl = SketchUrlRegex.IsMatch(this.Message);
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
|
||||
internal class Prediction {
|
||||
[ColumnName("PredictedLabel")]
|
||||
public string Category { get; set; }
|
||||
|
||||
[ColumnName("Score")]
|
||||
public float[] Probabilities { get; set; }
|
||||
}
|
||||
|
||||
internal static class Ext {
|
||||
public static bool ContainsIgnoreCase(this string haystack, string needle) {
|
||||
return CultureInfo.InvariantCulture.CompareInfo.IndexOf(haystack, needle, CompareOptions.IgnoreCase) >= 0;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
category,channel,message
|
||||
Category,Channel,Message
|
||||
FC,0,[FC recruitment] Small/New FC looking for more members to join us. New and experienced welcomed. Send tell if interested!
|
||||
FC,0,"<<GRIND>> is recruiting! If you're looking for a med-sized social active FC, then we would love to meet you! /tell for more info?"
|
||||
FC,0,<Panic> is recruiting! We're a slowly growing fc that would appreciate some new faces. /tell for more info or an inv <3
|
||||
|
|
|
Loading…
Reference in New Issue
Block a user