-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathmodel_testing.py
114 lines (94 loc) · 3.38 KB
/
model_testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
# with the License. A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
# and limitations under the License.
#
import boto3
import botocore
from PIL import Image
from io import BytesIO
import base64
import json
from diffusers.utils import export_to_video
from time import sleep
import os
s3_client = boto3.client("s3")
s3_resource = boto3.resource("s3")
sagemaker_client = boto3.client("sagemaker-runtime")
# Configuration
S3_BUCKET = "XXXXX"
S3_BUCKET_KEY = "svd-hf-1"
S3_BUCKET_OUTPUT_KEY = "output" # this is the output folder in the S3 bucket (configured in the asyncInference construct settings)
ENDPOINT_NAME = "svdendpoint" # endpoint name configured in the construct
FRAME_PER_SECOND = 7
INPUT_IMAGE_URL = 'https://raw.githubusercontent.com/Stability-AI/generative-models/main/assets/test_image.png'
# Local File Paths
REQUEST_PAYLOAD = "input.json"
RESPONSE_PAYLOAD = "output.json"
def decode_base64_image(image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
return Image.open(buffer)
def upload_file(file_path):
s3_client.upload_file(
Filename=file_path,
Bucket=S3_BUCKET,
Key="input.json",
ExtraArgs={"ContentType": "application/json"},
)
def invoke_async_endpoint():
data = {
"inputs": INPUT_IMAGE_URL,
}
with open(REQUEST_PAYLOAD, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
upload_file(REQUEST_PAYLOAD)
response = sagemaker_client.invoke_endpoint_async(
EndpointName=ENDPOINT_NAME,
InputLocation=f"s3://{S3_BUCKET}/input.json",
InvocationTimeoutSeconds=3600,
)
return response.get("OutputLocation")
try:
invoke_response = invoke_async_endpoint()
except Exception as e:
print(e)
exit(1)
output_object_name = invoke_response.split("/")[-1]
output_key = os.path.join(S3_BUCKET_KEY, S3_BUCKET_OUTPUT_KEY, output_object_name)
print(f"Output Key: {output_key}")
head_response = None
try:
head_response = s3_resource.Object(S3_BUCKET, output_key).load()
except botocore.exceptions.ClientError as e:
retry_count = 10
if e.response['Error']['Code'] == "404":
while retry_count > 0 or head_response is None:
print(f"Waiting for output object: {output_key}")
print(f"Retries left: {retry_count}")
retry_count -= 1
sleep(30)
try:
head_response = s3_resource.Object(
S3_BUCKET, output_key).load()
break
except:
continue
else:
# Something else has gone wrong.
raise
# download output object from s3
s3_client.download_file(
Bucket=S3_BUCKET, Key=output_key, Filename=RESPONSE_PAYLOAD
)
with open(RESPONSE_PAYLOAD) as f:
frames = json.load(f)
decoded_images = [decode_base64_image(image) for image in frames["frames"]]
export_to_video(decoded_images, f"{output_object_name}.mp4", fps=FRAME_PER_SECOND)