feat(trainer): add trainer to actual repo
This commit is contained in:
parent
431f330229
commit
bd05abb5e0
|
@ -0,0 +1 @@
|
|||
model.zip
|
|
@ -0,0 +1,15 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
<TargetFramework>netcoreapp3.1</TargetFramework>
|
||||
<Nullable>enable</Nullable>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="ConsoleTables" Version="2.4.2" />
|
||||
<PackageReference Include="CsvHelper" Version="18.0.0" />
|
||||
<PackageReference Include="Microsoft.ML" Version="1.5.4" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
|
@ -0,0 +1,231 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Globalization;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text.RegularExpressions;
|
||||
using ConsoleTables;
|
||||
using CsvHelper;
|
||||
using CsvHelper.Configuration;
|
||||
using CsvHelper.Configuration.Attributes;
|
||||
using Microsoft.ML;
|
||||
using Microsoft.ML.Data;
|
||||
using Microsoft.ML.Transforms.Text;
|
||||
|
||||
namespace NoSoliciting.Trainer {
|
||||
internal static class Program {
|
||||
private static void Main(string[] args) {
|
||||
var full = args[0] == "create";
|
||||
|
||||
var ctx = new MLContext(1);
|
||||
|
||||
using var fileStream = new FileStream("../../../data.csv", FileMode.Open);
|
||||
using var stream = new StreamReader(fileStream);
|
||||
using var csv = new CsvReader(stream, new CsvConfiguration(CultureInfo.InvariantCulture) {
|
||||
HeaderValidated = null,
|
||||
});
|
||||
|
||||
var records = csv.GetRecords<Data>().ToList();
|
||||
var classes = new Dictionary<string, uint>();
|
||||
|
||||
foreach (var record in records) {
|
||||
// normalise the message
|
||||
record.Message = Util.Normalise(record.Message);
|
||||
|
||||
// keep track of how many message of each category we have
|
||||
if (!classes.ContainsKey(record.Category!)) {
|
||||
classes[record.Category] = 0;
|
||||
}
|
||||
|
||||
classes[record.Category] += 1;
|
||||
}
|
||||
|
||||
// calculate class weights
|
||||
var weights = new Dictionary<string, float>();
|
||||
foreach (var (category, count) in classes) {
|
||||
var nSamples = (float) records.Count;
|
||||
var nClasses = (float) classes.Count;
|
||||
var nSamplesJ = (float) count;
|
||||
|
||||
var w = nSamples / (nClasses * nSamplesJ);
|
||||
|
||||
weights[category] = w;
|
||||
}
|
||||
|
||||
// apply class weights
|
||||
foreach (var record in records) {
|
||||
record.Weight = weights[record.Category!];
|
||||
}
|
||||
|
||||
var df = ctx.Data.LoadFromEnumerable(records);
|
||||
|
||||
var ttd = ctx.Data.TrainTestSplit(df, 0.2, seed: 1);
|
||||
|
||||
var pipeline = ctx.Transforms.Conversion.MapValueToKey("Label", nameof(Data.Category))
|
||||
.Append(ctx.Transforms.Text.NormalizeText("MsgNormal", nameof(Data.Message), keepPunctuations: false))
|
||||
.Append(ctx.Transforms.Text.TokenizeIntoWords("MsgTokens", "MsgNormal"))
|
||||
// .Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgTokens",
|
||||
// "the",
|
||||
// "a",
|
||||
// "of",
|
||||
// "in",
|
||||
// "for",
|
||||
// "from",
|
||||
// "and",
|
||||
// "discord"
|
||||
// ))
|
||||
.Append(ctx.Transforms.Text.RemoveDefaultStopWords("MsgNoDefStop", "MsgTokens"))
|
||||
.Append(ctx.Transforms.Text.RemoveStopWords("MsgNoStop", "MsgNoDefStop",
|
||||
"discord"
|
||||
))
|
||||
.Append(ctx.Transforms.Conversion.MapValueToKey("MsgKey", "MsgNoStop"))
|
||||
.Append(ctx.Transforms.Text.ProduceNgrams("MsgNgrams", "MsgKey", weighting: NgramExtractingEstimator.WeightingCriteria.Tf))
|
||||
.Append(ctx.Transforms.NormalizeLpNorm("FeaturisedMessage", "MsgNgrams"))
|
||||
.Append(ctx.Transforms.Conversion.ConvertType("CPartyFinder", "PartyFinder"))
|
||||
.Append(ctx.Transforms.Conversion.ConvertType("CShout", "Shout"))
|
||||
.Append(ctx.Transforms.Conversion.ConvertType("CTrade", "ContainsTradeWords"))
|
||||
.Append(ctx.Transforms.Conversion.ConvertType("CSketch", "ContainsSketchUrl"))
|
||||
.Append(ctx.Transforms.Conversion.ConvertType("HasWard", "ContainsWard"))
|
||||
.Append(ctx.Transforms.Conversion.ConvertType("HasPlot", "ContainsPlot"))
|
||||
.Append(ctx.Transforms.Conversion.ConvertType("HasNumbers", "ContainsHousingNumbers"))
|
||||
.Append(ctx.Transforms.Concatenate("Features", "FeaturisedMessage", "CPartyFinder", "CShout", "CTrade", "HasWard", "HasPlot", "HasNumbers", "CSketch"))
|
||||
.Append(ctx.MulticlassClassification.Trainers.SdcaMaximumEntropy(exampleWeightColumnName: "Weight"))
|
||||
.Append(ctx.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
|
||||
|
||||
var train = full ? df : ttd.TrainSet;
|
||||
|
||||
var model = pipeline.Fit(train);
|
||||
|
||||
ctx.Model.Save(model, df.Schema, @"../../../model.zip");
|
||||
|
||||
var testPredictions = model.Transform(ttd.TestSet);
|
||||
var eval = ctx.MulticlassClassification.Evaluate(testPredictions);
|
||||
|
||||
var predEngine = ctx.Model.CreatePredictionEngine<Data, Prediction>(model);
|
||||
|
||||
var slotNames = new VBuffer<ReadOnlyMemory<char>>();
|
||||
predEngine.OutputSchema["Score"].GetSlotNames(ref slotNames);
|
||||
var names = slotNames.DenseValues()
|
||||
.Select(column => column.ToString())
|
||||
.ToList();
|
||||
|
||||
var cols = new string[1 + names.Count];
|
||||
cols[0] = "";
|
||||
for (var j = 0; j < names.Count; j++) {
|
||||
cols[j + 1] = names[j];
|
||||
}
|
||||
|
||||
var table = new ConsoleTable(cols);
|
||||
|
||||
for (var i = 0; i < names.Count; i++) {
|
||||
var name = names[i];
|
||||
var confuse = eval.ConfusionMatrix.Counts[i];
|
||||
|
||||
var row = new object[1 + confuse.Count];
|
||||
row[0] = name;
|
||||
for (var j = 0; j < confuse.Count; j++) {
|
||||
row[j + 1] = confuse[j];
|
||||
}
|
||||
|
||||
table.AddRow(row);
|
||||
}
|
||||
|
||||
Console.WriteLine(table.ToString());
|
||||
|
||||
Console.WriteLine($"Log loss : {eval.LogLoss * 100}");
|
||||
Console.WriteLine($"Macro acc: {eval.MacroAccuracy * 100}");
|
||||
Console.WriteLine($"Micro acc: {eval.MicroAccuracy * 100}");
|
||||
|
||||
if (full) {
|
||||
return;
|
||||
}
|
||||
|
||||
while (true) {
|
||||
var msg = Console.ReadLine()!.Trim();
|
||||
|
||||
var parts = msg.Split(' ', 2);
|
||||
|
||||
ushort.TryParse(parts[0], out var channel);
|
||||
|
||||
var input = new Data {
|
||||
Channel = channel,
|
||||
// PartyFinder = channel == 0,
|
||||
Message = parts[1],
|
||||
};
|
||||
|
||||
var pred = predEngine.Predict(input);
|
||||
|
||||
Console.WriteLine(pred.Category);
|
||||
for (var i = 0; i < names.Count; i++) {
|
||||
Console.WriteLine($" {names[i]}: {pred.Probabilities[i] * 100}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[SuppressMessage("ReSharper", "UnusedMember.Global")]
|
||||
internal class Data {
|
||||
private static readonly Regex WardRegex = new Regex(@"w.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||
|
||||
private static readonly Regex PlotRegex = new Regex(@"p.{0,2}\d", RegexOptions.Compiled | RegexOptions.IgnoreCase);
|
||||
private static readonly string[] PlotWords = {
|
||||
"plot",
|
||||
"apartment",
|
||||
"apt",
|
||||
};
|
||||
|
||||
private static readonly Regex NumbersRegex = new Regex(@"\d{1,2}.{0,2}\d{1,2}", RegexOptions.Compiled);
|
||||
|
||||
private static readonly string[] TradeWords = {
|
||||
"B> ",
|
||||
"S> ",
|
||||
"buy",
|
||||
"sell",
|
||||
"WTB",
|
||||
"WTS",
|
||||
};
|
||||
|
||||
private static readonly Regex SketchUrlRegex = new Regex(@"\.com-\w+\.\w+", RegexOptions.IgnoreCase | RegexOptions.Compiled);
|
||||
|
||||
[LoadColumn(0), Index(0)]
|
||||
public string? Category { get; set; }
|
||||
|
||||
[LoadColumn(1), Index(1)]
|
||||
public uint Channel { get; set; }
|
||||
|
||||
[LoadColumn(2), Index(2)]
|
||||
public string Message { get; set; }
|
||||
|
||||
[Ignore]
|
||||
public float Weight { get; set; } = 1;
|
||||
|
||||
public bool PartyFinder => this.Channel == 0;
|
||||
|
||||
public bool Shout => this.Channel == 11 || this.Channel == 30;
|
||||
|
||||
public bool ContainsWard => this.Message.ContainsIgnoreCase("ward") || WardRegex.IsMatch(this.Message);
|
||||
|
||||
public bool ContainsPlot => PlotWords.Any(word => this.Message.ContainsIgnoreCase(word)) || PlotRegex.IsMatch(this.Message);
|
||||
|
||||
public bool ContainsHousingNumbers => NumbersRegex.IsMatch(this.Message);
|
||||
|
||||
public bool ContainsTradeWords => TradeWords.Any(word => this.Message.ContainsIgnoreCase(word));
|
||||
|
||||
public bool ContainsSketchUrl => SketchUrlRegex.IsMatch(this.Message);
|
||||
}
|
||||
|
||||
internal class Prediction {
|
||||
[ColumnName("PredictedLabel")]
|
||||
public string Category { get; set; }
|
||||
|
||||
[ColumnName("Score")]
|
||||
public float[] Probabilities { get; set; }
|
||||
}
|
||||
|
||||
internal static class Ext {
|
||||
public static bool ContainsIgnoreCase(this string haystack, string needle) {
|
||||
return CultureInfo.InvariantCulture.CompareInfo.IndexOf(haystack, needle, CompareOptions.IgnoreCase) >= 0;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
|
||||
namespace NoSoliciting.Trainer {
|
||||
public static class Util {
|
||||
private static readonly Dictionary<char, string> Replacements = new Dictionary<char, string> {
|
||||
// numerals
|
||||
['\ue055'] = "1",
|
||||
['\ue056'] = "2",
|
||||
['\ue057'] = "3",
|
||||
['\ue058'] = "4",
|
||||
['\ue059'] = "5",
|
||||
['\ue099'] = "10",
|
||||
['\ue09a'] = "11",
|
||||
['\ue09b'] = "12",
|
||||
['\ue09c'] = "13",
|
||||
['\ue09d'] = "14",
|
||||
['\ue09e'] = "15",
|
||||
['\ue09f'] = "16",
|
||||
['\ue0a0'] = "17",
|
||||
['\ue0a1'] = "18",
|
||||
['\ue0a2'] = "19",
|
||||
['\ue0a3'] = "20",
|
||||
['\ue0a4'] = "21",
|
||||
['\ue0a5'] = "22",
|
||||
['\ue0a6'] = "23",
|
||||
['\ue0a7'] = "24",
|
||||
['\ue0a8'] = "25",
|
||||
['\ue0a9'] = "26",
|
||||
['\ue0aa'] = "27",
|
||||
['\ue0ab'] = "28",
|
||||
['\ue0ac'] = "29",
|
||||
['\ue0ad'] = "30",
|
||||
['\ue0ae'] = "31",
|
||||
|
||||
// symbols
|
||||
['\ue0af'] = "+",
|
||||
['\ue070'] = "?",
|
||||
|
||||
// letters in other sets
|
||||
['\ue022'] = "A",
|
||||
['\ue024'] = "_A",
|
||||
['\ue0b0'] = "E",
|
||||
};
|
||||
|
||||
private const char LowestReplacement = '\ue022';
|
||||
|
||||
public static string Normalise(string input) {
|
||||
if (input == null) {
|
||||
throw new ArgumentNullException(nameof(input), "input cannot be null");
|
||||
}
|
||||
|
||||
// replace ffxiv private use chars
|
||||
var builder = new StringBuilder(input.Length);
|
||||
foreach (var c in input) {
|
||||
if (c < LowestReplacement) {
|
||||
goto AppendNormal;
|
||||
}
|
||||
|
||||
// alphabet
|
||||
if (c >= 0xe071 && c <= 0xe08a) {
|
||||
builder.Append((char) (c - 0xe030));
|
||||
continue;
|
||||
}
|
||||
|
||||
// 0 to 9
|
||||
if (c >= 0xe060 && c <= 0xe069) {
|
||||
builder.Append((char) (c - 0xe030));
|
||||
continue;
|
||||
}
|
||||
|
||||
// 1 to 9
|
||||
if (c >= 0xe0b1 && c <= 0xe0b9) {
|
||||
builder.Append((char) (c - 0xe080));
|
||||
continue;
|
||||
}
|
||||
|
||||
// 1 to 9 again
|
||||
if (c >= 0xe090 && c <= 0xe098) {
|
||||
builder.Append((char) (c - 0xe05f));
|
||||
continue;
|
||||
}
|
||||
|
||||
// replacements in map
|
||||
if (Replacements.TryGetValue(c, out var rep)) {
|
||||
builder.Append(rep);
|
||||
continue;
|
||||
}
|
||||
|
||||
AppendNormal:
|
||||
builder.Append(c);
|
||||
}
|
||||
|
||||
input = builder.ToString();
|
||||
|
||||
// NFKD unicode normalisation
|
||||
return input.Normalize(NormalizationForm.FormKD);
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -13,6 +13,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "NoSoliciting.CursedWorkarou
|
|||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "NoSoliciting.Interface", "NoSoliciting.Interface\NoSoliciting.Interface.csproj", "{E88E57AB-EFB8-4F2F-93DB-F63123638C44}"
|
||||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "NoSoliciting.Trainer", "NoSoliciting.Trainer\NoSoliciting.Trainer.csproj", "{3D774127-F7A9-4B6D-AB2F-3AAF80D15586}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
Debug|Any CPU = Debug|Any CPU
|
||||
|
@ -35,6 +37,10 @@ Global
|
|||
{E88E57AB-EFB8-4F2F-93DB-F63123638C44}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{E88E57AB-EFB8-4F2F-93DB-F63123638C44}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{E88E57AB-EFB8-4F2F-93DB-F63123638C44}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{3D774127-F7A9-4B6D-AB2F-3AAF80D15586}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{3D774127-F7A9-4B6D-AB2F-3AAF80D15586}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{3D774127-F7A9-4B6D-AB2F-3AAF80D15586}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{3D774127-F7A9-4B6D-AB2F-3AAF80D15586}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
EndGlobalSection
|
||||
GlobalSection(SolutionProperties) = preSolution
|
||||
HideSolutionNode = FALSE
|
||||
|
|
Loading…
Reference in New Issue