refactor(training): compute properties in pipeline

Hopefully no longer required the data structure to be updated when new computed properties are added. This should also reduce duplication and make it easier to make bigger changes to the model without needing to update the plugin.
This commit is contained in:
Anna 2020-12-28 21:01:35 -05:00
parent 83cb794dfc
commit e24c54cfbc
Signed by: anna
GPG Key ID: 0B391D8F06FCD9E0

View File

@ -53,16 +53,16 @@ namespace NoSoliciting.Trainer {
weights[category] = w;
}
// apply class weights
foreach (var record in records) {
record.Weight = weights[record.Category!];
}
var df = ctx.Data.LoadFromEnumerable(records);
var ttd = ctx.Data.TrainTestSplit(df, 0.2, seed: 1);
void Compute(Data data, Data.Computed computed) {
data.Compute(computed, weights);
}
var pipeline = ctx.Transforms.Conversion.MapValueToKey("Label", nameof(Data.Category))
.Append(ctx.Transforms.CustomMapping((Action<Data, Data.Computed>) Compute, "Compute"))
.Append(ctx.Transforms.Text.NormalizeText("MsgNormal", nameof(Data.Message), keepPunctuations: false))
.Append(ctx.Transforms.Text.TokenizeIntoWords("MsgTokens", "MsgNormal"))
// .Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgTokens",
@ -166,9 +166,21 @@ namespace NoSoliciting.Trainer {
[SuppressMessage("ReSharper", "UnusedMember.Global")]
internal class Data {
[LoadColumn(0), Index(0)]
public string? Category { get; set; }
[LoadColumn(1), Index(1)]
public uint Channel { get; set; }
[LoadColumn(2), Index(2)]
public string Message { get; set; }
#region computed
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
private static readonly string[] PlotWords = {
"plot",
"apartment",
@ -188,31 +200,36 @@ namespace NoSoliciting.Trainer {
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
[LoadColumn(0), Index(0)]
public string? Category { get; set; }
internal class Computed {
public float Weight { get; set; }
[LoadColumn(1), Index(1)]
public uint Channel { get; set; }
public bool PartyFinder { get; set; }
[LoadColumn(2), Index(2)]
public string Message { get; set; }
public bool Shout { get; set; }
[Ignore]
public float Weight { get; set; } = 1;
public bool ContainsWard { get; set; }
public bool PartyFinder => this.Channel == 0;
public bool ContainsPlot { get; set; }
public bool Shout => this.Channel == 11 || this.Channel == 30;
public bool ContainsHousingNumbers { get; set; }
public bool ContainsWard => this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
public bool ContainsTradeWords { get; set; }
public bool ContainsPlot => PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
public bool ContainsSketchUrl { get; set; }
}
public bool ContainsHousingNumbers => NumbersRegex.IsMatch(this.Message);
internal void Compute(Computed output, Dictionary<string, float> weights) {
output.Weight = this.Category == null ? 1 : weights[this.Category];
output.PartyFinder = this.Channel == 0;
output.Shout = this.Channel == 11 || this.Channel == 30;
output.ContainsWard = this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
output.ContainsPlot = PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
output.ContainsHousingNumbers = NumbersRegex.IsMatch(this.Message);
output.ContainsTradeWords = TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
output.ContainsSketchUrl = SketchUrlRegex.IsMatch(this.Message);
}
public bool ContainsTradeWords => TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
public bool ContainsSketchUrl => SketchUrlRegex.IsMatch(this.Message);
#endregion
}
internal class Prediction {