diff --git a/src/stability_sdk/client.py b/src/stability_sdk/client.py index ee422e1f..273448eb 100644 --- a/src/stability_sdk/client.py +++ b/src/stability_sdk/client.py @@ -11,6 +11,7 @@ import logging import time import mimetypes +import magic import grpc from argparse import ArgumentParser, Namespace @@ -86,7 +87,8 @@ def process_artifacts_from_answers( for artifact in resp.artifacts: artifact_start = time.time() if artifact.type == generation.ARTIFACT_IMAGE: - ext = mimetypes.guess_extension(artifact.mime) + mime_type = magic.from_buffer(artifact.binary, mime=True) + ext = mime_type.split("/")[1] contents = artifact.binary elif artifact.type == generation.ARTIFACT_CLASSIFICATIONS: ext = ".pb.json"