refactor(training): compute properties in pipeline

Hopefully no longer required the data structure to be updated when new computed properties are added. This should also reduce duplication and make it easier to make bigger changes to the model without needing to update the plugin.
This commit is contained in:
Anna 2020-12-28 21:01:35 -05:00
parent 83cb794dfc
commit e24c54cfbc
Signed by: anna
GPG Key ID: 0B391D8F06FCD9E0

View File

@ -53,16 +53,16 @@ namespace NoSoliciting.Trainer {
weights[category] = w; weights[category] = w;
} }
// apply class weights
foreach (var record in records) {
record.Weight = weights[record.Category!];
}
var df = ctx.Data.LoadFromEnumerable(records); var df = ctx.Data.LoadFromEnumerable(records);
var ttd = ctx.Data.TrainTestSplit(df, 0.2, seed: 1); var ttd = ctx.Data.TrainTestSplit(df, 0.2, seed: 1);
void Compute(Data data, Data.Computed computed) {
data.Compute(computed, weights);
}
var pipeline = ctx.Transforms.Conversion.MapValueToKey("Label", nameof(Data.Category)) var pipeline = ctx.Transforms.Conversion.MapValueToKey("Label", nameof(Data.Category))
.Append(ctx.Transforms.CustomMapping((Action<Data, Data.Computed>) Compute, "Compute"))
.Append(ctx.Transforms.Text.NormalizeText("MsgNormal", nameof(Data.Message), keepPunctuations: false)) .Append(ctx.Transforms.Text.NormalizeText("MsgNormal", nameof(Data.Message), keepPunctuations: false))
.Append(ctx.Transforms.Text.TokenizeIntoWords("MsgTokens", "MsgNormal")) .Append(ctx.Transforms.Text.TokenizeIntoWords("MsgTokens", "MsgNormal"))
// .Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgTokens", // .Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgTokens",
@ -166,9 +166,21 @@ namespace NoSoliciting.Trainer {
[SuppressMessage("ReSharper", "UnusedMember.Global")] [SuppressMessage("ReSharper", "UnusedMember.Global")]
internal class Data { internal class Data {
[LoadColumn(0), Index(0)]
public string? Category { get; set; }
[LoadColumn(1), Index(1)]
public uint Channel { get; set; }
[LoadColumn(2), Index(2)]
public string Message { get; set; }
#region computed
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase); private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase); private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
private static readonly string[] PlotWords = { private static readonly string[] PlotWords = {
"plot", "plot",
"apartment", "apartment",
@ -188,31 +200,36 @@ namespace NoSoliciting.Trainer {
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled); private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
[LoadColumn(0), Index(0)] internal class Computed {
public string? Category { get; set; } public float Weight { get; set; }
[LoadColumn(1), Index(1)] public bool PartyFinder { get; set; }
public uint Channel { get; set; }
[LoadColumn(2), Index(2)] public bool Shout { get; set; }
public string Message { get; set; }
[Ignore] public bool ContainsWard { get; set; }
public float Weight { get; set; } = 1;
public bool PartyFinder => this.Channel == 0; public bool ContainsPlot { get; set; }
public bool Shout => this.Channel == 11 || this.Channel == 30; public bool ContainsHousingNumbers { get; set; }
public bool ContainsWard => this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message); public bool ContainsTradeWords { get; set; }
public bool ContainsPlot => PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message); public bool ContainsSketchUrl { get; set; }
}
public bool ContainsHousingNumbers => NumbersRegex.IsMatch(this.Message); internal void Compute(Computed output, Dictionary<string, float> weights) {
output.Weight = this.Category == null ? 1 : weights[this.Category];
output.PartyFinder = this.Channel == 0;
output.Shout = this.Channel == 11 || this.Channel == 30;
output.ContainsWard = this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
output.ContainsPlot = PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
output.ContainsHousingNumbers = NumbersRegex.IsMatch(this.Message);
output.ContainsTradeWords = TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
output.ContainsSketchUrl = SketchUrlRegex.IsMatch(this.Message);
}
public bool ContainsTradeWords => TradeWords.Any(word => this.Message.ContainsIgnoreCase(word)); #endregion
public bool ContainsSketchUrl => SketchUrlRegex.IsMatch(this.Message);
} }
internal class Prediction { internal class Prediction {