Skip to content

Commit d2bf3e7

Browse files
authored
Restore OVA ability to preserve key names on predicted label (dotnet#3101)
1 parent e5cbca7 commit d2bf3e7

File tree

7 files changed

+105
-39
lines changed

7 files changed

+105
-39
lines changed

src/Microsoft.ML.Core/Data/AnnotationUtils.cs

+10-2
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,16 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co
441441
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
442442
{
443443
var cols = new List<SchemaShape.Column>();
444-
if (labelColumn != null && labelColumn.Value.IsKey && NeedsSlotNames(labelColumn.Value))
445-
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
444+
if (labelColumn != null && labelColumn.Value.IsKey)
445+
{
446+
if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
447+
metaCol.Kind == SchemaShape.Column.VectorKind.Vector)
448+
{
449+
if (metaCol.ItemType is TextDataViewType)
450+
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
451+
cols.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector, metaCol.ItemType, false));
452+
}
453+
}
446454
cols.AddRange(GetTrainerOutputAnnotation());
447455
return cols;
448456
}

src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs

+40-10
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun
390390
if (trainSchema?.Label == null)
391391
return mapper; // We don't even have a label identified in a training schema.
392392
var keyType = trainSchema.Label.Value.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
393-
if (keyType == null || !CanWrap(mapper, keyType))
393+
if (keyType == null)
394394
return mapper;
395395

396396
// Great!! All checks pass.
@@ -409,11 +409,19 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun
409409
/// from the model of a bindable mapper)</param>
410410
/// <returns>Whether we can call <see cref="LabelNameBindableMapper.CreateBound{T}"/> with
411411
/// this mapper and expect it to succeed</returns>
412-
internal static bool CanWrap(ISchemaBoundMapper mapper, DataViewType labelNameType)
412+
internal static bool CanWrapTrainingLabels(ISchemaBoundMapper mapper, DataViewType labelNameType)
413+
{
414+
if (GetTypesForWrapping(mapper, labelNameType, AnnotationUtils.Kinds.TrainingLabelValues, out var scoreType))
415+
// Check that the type is vector, and is of compatible size with the score output.
416+
return labelNameType is VectorDataViewType vectorType && vectorType.Size == scoreType.GetVectorSize();
417+
return false;
418+
}
419+
420+
internal static bool GetTypesForWrapping(ISchemaBoundMapper mapper, DataViewType labelNameType, string metaKind, out DataViewType scoreType)
413421
{
414422
Contracts.AssertValue(mapper);
415423
Contracts.AssertValue(labelNameType);
416-
424+
scoreType = null;
417425
ISchemaBoundRowMapper rowMapper = mapper as ISchemaBoundRowMapper;
418426
if (rowMapper == null)
419427
return false; // We could cover this case, but it is of no practical worth as far as I see, so I decline to do so.
@@ -423,12 +431,30 @@ internal static bool CanWrap(ISchemaBoundMapper mapper, DataViewType labelNameTy
423431
var scoreCol = outSchema.GetColumnOrNull(AnnotationUtils.Const.ScoreValueKind.Score);
424432
if (!outSchema.TryGetColumnIndex(AnnotationUtils.Const.ScoreValueKind.Score, out scoreIdx))
425433
return false; // The mapper doesn't even publish a score column to attach the metadata to.
426-
if (outSchema[scoreIdx].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type != null)
427-
return false; // The mapper publishes a score column, and already produces its own slot names.
428-
var scoreType = outSchema[scoreIdx].Type;
434+
if (outSchema[scoreIdx].Annotations.Schema.GetColumnOrNull(metaKind)?.Type != null)
435+
return false; // The mapper publishes a score column, and already produces its own metakind.
436+
scoreType = outSchema[scoreIdx].Type;
437+
return true;
438+
}
429439

430-
// Check that the type is vector, and is of compatible size with the score output.
431-
return labelNameType is VectorDataViewType vectorType && vectorType.Size == scoreType.GetVectorSize() && vectorType.ItemType == TextDataViewType.Instance;
440+
/// <summary>
441+
/// This is a utility method used to determine whether <see cref="LabelNameBindableMapper"/>
442+
/// can or should be used to wrap <paramref name="mapper"/>. This will not throw, since the
443+
/// desired behavior in the event that it cannot be wrapped, is to just back off to the original
444+
/// "unwrapped" bound mapper.
445+
/// </summary>
446+
/// <param name="mapper">The mapper we are seeing if we can wrap</param>
447+
/// <param name="labelNameType">The type of the label names from the metadata (either
448+
/// originating from the key value metadata of the training label column, or deserialized
449+
/// from the model of a bindable mapper)</param>
450+
/// <returns>Whether we can call <see cref="LabelNameBindableMapper.CreateBound{T}"/> with
451+
/// this mapper and expect it to succeed</returns>
452+
internal static bool CanWrapSlotNames(ISchemaBoundMapper mapper, DataViewType labelNameType)
453+
{
454+
if (GetTypesForWrapping(mapper, labelNameType, AnnotationUtils.Kinds.SlotNames, out var scoreType))
455+
// Check that the type is vector, and is of compatible size with the score output.
456+
return labelNameType is VectorDataViewType vectorType && vectorType.Size == scoreType.GetVectorSize() && vectorType.ItemType == TextDataViewType.Instance;
457+
return false;
432458
}
433459

434460
internal static ISchemaBoundMapper WrapCore<T>(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
@@ -449,8 +475,12 @@ internal static ISchemaBoundMapper WrapCore<T>(IHostEnvironment env, ISchemaBoun
449475
{
450476
trainSchema.Label.Value.GetKeyValues(ref value);
451477
};
452-
453-
return LabelNameBindableMapper.CreateBound<T>(env, (ISchemaBoundRowMapper)mapper, type as VectorDataViewType, getter, AnnotationUtils.Kinds.SlotNames, CanWrap);
478+
var resultMapper = mapper;
479+
if (CanWrapTrainingLabels(resultMapper, type))
480+
resultMapper = LabelNameBindableMapper.CreateBound<T>(env, (ISchemaBoundRowMapper)resultMapper, type as VectorDataViewType, getter, AnnotationUtils.Kinds.TrainingLabelValues, CanWrapTrainingLabels);
481+
if (CanWrapSlotNames(resultMapper, type))
482+
resultMapper = LabelNameBindableMapper.CreateBound<T>(env, (ISchemaBoundRowMapper)resultMapper, type as VectorDataViewType, getter, AnnotationUtils.Kinds.SlotNames, CanWrapSlotNames);
483+
return resultMapper;
454484
}
455485

456486
[BestFriend]

src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs

+5-15
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,12 @@ private BindingsImpl(DataViewSchema input, ISchemaBoundRowMapper mapper, string
6262
{
6363
var scoreColMetadata = mapper.OutputSchema[scoreColIndex].Annotations;
6464

65-
var slotColumn = scoreColMetadata.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames);
66-
if (slotColumn?.Type is VectorDataViewType slotColVecType && (ulong)slotColVecType.Size == predColKeyType.Count)
65+
var trainLabelColumn = scoreColMetadata.Schema.GetColumnOrNull(AnnotationUtils.Kinds.TrainingLabelValues);
66+
if (trainLabelColumn?.Type is VectorDataViewType trainLabelColVecType && (ulong)trainLabelColVecType.Size == predColKeyType.Count)
6767
{
68-
Contracts.Assert(slotColVecType.Size > 0);
69-
_predColMetadata = Utils.MarshalInvoke(KeyValueMetadataFromMetadata<int>, slotColVecType.RawType,
70-
scoreColMetadata, slotColumn.Value);
71-
}
72-
else
73-
{
74-
var trainLabelColumn = scoreColMetadata.Schema.GetColumnOrNull(AnnotationUtils.Kinds.TrainingLabelValues);
75-
if (trainLabelColumn?.Type is VectorDataViewType trainLabelColVecType && (ulong)trainLabelColVecType.Size == predColKeyType.Count)
76-
{
77-
Contracts.Assert(trainLabelColVecType.Size > 0);
78-
_predColMetadata = Utils.MarshalInvoke(KeyValueMetadataFromMetadata<int>, trainLabelColVecType.RawType,
79-
scoreColMetadata, trainLabelColumn.Value);
80-
}
68+
Contracts.Assert(trainLabelColVecType.Size > 0);
69+
_predColMetadata = Utils.MarshalInvoke(KeyValueMetadataFromMetadata<int>, trainLabelColVecType.RawType,
70+
scoreColMetadata, trainLabelColumn.Value);
8171
}
8272
}
8373
}

src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ private protected override NaiveBayesMulticlassModelParameters TrainModelCore(Tr
131131
int size = cursor.Label + 1;
132132
Utils.EnsureSize(ref labelHistogram, size);
133133
Utils.EnsureSize(ref featureHistogram, size);
134-
Utils.EnsureSize(ref featureHistogram[cursor.Label], featureCount);
134+
if (featureHistogram[cursor.Label] == null)
135+
featureHistogram[cursor.Label] = new int[featureCount];
135136
labelHistogram[cursor.Label] += 1;
136137
labelCount = labelCount < size ? size : labelCount;
137138

test/Microsoft.ML.Functional.Tests/Training.cs

+35-1
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ public void ContinueTrainingSymbolicStochasticGradientDescent()
438438
}
439439

440440
/// <summary>
441-
/// Training: Meta-compononts function as expected. For OVA (one-versus-all), a user will be able to specify only
441+
/// Training: Meta-components function as expected. For OVA (one-versus-all), a user will be able to specify only
442442
/// binary classifier trainers. If they specify a different model class there should be a compile error.
443443
/// </summary>
444444
[Fact]
@@ -467,5 +467,39 @@ public void MetacomponentsFunctionAsExpectedOva()
467467
// Evaluate the model.
468468
var binaryClassificationMetrics = mlContext.MulticlassClassification.Evaluate(binaryClassificationPredictions);
469469
}
470+
471+
/// <summary>
472+
/// Training: Meta-components function as expected. For OVA (one-versus-all), a user will be able to specify only
473+
/// binary classifier trainers. If they specify a different model class there should be a compile error.
474+
/// </summary>
475+
[Fact]
476+
public void MetacomponentsFunctionWithKeyHandling()
477+
{
478+
var mlContext = new MLContext(seed: 1);
479+
480+
var data = mlContext.Data.LoadFromTextFile<Iris>(GetDataPath(TestDatasets.iris.trainFilename),
481+
hasHeader: TestDatasets.iris.fileHasHeader,
482+
separatorChar: TestDatasets.iris.fileSeparator);
483+
484+
// Create a model training an OVA trainer with a binary classifier.
485+
var binaryClassificationTrainer = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(
486+
new LbfgsLogisticRegressionBinaryTrainer.Options { MaximumNumberOfIterations = 10, NumberOfThreads = 1, });
487+
var binaryClassificationPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features)
488+
.AppendCacheCheckpoint(mlContext)
489+
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
490+
.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryClassificationTrainer))
491+
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
492+
493+
// Fit the binary classification pipeline.
494+
var binaryClassificationModel = binaryClassificationPipeline.Fit(data);
495+
496+
// Transform the data
497+
var binaryClassificationPredictions = binaryClassificationModel.Transform(data);
498+
499+
// Evaluate the model.
500+
var binaryClassificationMetrics = mlContext.MulticlassClassification.Evaluate(binaryClassificationPredictions);
501+
502+
Assert.Equal(0.4367, binaryClassificationMetrics.LogLoss, 4);
503+
}
470504
}
471505
}

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs

+8-8
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,22 @@ void PredictAndMetadata()
3737

3838
var testLoader = ml.Data.LoadFromTextFile(dataPath, TestDatasets.irisData.GetLoaderColumns(), separatorChar: ',', hasHeader: true);
3939
var testData = ml.Data.CreateEnumerable<IrisData>(testLoader, false);
40-
40+
4141
// During prediction we will get Score column with 3 float values.
4242
// We need to find way to map each score to original label.
43-
// In order to do what we need to get SlotNames from Score column.
44-
// Slot names on top of Score column represent original labels for i-th value in Score array.
45-
VBuffer<ReadOnlyMemory<char>> slotNames = default;
46-
engine.OutputSchema[nameof(IrisPrediction.Score)].GetSlotNames(ref slotNames);
43+
// In order to do what we need to get TrainingLabelValues from Score column.
44+
// TrainingLabelValues on top of Score column represent original labels for i-th value in Score array.
45+
VBuffer<ReadOnlyMemory<char>> originalLabels = default;
46+
engine.OutputSchema[nameof(IrisPrediction.Score)].Annotations.GetValue(AnnotationUtils.Kinds.TrainingLabelValues, ref originalLabels);
4747
// Since we apply MapValueToKey estimator with default parameters, key values
4848
// depends on order of occurence in data file. Which is "Iris-setosa", "Iris-versicolor", "Iris-virginica"
4949
// So if we have Score column equal to [0.2, 0.3, 0.5] that's mean what score for
5050
// Iris-setosa is 0.2
5151
// Iris-versicolor is 0.3
5252
// Iris-virginica is 0.5.
53-
Assert.True(slotNames.GetItemOrDefault(0).ToString() == "Iris-setosa");
54-
Assert.True(slotNames.GetItemOrDefault(1).ToString() == "Iris-versicolor");
55-
Assert.True(slotNames.GetItemOrDefault(2).ToString() == "Iris-virginica");
53+
Assert.Equal("Iris-setosa", originalLabels.GetItemOrDefault(0).ToString());
54+
Assert.Equal("Iris-versicolor", originalLabels.GetItemOrDefault(1).ToString());
55+
Assert.Equal("Iris-virginica", originalLabels.GetItemOrDefault(2).ToString());
5656

5757
// Let's look how we can convert key value for PredictedLabel to original labels.
5858
// We need to read KeyValues for "PredictedLabel" column.

test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.Linq;
56
using Microsoft.ML.Calibrators;
67
using Microsoft.ML.Data;
78
using Microsoft.ML.RunTests;
@@ -83,12 +84,14 @@ public void MetacomponentsFeaturesRenamed()
8384
var data = loader.Load(GetDataPath(TestDatasets.irisData.trainFilename));
8485

8586
var sdcaTrainer = ML.BinaryClassification.Trainers.SdcaNonCalibrated(
86-
new SdcaNonCalibratedBinaryTrainer.Options {
87+
new SdcaNonCalibratedBinaryTrainer.Options
88+
{
8789
LabelColumnName = "Label",
8890
FeatureColumnName = "Vars",
8991
MaximumNumberOfIterations = 100,
9092
Shuffle = true,
91-
NumberOfThreads = 1, });
93+
NumberOfThreads = 1,
94+
});
9295

9396
var pipeline = new ColumnConcatenatingEstimator(Env, "Vars", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
9497
.Append(new ValueToKeyMappingEstimator(Env, "Label"), TransformerScope.TrainTest)

0 commit comments

Comments
 (0)