@@ -390,7 +390,7 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun
390
390
if ( trainSchema ? . Label == null )
391
391
return mapper ; // We don't even have a label identified in a training schema.
392
392
var keyType = trainSchema . Label . Value . Annotations . Schema . GetColumnOrNull ( AnnotationUtils . Kinds . KeyValues ) ? . Type as VectorDataViewType ;
393
- if ( keyType == null || ! CanWrap ( mapper , keyType ) )
393
+ if ( keyType == null )
394
394
return mapper ;
395
395
396
396
// Great!! All checks pass.
@@ -409,11 +409,19 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun
409
409
/// from the model of a bindable mapper)</param>
410
410
/// <returns>Whether we can call <see cref="LabelNameBindableMapper.CreateBound{T}"/> with
411
411
/// this mapper and expect it to succeed</returns>
412
- internal static bool CanWrap ( ISchemaBoundMapper mapper , DataViewType labelNameType )
412
+ internal static bool CanWrapTrainingLabels ( ISchemaBoundMapper mapper , DataViewType labelNameType )
413
+ {
414
+ if ( GetTypesForWrapping ( mapper , labelNameType , AnnotationUtils . Kinds . TrainingLabelValues , out var scoreType ) )
415
+ // Check that the type is vector, and is of compatible size with the score output.
416
+ return labelNameType is VectorDataViewType vectorType && vectorType . Size == scoreType . GetVectorSize ( ) ;
417
+ return false ;
418
+ }
419
+
420
+ internal static bool GetTypesForWrapping ( ISchemaBoundMapper mapper , DataViewType labelNameType , string metaKind , out DataViewType scoreType )
413
421
{
414
422
Contracts . AssertValue ( mapper ) ;
415
423
Contracts . AssertValue ( labelNameType ) ;
416
-
424
+ scoreType = null ;
417
425
ISchemaBoundRowMapper rowMapper = mapper as ISchemaBoundRowMapper ;
418
426
if ( rowMapper == null )
419
427
return false ; // We could cover this case, but it is of no practical worth as far as I see, so I decline to do so.
@@ -423,12 +431,30 @@ internal static bool CanWrap(ISchemaBoundMapper mapper, DataViewType labelNameTy
423
431
var scoreCol = outSchema . GetColumnOrNull ( AnnotationUtils . Const . ScoreValueKind . Score ) ;
424
432
if ( ! outSchema . TryGetColumnIndex ( AnnotationUtils . Const . ScoreValueKind . Score , out scoreIdx ) )
425
433
return false ; // The mapper doesn't even publish a score column to attach the metadata to.
426
- if ( outSchema [ scoreIdx ] . Annotations . Schema . GetColumnOrNull ( AnnotationUtils . Kinds . SlotNames ) ? . Type != null )
427
- return false ; // The mapper publishes a score column, and already produces its own slot names.
428
- var scoreType = outSchema [ scoreIdx ] . Type ;
434
+ if ( outSchema [ scoreIdx ] . Annotations . Schema . GetColumnOrNull ( metaKind ) ? . Type != null )
435
+ return false ; // The mapper publishes a score column, and already produces its own metakind.
436
+ scoreType = outSchema [ scoreIdx ] . Type ;
437
+ return true ;
438
+ }
429
439
430
- // Check that the type is vector, and is of compatible size with the score output.
431
- return labelNameType is VectorDataViewType vectorType && vectorType . Size == scoreType . GetVectorSize ( ) && vectorType . ItemType == TextDataViewType . Instance ;
440
+ /// <summary>
441
+ /// This is a utility method used to determine whether <see cref="LabelNameBindableMapper"/>
442
+ /// can or should be used to wrap <paramref name="mapper"/>. This will not throw, since the
443
+ /// desired behavior in the event that it cannot be wrapped, is to just back off to the original
444
+ /// "unwrapped" bound mapper.
445
+ /// </summary>
446
+ /// <param name="mapper">The mapper we are seeing if we can wrap</param>
447
+ /// <param name="labelNameType">The type of the label names from the metadata (either
448
+ /// originating from the key value metadata of the training label column, or deserialized
449
+ /// from the model of a bindable mapper)</param>
450
+ /// <returns>Whether we can call <see cref="LabelNameBindableMapper.CreateBound{T}"/> with
451
+ /// this mapper and expect it to succeed</returns>
452
+ internal static bool CanWrapSlotNames ( ISchemaBoundMapper mapper , DataViewType labelNameType )
453
+ {
454
+ if ( GetTypesForWrapping ( mapper , labelNameType , AnnotationUtils . Kinds . SlotNames , out var scoreType ) )
455
+ // Check that the type is vector, and is of compatible size with the score output.
456
+ return labelNameType is VectorDataViewType vectorType && vectorType . Size == scoreType . GetVectorSize ( ) && vectorType . ItemType == TextDataViewType . Instance ;
457
+ return false ;
432
458
}
433
459
434
460
internal static ISchemaBoundMapper WrapCore < T > ( IHostEnvironment env , ISchemaBoundMapper mapper , RoleMappedSchema trainSchema )
@@ -449,8 +475,12 @@ internal static ISchemaBoundMapper WrapCore<T>(IHostEnvironment env, ISchemaBoun
449
475
{
450
476
trainSchema . Label . Value . GetKeyValues ( ref value ) ;
451
477
} ;
452
-
453
- return LabelNameBindableMapper . CreateBound < T > ( env , ( ISchemaBoundRowMapper ) mapper , type as VectorDataViewType , getter , AnnotationUtils . Kinds . SlotNames , CanWrap ) ;
478
+ var resultMapper = mapper ;
479
+ if ( CanWrapTrainingLabels ( resultMapper , type ) )
480
+ resultMapper = LabelNameBindableMapper . CreateBound < T > ( env , ( ISchemaBoundRowMapper ) resultMapper , type as VectorDataViewType , getter , AnnotationUtils . Kinds . TrainingLabelValues , CanWrapTrainingLabels ) ;
481
+ if ( CanWrapSlotNames ( resultMapper , type ) )
482
+ resultMapper = LabelNameBindableMapper . CreateBound < T > ( env , ( ISchemaBoundRowMapper ) resultMapper , type as VectorDataViewType , getter , AnnotationUtils . Kinds . SlotNames , CanWrapSlotNames ) ;
483
+ return resultMapper ;
454
484
}
455
485
456
486
[ BestFriend ]
0 commit comments