diff --git a/NoSoliciting.Trainer/Program.cs b/NoSoliciting.Trainer/Program.cs index 0311895..af8631e 100644 --- a/NoSoliciting.Trainer/Program.cs +++ b/NoSoliciting.Trainer/Program.cs @@ -53,16 +53,16 @@ namespace NoSoliciting.Trainer { weights[category] = w; } - // apply class weights - foreach (var record in records) { - record.Weight = weights[record.Category!]; - } - var df = ctx.Data.LoadFromEnumerable(records); var ttd = ctx.Data.TrainTestSplit(df, 0.2, seed: 1); + void Compute(Data data, Data.Computed computed) { + data.Compute(computed, weights); + } + var pipeline = ctx.Transforms.Conversion.MapValueToKey("Label", nameof(Data.Category)) + .Append(ctx.Transforms.CustomMapping((Action) Compute, "Compute")) .Append(ctx.Transforms.Text.NormalizeText("MsgNormal", nameof(Data.Message), keepPunctuations: false)) .Append(ctx.Transforms.Text.TokenizeIntoWords("MsgTokens", "MsgNormal")) // .Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgTokens", @@ -166,9 +166,21 @@ namespace NoSoliciting.Trainer { [SuppressMessage("ReSharper", "UnusedMember.Global")] internal class Data { + [LoadColumn(0), Index(0)] + public string? Category { get; set; } + + [LoadColumn(1), Index(1)] + public uint Channel { get; set; } + + [LoadColumn(2), Index(2)] + public string Message { get; set; } + + #region computed + private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase); private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase); + private static readonly string[] PlotWords = { "plot", "apartment", @@ -188,31 +200,36 @@ namespace NoSoliciting.Trainer { private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled); - [LoadColumn(0), Index(0)] - public string? Category { get; set; } + internal class Computed { + public float Weight { get; set; } - [LoadColumn(1), Index(1)] - public uint Channel { get; set; } + public bool PartyFinder { get; set; } - [LoadColumn(2), Index(2)] - public string Message { get; set; } + public bool Shout { get; set; } - [Ignore] - public float Weight { get; set; } = 1; + public bool ContainsWard { get; set; } - public bool PartyFinder => this.Channel == 0; + public bool ContainsPlot { get; set; } - public bool Shout => this.Channel == 11 || this.Channel == 30; + public bool ContainsHousingNumbers { get; set; } - public bool ContainsWard => this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message); + public bool ContainsTradeWords { get; set; } - public bool ContainsPlot => PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message); + public bool ContainsSketchUrl { get; set; } + } - public bool ContainsHousingNumbers => NumbersRegex.IsMatch(this.Message); + internal void Compute(Computed output, Dictionary weights) { + output.Weight = this.Category == null ? 1 : weights[this.Category]; + output.PartyFinder = this.Channel == 0; + output.Shout = this.Channel == 11 || this.Channel == 30; + output.ContainsWard = this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message); + output.ContainsPlot = PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message); + output.ContainsHousingNumbers = NumbersRegex.IsMatch(this.Message); + output.ContainsTradeWords = TradeWords.Any(word => this.Message.ContainsIgnoreCase(word)); + output.ContainsSketchUrl = SketchUrlRegex.IsMatch(this.Message); + } - public bool ContainsTradeWords => TradeWords.Any(word => this.Message.ContainsIgnoreCase(word)); - - public bool ContainsSketchUrl => SketchUrlRegex.IsMatch(this.Message); + #endregion } internal class Prediction {