Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit 2f40e2e

Browse files
authored
One fix in Image package and one fix in google.datalab.ml (#359)
* Fix an issue in image classification package: trainer incorrectly depend on DataFlow. Fix an issue in google.datalab.ml.Summary and google.datalab.FeatureSliceView. * Remove unused import in image classification package. * Fix trailing whitespace and indent. * Fix indent.
1 parent 2137c84 commit 2f40e2e

File tree

7 files changed

+32
-59
lines changed

7 files changed

+32
-59
lines changed

google/datalab/ml/_feature_slice_view.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,8 @@ def plot(self, data):
5858
"""
5959
import IPython
6060

61-
if sys.version_info.major > 2:
62-
basestring = (str, bytes) # for python 3 compatibility
63-
64-
if isinstance(data, basestring):
61+
if ((sys.version_info.major > 2 and isinstance(data, str)) or
62+
(sys.version_info.major <= 2 and isinstance(data, basestring))):
6563
data = bq.Query(data)
6664

6765
if isinstance(data, bq.Query):

google/datalab/ml/_summary.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,9 @@ def get_events(self, event_names):
9595
events (i.e. 'loss' under trains_set/, and 'loss' under eval_set/.)
9696
"""
9797

98-
if sys.version_info.major > 2:
99-
basestring = (str, bytes) # for python 3 compatibility
100-
101-
event_names = [event_names] if isinstance(event_names, basestring) else event_names
98+
if ((sys.version_info.major > 2 and isinstance(event_names, str)) or
99+
(sys.version_info.major <= 2 and isinstance(event_names, basestring))):
100+
event_names = [event_names]
102101

103102
all_events = self.list_events()
104103
dirs_to_look = set()
@@ -152,10 +151,10 @@ def plot(self, event_names, x_axis='step'):
152151
x_axis: whether to use step or time as x axis.
153152
"""
154153

155-
if sys.version_info.major > 2:
156-
basestring = (str, bytes) # for python 3 compatibility
154+
if ((sys.version_info.major > 2 and isinstance(event_names, str)) or
155+
(sys.version_info.major <= 2 and isinstance(event_names, basestring))):
156+
event_names = [event_names]
157157

158-
event_names = [event_names] if isinstance(event_names, basestring) else event_names
159158
events_list = self.get_events(event_names)
160159
for event_name, dir_event_dict in zip(event_names, events_list):
161160
for dir, df in dir_event_dict.iteritems():

solutionbox/image_classification/mltoolbox/image/classification/_cloud.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@
1616
"""Cloud implementation for preprocessing, training and prediction for inception model.
1717
"""
1818

19-
import apache_beam as beam
19+
2020
import base64
2121
import datetime
2222
import logging
2323
import os
2424
import urllib
2525

26-
from . import _predictor
27-
from . import _preprocess
2826
from . import _util
2927

3028

@@ -39,7 +37,9 @@ class Cloud(object):
3937
def preprocess(train_dataset, output_dir, eval_dataset, checkpoint, pipeline_option):
4038
"""Preprocess data in Cloud with DataFlow."""
4139

40+
import apache_beam as beam
4241
import google.datalab.utils
42+
from . import _preprocess
4343

4444
if checkpoint is None:
4545
checkpoint = _util._DEFAULT_CHECKPOINT_GSURL
@@ -154,7 +154,9 @@ def predict(model_id, image_files, resize, show_image):
154154
def batch_predict(dataset, model_dir, output_csv, output_bq_table, pipeline_option):
155155
"""Batch predict running in cloud."""
156156

157+
import apache_beam as beam
157158
import google.datalab.utils
159+
from . import _predictor
158160

159161
if output_csv is None and output_bq_table is None:
160162
raise ValueError('output_csv and output_bq_table cannot both be None.')

solutionbox/image_classification/mltoolbox/image/classification/_local.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616
"""Local implementation for preprocessing, training and prediction for inception model.
1717
"""
1818

19-
import apache_beam as beam
19+
2020
import datetime
2121

2222

2323
from . import _model
24-
from . import _predictor
25-
from . import _preprocess
2624
from . import _trainer
2725
from . import _util
2826

@@ -34,7 +32,9 @@ class Local(object):
3432
def preprocess(train_dataset, output_dir, eval_dataset, checkpoint):
3533
"""Preprocess data locally."""
3634

35+
import apache_beam as beam
3736
from google.datalab.utils import LambdaJob
37+
from . import _preprocess
3838

3939
if checkpoint is None:
4040
checkpoint = _util._DEFAULT_CHECKPOINT_GSURL
@@ -70,6 +70,8 @@ def train(input_dir, batch_size, max_steps, output_dir, checkpoint):
7070
def predict(model_dir, image_files, resize, show_image):
7171
"""Predict using an model in a local or GCS directory."""
7272

73+
from . import _predictor
74+
7375
images = _util.load_images(image_files, resize=resize)
7476
labels_and_scores = _predictor.predict(model_dir, images)
7577
results = zip(image_files, images, labels_and_scores)
@@ -80,7 +82,9 @@ def predict(model_dir, image_files, resize, show_image):
8082
def batch_predict(dataset, model_dir, output_csv, output_bq_table):
8183
"""Batch predict running locally."""
8284

85+
import apache_beam as beam
8386
from google.datalab.utils import LambdaJob
87+
from . import _predictor
8488

8589
if output_csv is None and output_bq_table is None:
8690
raise ValueError('output_csv and output_bq_table cannot both be None.')

solutionbox/image_classification/mltoolbox/image/classification/_predictor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ class LoadImagesDoFn(beam.DoFn):
9595
"""A DoFn that reads image from url."""
9696

9797
def process(self, element):
98-
with _util.open_local_or_gcs(element['image_url'], 'r') as ff:
98+
from tensorflow.python.lib.io import file_io as tf_file_io
99+
100+
with tf_file_io.FileIO(element['image_url'], 'r') as ff:
99101
image_bytes = ff.read()
100102
out_element = {'image_bytes': image_bytes}
101103
out_element.update(element)

solutionbox/image_classification/mltoolbox/image/classification/_preprocess.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,11 @@ class ReadImageAndConvertToJpegDoFn(beam.DoFn):
8787
"""
8888

8989
def process(self, element):
90-
uri, label_id = element
90+
from tensorflow.python.lib.io import file_io as tf_file_io
9191

92+
uri, label_id = element
9293
try:
93-
with _util.open_local_or_gcs(uri, mode='r') as f:
94+
with tf_file_io.FileIO(uri, 'r') as f:
9495
img = Image.open(f).convert('RGB')
9596
# A variety of different calling libraries throw different exceptions here.
9697
# They all correspond to an unreadable file so we treat them equivalently.

solutionbox/image_classification/mltoolbox/image/classification/_util.py

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,9 @@
1616
"""Reusable utility functions.
1717
"""
1818

19-
from apache_beam.io import gcsio
2019
import collections
21-
import glob
2220
import multiprocessing
2321
import os
24-
import logging
25-
import time
2622

2723
import tensorflow as tf
2824
from tensorflow.python.lib.io import file_io
@@ -44,41 +40,12 @@ def default_project():
4440
return Context.default().project_id
4541

4642

47-
def open_local_or_gcs(path, mode):
48-
"""Opens the given path."""
49-
if path.startswith('gs://'):
50-
try:
51-
return gcsio.GcsIO().open(path, mode)
52-
except Exception as e: # pylint: disable=broad-except
53-
# Currently we retry exactly once, to work around flaky gcs calls.
54-
logging.error('Retrying after exception reading gcs file: %s', e)
55-
time.sleep(10)
56-
return gcsio.GcsIO().open(path, mode)
57-
else:
58-
return open(path, mode)
59-
60-
61-
def file_exists(path):
62-
"""Returns whether the file exists."""
63-
if path.startswith('gs://'):
64-
return gcsio.GcsIO().exists(path)
65-
else:
66-
return os.path.exists(path)
67-
68-
69-
def glob_files(path):
70-
if path.startswith('gs://'):
71-
return gcsio.GcsIO().glob(path)
72-
else:
73-
return glob.glob(path)
74-
75-
7643
def _get_latest_data_dir(input_dir):
7744
latest_file = os.path.join(input_dir, 'latest')
78-
if not file_exists(latest_file):
45+
if not file_io.file_exists(latest_file):
7946
raise Exception(('Cannot find "latest" file in "%s". ' +
8047
'Please use a preprocessing output dir.') % input_dir)
81-
with open_local_or_gcs(latest_file, 'r') as f:
48+
with file_io.FileIO(latest_file, 'r') as f:
8249
dir_name = f.read().rstrip()
8350
return os.path.join(input_dir, dir_name)
8451

@@ -88,16 +55,16 @@ def get_train_eval_files(input_dir):
8855
data_dir = _get_latest_data_dir(input_dir)
8956
train_pattern = os.path.join(data_dir, 'train*.tfrecord.gz')
9057
eval_pattern = os.path.join(data_dir, 'eval*.tfrecord.gz')
91-
train_files = glob_files(train_pattern)
92-
eval_files = glob_files(eval_pattern)
58+
train_files = file_io.get_matching_files(train_pattern)
59+
eval_files = file_io.get_matching_files(eval_pattern)
9360
return train_files, eval_files
9461

9562

9663
def get_labels(input_dir):
9764
"""Get a list of labels from preprocessed output dir."""
9865
data_dir = _get_latest_data_dir(input_dir)
9966
labels_file = os.path.join(data_dir, 'labels')
100-
with open_local_or_gcs(labels_file, mode='r') as f:
67+
with file_io.FileIO(labels_file, 'r') as f:
10168
labels = f.read().rstrip().split('\n')
10269
return labels
10370

@@ -254,7 +221,7 @@ def load_images(image_files, resize=True):
254221

255222
images = []
256223
for image_file in image_files:
257-
with open_local_or_gcs(image_file, 'r') as ff:
224+
with file_io.FileIO(image_file, 'r') as ff:
258225
images.append(ff.read())
259226
if resize is False:
260227
return images

0 commit comments

Comments
 (0)