Skip to content

Commit 45f1353

Browse files
nnegreyleahecole
authored and
Doug Mahugh
committed
automl: add vision object detection samples for atuoml ga (GoogleCloudPlatform#2614)
* automl: add vision object detection samples for atuoml ga * Update tests * update test resource file used * Consistently use double quotes * Move test imports to top of file * license year 2020 * Use centralized testing project for automl, improve comment with links to docs Co-authored-by: Leah E. Cole <[email protected]>
1 parent ee042e5 commit 45f1353

9 files changed

+350
-0
lines changed
2.25 MB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def create_dataset(project_id, display_name):
17+
"""Create a dataset."""
18+
# [START automl_vision_object_detection_create_dataset]
19+
from google.cloud import automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# display_name = "your_datasets_display_name"
24+
25+
client = automl.AutoMlClient()
26+
27+
# A resource that represents Google Cloud Platform location.
28+
project_location = client.location_path(project_id, "us-central1")
29+
metadata = automl.types.ImageObjectDetectionDatasetMetadata()
30+
dataset = automl.types.Dataset(
31+
display_name=display_name,
32+
image_object_detection_dataset_metadata=metadata,
33+
)
34+
35+
# Create a dataset with the dataset metadata in the region.
36+
response = client.create_dataset(project_location, dataset)
37+
38+
created_dataset = response.result()
39+
40+
# Display the dataset information
41+
print("Dataset name: {}".format(created_dataset.name))
42+
print("Dataset id: {}".format(created_dataset.name.split("/")[-1]))
43+
# [END automl_vision_object_detection_create_dataset]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import datetime
16+
import os
17+
18+
from google.cloud import automl
19+
import pytest
20+
21+
import vision_object_detection_create_dataset
22+
23+
24+
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
25+
26+
27+
@pytest.mark.slow
28+
def test_vision_object_detection_create_dataset(capsys):
29+
# create dataset
30+
dataset_name = "test_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")
31+
vision_object_detection_create_dataset.create_dataset(
32+
PROJECT_ID, dataset_name
33+
)
34+
out, _ = capsys.readouterr()
35+
assert "Dataset id: " in out
36+
37+
# Delete the created dataset
38+
dataset_id = out.splitlines()[1].split()[2]
39+
client = automl.AutoMlClient()
40+
dataset_full_id = client.dataset_path(
41+
PROJECT_ID, "us-central1", dataset_id
42+
)
43+
response = client.delete_dataset(dataset_full_id)
44+
response.result()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def create_model(project_id, dataset_id, display_name):
17+
"""Create a model."""
18+
# [START automl_vision_object_detection_create_model]
19+
from google.cloud import automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# dataset_id = "YOUR_DATASET_ID"
24+
# display_name = "your_models_display_name"
25+
26+
client = automl.AutoMlClient()
27+
28+
# A resource that represents Google Cloud Platform location.
29+
project_location = client.location_path(project_id, "us-central1")
30+
# Leave model unset to use the default base model provided by Google
31+
# train_budget_milli_node_hours: The actual train_cost will be equal or
32+
# less than this value.
33+
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#imageobjectdetectionmodelmetadata
34+
metadata = automl.types.ImageObjectDetectionModelMetadata(
35+
train_budget_milli_node_hours=24000
36+
)
37+
model = automl.types.Model(
38+
display_name=display_name,
39+
dataset_id=dataset_id,
40+
image_object_detection_model_metadata=metadata,
41+
)
42+
43+
# Create a model with the model metadata in the region.
44+
response = client.create_model(project_location, model)
45+
46+
print("Training operation name: {}".format(response.operation.name))
47+
print("Training started...")
48+
# [END automl_vision_object_detection_create_model]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from google.cloud import automl
18+
import pytest
19+
20+
import vision_object_detection_create_model
21+
22+
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
23+
DATASET_ID = os.environ["OBJECT_DETECTION_DATASET_ID"]
24+
25+
26+
@pytest.mark.slow
27+
def test_vision_object_detection_create_model(capsys):
28+
vision_object_detection_create_model.create_model(
29+
PROJECT_ID, DATASET_ID, "object_test_create_model"
30+
)
31+
out, _ = capsys.readouterr()
32+
assert "Training started" in out
33+
34+
# Cancel the operation
35+
operation_id = out.split("Training operation name: ")[1].split("\n")[0]
36+
client = automl.AutoMlClient()
37+
client.transport._operations_client.cancel_operation(operation_id)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def deploy_model(project_id, model_id):
17+
"""Deploy a model with a specified node count."""
18+
# [START automl_vision_object_detection_deploy_model_node_count]
19+
from google.cloud import automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# model_id = "YOUR_MODEL_ID"
24+
25+
client = automl.AutoMlClient()
26+
# Get the full path of the model.
27+
model_full_id = client.model_path(project_id, "us-central1", model_id)
28+
29+
# node count determines the number of nodes to deploy the model on.
30+
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#imageobjectdetectionmodeldeploymentmetadata
31+
metadata = automl.types.ImageObjectDetectionModelDeploymentMetadata(
32+
node_count=2
33+
)
34+
response = client.deploy_model(
35+
model_full_id,
36+
image_object_detection_model_deployment_metadata=metadata,
37+
)
38+
39+
print("Model deployment finished. {}".format(response.result()))
40+
# [END automl_vision_object_detection_deploy_model_node_count]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import pytest
18+
19+
import vision_object_detection_deploy_model_node_count
20+
21+
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
22+
MODEL_ID = "0000000000000000000000"
23+
24+
25+
@pytest.mark.slow
26+
def test_object_detection_deploy_model_with_node_count(capsys):
27+
# As model deployment can take a long time, instead try to deploy a
28+
# nonexistent model and confirm that the model was not found, but other
29+
# elements of the request were valid.
30+
try:
31+
vision_object_detection_deploy_model_node_count.deploy_model(
32+
PROJECT_ID, MODEL_ID
33+
)
34+
out, _ = capsys.readouterr()
35+
assert "The model does not exist" in out
36+
except Exception as e:
37+
assert "The model does not exist" in e.message
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def predict(project_id, model_id, file_path):
17+
"""Predict."""
18+
# [START automl_vision_object_detection_predict]
19+
from google.cloud import automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# model_id = "YOUR_MODEL_ID"
24+
# file_path = "path_to_local_file.jpg"
25+
26+
prediction_client = automl.PredictionServiceClient()
27+
28+
# Get the full path of the model.
29+
model_full_id = prediction_client.model_path(
30+
project_id, "us-central1", model_id
31+
)
32+
33+
# Read the file.
34+
with open(file_path, "rb") as content_file:
35+
content = content_file.read()
36+
37+
image = automl.types.Image(image_bytes=content)
38+
payload = automl.types.ExamplePayload(image=image)
39+
40+
# params is additional domain-specific parameters.
41+
# score_threshold is used to filter the result
42+
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#predictrequest
43+
params = {"score_threshold": "0.8"}
44+
45+
response = prediction_client.predict(model_full_id, payload, params)
46+
print("Prediction results:")
47+
for result in response.payload:
48+
print("Predicted class name: {}".format(result.display_name))
49+
print(
50+
"Predicted class score: {}".format(
51+
result.image_object_detection.score
52+
)
53+
)
54+
bounding_box = result.image_object_detection.bounding_box
55+
print("Normalized Vertices:")
56+
for vertex in bounding_box.normalized_vertices:
57+
print("\tX: {}, Y: {}".format(vertex.x, vertex.y))
58+
# [END automl_vision_object_detection_predict]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from google.cloud import automl
18+
import pytest
19+
20+
import vision_object_detection_predict
21+
22+
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
23+
MODEL_ID = os.environ["OBJECT_DETECTION_MODEL_ID"]
24+
25+
26+
@pytest.fixture(scope="function")
27+
def verify_model_state():
28+
client = automl.AutoMlClient()
29+
model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID)
30+
31+
model = client.get_model(model_full_id)
32+
if model.deployment_state == automl.enums.Model.DeploymentState.UNDEPLOYED:
33+
# Deploy model if it is not deployed
34+
response = client.deploy_model(model_full_id)
35+
response.result()
36+
37+
38+
def test_vision_object_detection_predict(capsys, verify_model_state):
39+
verify_model_state
40+
file_path = "resources/salad.jpg"
41+
vision_object_detection_predict.predict(PROJECT_ID, MODEL_ID, file_path)
42+
out, _ = capsys.readouterr()
43+
assert "Predicted class name:" in out

0 commit comments

Comments
 (0)