refactor(training): compute properties in pipeline
Hopefully no longer required the data structure to be updated when new computed properties are added. This should also reduce duplication and make it easier to make bigger changes to the model without needing to update the plugin.
This commit is contained in:
parent
83cb794dfc
commit
e24c54cfbc
|
@ -53,16 +53,16 @@ namespace NoSoliciting.Trainer {
|
|||
weights[category] = w;
|
||||
}
|
||||
|
||||
// apply class weights
|
||||
foreach (var record in records) {
|
||||
record.Weight = weights[record.Category!];
|
||||
}
|
||||
|
||||
var df = ctx.Data.LoadFromEnumerable(records);
|
||||
|
||||
var ttd = ctx.Data.TrainTestSplit(df, 0.2, seed: 1);
|
||||
|
||||
void Compute(Data data, Data.Computed computed) {
|
||||
data.Compute(computed, weights);
|
||||
}
|
||||
|
||||
var pipeline = ctx.Transforms.Conversion.MapValueToKey("Label", nameof(Data.Category))
|
||||
.Append(ctx.Transforms.CustomMapping((Action<Data, Data.Computed>) Compute, "Compute"))
|
||||
.Append(ctx.Transforms.Text.NormalizeText("MsgNormal", nameof(Data.Message), keepPunctuations: false))
|
||||
.Append(ctx.Transforms.Text.TokenizeIntoWords("MsgTokens", "MsgNormal"))
|
||||
// .Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgTokens",
|
||||
|
@ -166,9 +166,21 @@ namespace NoSoliciting.Trainer {
|
|||
|
||||
[SuppressMessage("ReSharper", "UnusedMember.Global")]
|
||||
internal class Data {
|
||||
[LoadColumn(0), Index(0)]
|
||||
public string? Category { get; set; }
|
||||
|
||||
[LoadColumn(1), Index(1)]
|
||||
public uint Channel { get; set; }
|
||||
|
||||
[LoadColumn(2), Index(2)]
|
||||
public string Message { get; set; }
|
||||
|
||||
#region computed
|
||||
|
||||
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||
|
||||
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||
|
||||
private static readonly string[] PlotWords = {
|
||||
"plot",
|
||||
"apartment",
|
||||
|
@ -188,31 +200,36 @@ namespace NoSoliciting.Trainer {
|
|||
|
||||
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
|
||||
|
||||
[LoadColumn(0), Index(0)]
|
||||
public string? Category { get; set; }
|
||||
internal class Computed {
|
||||
public float Weight { get; set; }
|
||||
|
||||
[LoadColumn(1), Index(1)]
|
||||
public uint Channel { get; set; }
|
||||
public bool PartyFinder { get; set; }
|
||||
|
||||
[LoadColumn(2), Index(2)]
|
||||
public string Message { get; set; }
|
||||
public bool Shout { get; set; }
|
||||
|
||||
[Ignore]
|
||||
public float Weight { get; set; } = 1;
|
||||
public bool ContainsWard { get; set; }
|
||||
|
||||
public bool PartyFinder => this.Channel == 0;
|
||||
public bool ContainsPlot { get; set; }
|
||||
|
||||
public bool Shout => this.Channel == 11 || this.Channel == 30;
|
||||
public bool ContainsHousingNumbers { get; set; }
|
||||
|
||||
public bool ContainsWard => this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
|
||||
public bool ContainsTradeWords { get; set; }
|
||||
|
||||
public bool ContainsPlot => PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
|
||||
public bool ContainsSketchUrl { get; set; }
|
||||
}
|
||||
|
||||
public bool ContainsHousingNumbers => NumbersRegex.IsMatch(this.Message);
|
||||
internal void Compute(Computed output, Dictionary<string, float> weights) {
|
||||
output.Weight = this.Category == null ? 1 : weights[this.Category];
|
||||
output.PartyFinder = this.Channel == 0;
|
||||
output.Shout = this.Channel == 11 || this.Channel == 30;
|
||||
output.ContainsWard = this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
|
||||
output.ContainsPlot = PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
|
||||
output.ContainsHousingNumbers = NumbersRegex.IsMatch(this.Message);
|
||||
output.ContainsTradeWords = TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
|
||||
output.ContainsSketchUrl = SketchUrlRegex.IsMatch(this.Message);
|
||||
}
|
||||
|
||||
public bool ContainsTradeWords => TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
|
||||
|
||||
public bool ContainsSketchUrl => SketchUrlRegex.IsMatch(this.Message);
|
||||
#endregion
|
||||
}
|
||||
|
||||
internal class Prediction {
|
||||
|
|
Loading…
Reference in New Issue
Block a user