1414# limitations under the License. 
1515# 
1616
17- 
1817# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 
1918# 
2019# Licensed under the Apache License, Version 2.0 (the "License"); 
3029# limitations under the License. 
3130# ============================================================================== 
3231
33- r"""Script to process the Imagenet dataset and upload to gcs.  
32+ r"""Script to process the Imagenet dataset 
3433
3534To run the script setup a virtualenv with the following libraries installed. 
36- - `gcloud`: Follow the instructions on 
37-   [cloud SDK docs](https://cloud.google.com/sdk/downloads) followed by 
38-   installing the python api using `pip install gcloud`. 
39- - `google-cloud-storage`: Install with `pip install google-cloud-storage` 
4035- `tensorflow`: Install with `pip install tensorflow` 
4136
4237Once you have all the above libraries setup, you should register on the 
4641- Validation Images: validation/ILSVRC2012_val_00000001.JPEG 
4742- Validation Labels: synset_labels.txt 
4843
49- To run the script to preprocess the raw dataset as TFRecords and upload to gcs, 
50- run the following command: 
51- 
52- ``` 
53- python imagenet_to_gcs.py \ 
54-   --project="TEST_PROJECT" \ 
55-   --gcs_output_path="gs://TEST_BUCKET/IMAGENET_DIR" \ 
56-   --raw_data_dir="path/to/imagenet" 
57- ``` 
58- 
5944""" 
6045
6146import  math 
7257tf .disable_eager_execution ()
7358
7459
75- from  google .cloud  import  storage 
76- 
77- flags .DEFINE_string (
78-     'project' , None , 'Google cloud project id for uploading the dataset.' )
79- flags .DEFINE_string (
80-     'gcs_output_path' , None , 'GCS path for uploading the dataset.' )
8160flags .DEFINE_string (
8261    'local_scratch_dir' , None , 'Scratch directory path for temporary files.' )
8362flags .DEFINE_string (
8463    'raw_data_dir' , None , 'Directory path for raw Imagenet dataset. ' 
8564    'Should have train and validation subdirectories inside it.' )
86- flags .DEFINE_boolean (
87-     'gcs_upload' , True , 'Set to false to not upload to gcs.' )
8865
8966FLAGS  =  flags .FLAGS 
9067
@@ -425,50 +402,9 @@ def make_shuffle_idx(n):
425402  return  training_records , validation_records 
426403
427404
428- def  upload_to_gcs (training_records , validation_records ):
429-   """Upload TF-Record files to GCS, at provided path.""" 
430- 
431-   # Find the GCS bucket_name and key_prefix for dataset files 
432-   path_parts  =  FLAGS .gcs_output_path [5 :].split ('/' , 1 )
433-   bucket_name  =  path_parts [0 ]
434-   if  len (path_parts ) ==  1 :
435-     key_prefix  =  '' 
436-   elif  path_parts [1 ].endswith ('/' ):
437-     key_prefix  =  path_parts [1 ]
438-   else :
439-     key_prefix  =  path_parts [1 ] +  '/' 
440- 
441-   client  =  storage .Client (project = FLAGS .project )
442-   bucket  =  client .get_bucket (bucket_name )
443- 
444-   def  _upload_files (filenames ):
445-     """Upload a list of files into a specifc subdirectory.""" 
446-     for  i , filename  in  enumerate (sorted (filenames )):
447-       blob  =  bucket .blob (key_prefix  +  os .path .basename (filename ))
448-       blob .upload_from_filename (filename )
449-       if  not  i  %  20 :
450-         logging .info ('Finished uploading file: %s' , filename )
451- 
452-   # Upload training dataset 
453-   logging .info ('Uploading the training data.' )
454-   _upload_files (training_records )
455- 
456-   # Upload validation dataset 
457-   logging .info ('Uploading the validation data.' )
458-   _upload_files (validation_records )
459- 
460- 
461405def  main (argv ):  # pylint: disable=unused-argument 
462406  logging .set_verbosity (logging .INFO )
463407
464-   if  FLAGS .gcs_upload  and  FLAGS .project  is  None :
465-     raise  ValueError ('GCS Project must be provided.' )
466- 
467-   if  FLAGS .gcs_upload  and  FLAGS .gcs_output_path  is  None :
468-     raise  ValueError ('GCS output path must be provided.' )
469-   elif  FLAGS .gcs_upload  and  not  FLAGS .gcs_output_path .startswith ('gs://' ):
470-     raise  ValueError ('GCS output path must start with gs://' )
471- 
472408  if  FLAGS .local_scratch_dir  is  None :
473409    raise  ValueError ('Scratch directory path must be provided.' )
474410
@@ -482,10 +418,6 @@ def main(argv):  # pylint: disable=unused-argument
482418  # Convert the raw data into tf-records 
483419  training_records , validation_records  =  convert_to_tf_records (raw_data_dir )
484420
485-   # Upload to GCS 
486-   if  FLAGS .gcs_upload :
487-     upload_to_gcs (training_records , validation_records )
488- 
489421
490422if  __name__  ==  '__main__' :
491423  app .run (main )
0 commit comments