@@ -51,15 +51,17 @@ def resize_by_area(img, size):
51
51
52
52
class ImageProblem (problem .Problem ):
53
53
54
- def example_reading_spec (self , label_key = None ):
55
- if label_key is None :
56
- label_key = "image/class/label"
54
+ def example_reading_spec (self , label_repr = None ):
55
+ if label_repr is None :
56
+ label_repr = ( "image/class/label" , tf . FixedLenFeature (( 1 ,), tf . int64 ))
57
57
58
58
data_fields = {
59
59
"image/encoded" : tf .FixedLenFeature ((), tf .string ),
60
60
"image/format" : tf .FixedLenFeature ((), tf .string ),
61
- label_key : tf .VarLenFeature (tf .int64 )
62
61
}
62
+ label_key , label_type = label_repr # pylint: disable=unpacking-non-sequence
63
+ data_fields [label_key ] = label_type
64
+
63
65
data_items_to_decoders = {
64
66
"inputs" :
65
67
tf .contrib .slim .tfexample_decoder .Image (
@@ -244,8 +246,9 @@ def hparams(self, defaults, unused_model_hparams):
244
246
245
247
def example_reading_spec (self ):
246
248
label_key = "image/unpadded_label"
249
+ label_type = tf .VarLenFeature (tf .int64 )
247
250
return super (ImageFSNS , self ).example_reading_spec (
248
- self , label_key = label_key )
251
+ self , label_repr = ( label_key , label_type ) )
249
252
250
253
251
254
class Image2ClassProblem (ImageProblem ):
@@ -283,10 +286,8 @@ def generator(self, data_dir, tmp_dir, is_training):
283
286
284
287
def hparams (self , defaults , unused_model_hparams ):
285
288
p = defaults
286
- small_modality = "%s:small_image_modality" % registry .Modalities .IMAGE
287
- modality = small_modality if self .is_small else registry .Modalities .IMAGE
288
- p .input_modality = {"inputs" : (modality , None )}
289
- p .target_modality = ("%s:2d" % registry .Modalities .CLASS_LABEL ,
289
+ p .input_modality = {"inputs" : (registry .Modalities .IMAGE , None )}
290
+ p .target_modality = (registry .Modalities .CLASS_LABEL ,
290
291
self .num_classes )
291
292
p .batch_size_multiplier = 4 if self .is_small else 256
292
293
p .max_expected_batch_size_per_shard = 8 if self .is_small else 2
@@ -382,6 +383,38 @@ def preprocess_example(self, example, mode, unused_hparams):
382
383
return example
383
384
384
385
386
+ @registry .register_problem
387
+ class ImageImagenet64 (Image2ClassProblem ):
388
+ """Imagenet rescaled to 64x64."""
389
+
390
+ def dataset_filename (self ):
391
+ return "image_imagenet" # Reuse Imagenet data.
392
+
393
+ @property
394
+ def is_small (self ):
395
+ return True # Modalities like for CIFAR.
396
+
397
+ @property
398
+ def num_classes (self ):
399
+ return 1000
400
+
401
+ def generate_data (self , data_dir , tmp_dir , task_id = - 1 ):
402
+ # TODO(lukaszkaiser): find a better way than printing this.
403
+ print ("To generate the ImageNet dataset in the proper format, follow "
404
+ "instructions at https://github.com/tensorflow/models/blob/master"
405
+ "/inception/README.md#getting-started" )
406
+
407
+ def preprocess_example (self , example , mode , unused_hparams ):
408
+ inputs = example ["inputs" ]
409
+ # Just resize with area.
410
+ if self ._was_reversed :
411
+ example ["inputs" ] = resize_by_area (inputs , 64 )
412
+ else :
413
+ example = imagenet_preprocess_example (example , mode )
414
+ example ["inputs" ] = example ["inputs" ] = resize_by_area (inputs , 64 )
415
+ return example
416
+
417
+
385
418
@registry .register_problem
386
419
class Img2imgImagenet (ImageProblem ):
387
420
"""Imagenet rescaled to 8x8 for input and 32x32 for output."""
@@ -623,9 +656,11 @@ def class_labels(self):
623
656
]
624
657
625
658
def preprocess_example (self , example , mode , unused_hparams ):
659
+ example ["inputs" ].set_shape ([_CIFAR10_IMAGE_SIZE , _CIFAR10_IMAGE_SIZE , 3 ])
626
660
if mode == tf .estimator .ModeKeys .TRAIN :
627
661
example ["inputs" ] = common_layers .cifar_image_augmentation (
628
662
example ["inputs" ])
663
+ example ["inputs" ] = tf .to_int64 (example ["inputs" ])
629
664
return example
630
665
631
666
def generator (self , data_dir , tmp_dir , is_training ):
@@ -649,6 +684,7 @@ def generator(self, data_dir, tmp_dir, is_training):
649
684
class ImageCifar10Plain (ImageCifar10 ):
650
685
651
686
def preprocess_example (self , example , mode , unused_hparams ):
687
+ example ["inputs" ].set_shape ([_CIFAR10_IMAGE_SIZE , _CIFAR10_IMAGE_SIZE , 3 ])
652
688
example ["inputs" ] = tf .to_int64 (example ["inputs" ])
653
689
return example
654
690
0 commit comments