Skip to content

Commit

Permalink
Add Ranking AutoML Sample (dotnet#852)
Browse files Browse the repository at this point in the history
* Initial add of project

* Update ranking sample

* Get sample working

* Updates based on feedback

* Add refitting on validation and test data sets

* Update console headers

* Iteration print improvements

* Correct validationData

* Printing NDCG@1,3,10 & DCG@10

* Printing NDCG@1,3,10 & DCG@10

* Add readme

* Update based on feedback

* Use new DcgTruncation property

* Update to latest AutoML package

* Review feedback

* Wording for 1st refit step

* Update to include original label in output

Co-authored-by: Justin Ormont <[email protected]>
  • Loading branch information
jwood803 and justinormont authored Jan 1, 2021
1 parent 2990008 commit 1c804f5
Show file tree
Hide file tree
Showing 8 changed files with 559 additions and 7 deletions.
33 changes: 26 additions & 7 deletions samples/csharp/common/AutoML/ConsoleHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,26 @@ public static void PrintBinaryClassificationMetrics(string name, BinaryClassific
public static void PrintMulticlassClassificationMetrics(string name, MulticlassClassificationMetrics metrics)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for {name} multi-class classification model ");
Console.WriteLine($"* Metrics for {name} multi-class classification model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value from 0 and 1, where closer to 1.0 is better");
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value from 0 and 1, where closer to 1.0 is better");
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
Console.WriteLine($"************************************************************");
}

public static void PrintRankingMetrics(string name, RankingMetrics metrics, uint optimizationMetricTruncationLevel)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for {name} ranking model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($" Normalized Discounted Cumulative Gain (NDCG@{optimizationMetricTruncationLevel}) = {metrics?.NormalizedDiscountedCumulativeGains?[(int)optimizationMetricTruncationLevel - 1] ?? double.NaN:0.####}, a value from 0 and 1, where closer to 1.0 is better");
Console.WriteLine($" Discounted Cumulative Gain (DCG@{optimizationMetricTruncationLevel}) = {metrics?.DiscountedCumulativeGains?[(int)optimizationMetricTruncationLevel - 1] ?? double.NaN:0.####}");
}

public static void ShowDataViewInConsole(MLContext mlContext, IDataView dataView, int numberOfRows = 4)
{
string msg = string.Format("Show data in DataView: Showing {0} rows with the columns", numberOfRows.ToString());
Expand Down Expand Up @@ -89,6 +98,11 @@ internal static void PrintIterationMetrics(int iteration, string trainerName, Re
CreateRow($"{iteration,-4} {trainerName,-35} {metrics?.RSquared ?? double.NaN,8:F4} {metrics?.MeanAbsoluteError ?? double.NaN,13:F2} {metrics?.MeanSquaredError ?? double.NaN,12:F2} {metrics?.RootMeanSquaredError ?? double.NaN,8:F2} {runtimeInSeconds.Value,9:F1}", Width);
}

internal static void PrintIterationMetrics(int iteration, string trainerName, RankingMetrics metrics, double? runtimeInSeconds)
{
CreateRow($"{iteration,-4} {trainerName,-15} {metrics?.NormalizedDiscountedCumulativeGains[0] ?? double.NaN,9:F4} {metrics?.NormalizedDiscountedCumulativeGains[2] ?? double.NaN,9:F4} {metrics?.NormalizedDiscountedCumulativeGains[9] ?? double.NaN,9:F4} {metrics?.DiscountedCumulativeGains[9] ?? double.NaN,9:F4} {runtimeInSeconds.Value,9:F1}", Width);
}

internal static void PrintIterationException(Exception ex)
{
Console.WriteLine($"Exception during AutoML iteration: {ex}");
Expand All @@ -109,6 +123,11 @@ internal static void PrintRegressionMetricsHeader()
CreateRow($"{"",-4} {"Trainer",-35} {"RSquared",8} {"Absolute-loss",13} {"Squared-loss",12} {"RMS-loss",8} {"Duration",9}", Width);
}

internal static void PrintRankingMetricsHeader()
{
CreateRow($"{"",-4} {"Trainer",-15} {"NDCG@1",9} {"NDCG@3",9} {"NDCG@10",9} {"DCG@10",9} {"Duration",9}", Width);
}

private static void CreateRow(string message, int width)
{
Console.WriteLine("|" + message.PadRight(width - 2) + "|");
Expand Down Expand Up @@ -239,10 +258,10 @@ private void AppendTableRow(ICollection<string[]> tableRows,

tableRows.Add(new[]
{
columnName,
GetColumnDataType(columnName),
columnPurpose
});
columnName,
GetColumnDataType(columnName),
columnPurpose
});
}

private void AppendTableRows(ICollection<string[]> tableRows,
Expand Down
23 changes: 23 additions & 0 deletions samples/csharp/common/AutoML/ProgressHandlers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,27 @@ public void Report(RunDetail<RegressionMetrics> iterationResult)
}
}
}

public class RankingExperimentProgressHandler : IProgress<RunDetail<RankingMetrics>>
{
private int _iterationIndex;

public void Report(RunDetail<RankingMetrics> iterationResult)
{
if (_iterationIndex++ == 0)
{
ConsoleHelper.PrintRankingMetricsHeader();
}

if (iterationResult.Exception != null)
{
ConsoleHelper.PrintIterationException(iterationResult.Exception);
}
else
{
ConsoleHelper.PrintIterationMetrics(_iterationIndex, iterationResult.TrainerName,
iterationResult.ValidationMetrics, iterationResult.RuntimeInSeconds);
}
}
}
}
Loading

0 comments on commit 1c804f5

Please sign in to comment.