refactor(training): compute properties in pipeline
Hopefully no longer required the data structure to be updated when new computed properties are added. This should also reduce duplication and make it easier to make bigger changes to the model without needing to update the plugin.
This commit is contained in:
parent
83cb794dfc
commit
e24c54cfbc
|
@ -53,16 +53,16 @@ namespace NoSoliciting.Trainer {
|
||||||
weights[category] = w;
|
weights[category] = w;
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply class weights
|
|
||||||
foreach (var record in records) {
|
|
||||||
record.Weight = weights[record.Category!];
|
|
||||||
}
|
|
||||||
|
|
||||||
var df = ctx.Data.LoadFromEnumerable(records);
|
var df = ctx.Data.LoadFromEnumerable(records);
|
||||||
|
|
||||||
var ttd = ctx.Data.TrainTestSplit(df, 0.2, seed: 1);
|
var ttd = ctx.Data.TrainTestSplit(df, 0.2, seed: 1);
|
||||||
|
|
||||||
|
void Compute(Data data, Data.Computed computed) {
|
||||||
|
data.Compute(computed, weights);
|
||||||
|
}
|
||||||
|
|
||||||
var pipeline = ctx.Transforms.Conversion.MapValueToKey("Label", nameof(Data.Category))
|
var pipeline = ctx.Transforms.Conversion.MapValueToKey("Label", nameof(Data.Category))
|
||||||
|
.Append(ctx.Transforms.CustomMapping((Action<Data, Data.Computed>) Compute, "Compute"))
|
||||||
.Append(ctx.Transforms.Text.NormalizeText("MsgNormal", nameof(Data.Message), keepPunctuations: false))
|
.Append(ctx.Transforms.Text.NormalizeText("MsgNormal", nameof(Data.Message), keepPunctuations: false))
|
||||||
.Append(ctx.Transforms.Text.TokenizeIntoWords("MsgTokens", "MsgNormal"))
|
.Append(ctx.Transforms.Text.TokenizeIntoWords("MsgTokens", "MsgNormal"))
|
||||||
// .Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgTokens",
|
// .Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgTokens",
|
||||||
|
@ -166,9 +166,21 @@ namespace NoSoliciting.Trainer {
|
||||||
|
|
||||||
[SuppressMessage("ReSharper", "UnusedMember.Global")]
|
[SuppressMessage("ReSharper", "UnusedMember.Global")]
|
||||||
internal class Data {
|
internal class Data {
|
||||||
|
[LoadColumn(0), Index(0)]
|
||||||
|
public string? Category { get; set; }
|
||||||
|
|
||||||
|
[LoadColumn(1), Index(1)]
|
||||||
|
public uint Channel { get; set; }
|
||||||
|
|
||||||
|
[LoadColumn(2), Index(2)]
|
||||||
|
public string Message { get; set; }
|
||||||
|
|
||||||
|
#region computed
|
||||||
|
|
||||||
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||||
|
|
||||||
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||||
|
|
||||||
private static readonly string[] PlotWords = {
|
private static readonly string[] PlotWords = {
|
||||||
"plot",
|
"plot",
|
||||||
"apartment",
|
"apartment",
|
||||||
|
@ -188,31 +200,36 @@ namespace NoSoliciting.Trainer {
|
||||||
|
|
||||||
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
|
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
|
||||||
|
|
||||||
[LoadColumn(0), Index(0)]
|
internal class Computed {
|
||||||
public string? Category { get; set; }
|
public float Weight { get; set; }
|
||||||
|
|
||||||
[LoadColumn(1), Index(1)]
|
public bool PartyFinder { get; set; }
|
||||||
public uint Channel { get; set; }
|
|
||||||
|
|
||||||
[LoadColumn(2), Index(2)]
|
public bool Shout { get; set; }
|
||||||
public string Message { get; set; }
|
|
||||||
|
|
||||||
[Ignore]
|
public bool ContainsWard { get; set; }
|
||||||
public float Weight { get; set; } = 1;
|
|
||||||
|
|
||||||
public bool PartyFinder => this.Channel == 0;
|
public bool ContainsPlot { get; set; }
|
||||||
|
|
||||||
public bool Shout => this.Channel == 11 || this.Channel == 30;
|
public bool ContainsHousingNumbers { get; set; }
|
||||||
|
|
||||||
public bool ContainsWard => this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
|
public bool ContainsTradeWords { get; set; }
|
||||||
|
|
||||||
public bool ContainsPlot => PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
|
public bool ContainsSketchUrl { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
public bool ContainsHousingNumbers => NumbersRegex.IsMatch(this.Message);
|
internal void Compute(Computed output, Dictionary<string, float> weights) {
|
||||||
|
output.Weight = this.Category == null ? 1 : weights[this.Category];
|
||||||
|
output.PartyFinder = this.Channel == 0;
|
||||||
|
output.Shout = this.Channel == 11 || this.Channel == 30;
|
||||||
|
output.ContainsWard = this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
|
||||||
|
output.ContainsPlot = PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
|
||||||
|
output.ContainsHousingNumbers = NumbersRegex.IsMatch(this.Message);
|
||||||
|
output.ContainsTradeWords = TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
|
||||||
|
output.ContainsSketchUrl = SketchUrlRegex.IsMatch(this.Message);
|
||||||
|
}
|
||||||
|
|
||||||
public bool ContainsTradeWords => TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
|
#endregion
|
||||||
|
|
||||||
public bool ContainsSketchUrl => SketchUrlRegex.IsMatch(this.Message);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class Prediction {
|
internal class Prediction {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user