Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 60 additions & 10 deletions src/replicate/resources/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,10 @@ def delete(

def get(
self,
model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg
*,
model_owner: str,
model_name: str,
model_owner: str | NotGiven = NOT_GIVEN,
model_name: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
Expand Down Expand Up @@ -384,15 +385,39 @@ def get(
The `latest_version` object is the model's most recently pushed
[version](#models.versions.get).

Supports both legacy and new formats:
- Legacy: models.get("owner/name")
- New: models.get(model_owner="owner", model_name="name")

Args:
model_or_owner: Legacy format string "owner/name" (positional argument)
model_owner: Model owner (keyword argument)
model_name: Model name (keyword argument)
extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
# Handle legacy format: models.get("owner/name")
if model_or_owner is not NOT_GIVEN:
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
raise ValueError(
"Cannot specify both positional and keyword arguments. "
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
)

# Parse the owner/name format
if "/" not in model_or_owner:
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")

parts = model_or_owner.split("/", 1)
model_owner = parts[0]
model_name = parts[1]

# Validate required parameters
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
raise ValueError("model_owner and model_name are required")

if not model_owner:
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
if not model_name:
Expand Down Expand Up @@ -698,9 +723,10 @@ async def delete(

async def get(
self,
model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg
*,
model_owner: str,
model_name: str,
model_owner: str | NotGiven = NOT_GIVEN,
model_name: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
Expand Down Expand Up @@ -783,15 +809,39 @@ async def get(
The `latest_version` object is the model's most recently pushed
[version](#models.versions.get).

Supports both legacy and new formats:
- Legacy: models.get("owner/name")
- New: models.get(model_owner="owner", model_name="name")

Args:
model_or_owner: Legacy format string "owner/name" (positional argument)
model_owner: Model owner (keyword argument)
model_name: Model name (keyword argument)
extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
# Handle legacy format: models.get("owner/name")
if model_or_owner is not NOT_GIVEN:
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
raise ValueError(
"Cannot specify both positional and keyword arguments. "
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
)

# Parse the owner/name format
if "/" not in model_or_owner:
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")

parts = model_or_owner.split("/", 1)
model_owner = parts[0]
model_name = parts[1]

# Validate required parameters
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
raise ValueError("model_owner and model_name are required")

if not model_owner:
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
if not model_name:
Expand Down
156 changes: 156 additions & 0 deletions tests/test_models_backward_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Tests for backward compatibility in models.get() method.
"""

from unittest.mock import Mock, patch

import pytest

from replicate import Replicate, AsyncReplicate
from replicate.types.model_get_response import ModelGetResponse


@pytest.fixture
def mock_model_response():
"""Mock response for model.get requests."""
return ModelGetResponse(
url="https://replicate.com/stability-ai/stable-diffusion",
owner="stability-ai",
name="stable-diffusion",
description="A model for generating images from text prompts",
visibility="public",
github_url=None,
paper_url=None,
license_url=None,
run_count=0,
cover_image_url=None,
default_example=None,
latest_version=None,
)


class TestModelGetBackwardCompatibility:
"""Test backward compatibility for models.get() method."""

@pytest.fixture
def client(self):
"""Create a Replicate client with mocked token."""
return Replicate(bearer_token="test-token")

def test_legacy_format_owner_name(self, client, mock_model_response):
"""Test legacy format: models.get('owner/name')."""
# Mock the underlying _get method
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
# Call with legacy format
result = client.models.get("stability-ai/stable-diffusion")

# Verify underlying method was called with correct parameters
mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock())
assert result == mock_model_response

def test_new_format_keyword_args(self, client, mock_model_response):
"""Test new format: models.get(model_owner='owner', model_name='name')."""
# Mock the underlying _get method
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
# Call with new format
result = client.models.get(model_owner="stability-ai", model_name="stable-diffusion")

# Verify underlying method was called with correct parameters
mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock())
assert result == mock_model_response

def test_legacy_format_with_extra_params(self, client, mock_model_response):
"""Test legacy format with extra parameters."""
# Mock the underlying _get method
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
# Call with legacy format and extra parameters
result = client.models.get(
"stability-ai/stable-diffusion", extra_headers={"X-Custom": "test"}, timeout=30.0
)

# Verify underlying method was called
mock_get.assert_called_once()
assert result == mock_model_response

def test_error_mixed_formats(self, client):
"""Test error when mixing legacy and new formats."""
with pytest.raises(ValueError) as exc_info:
client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")

assert "Cannot specify both positional and keyword arguments" in str(exc_info.value)

def test_error_invalid_legacy_format(self, client):
"""Test error for invalid legacy format (no slash)."""
with pytest.raises(ValueError) as exc_info:
client.models.get("invalid-format")

assert "Invalid model reference 'invalid-format'" in str(exc_info.value)
assert "Expected format: 'owner/name'" in str(exc_info.value)

def test_error_missing_parameters(self, client):
"""Test error when no parameters are provided."""
with pytest.raises(ValueError) as exc_info:
client.models.get()

assert "model_owner and model_name are required" in str(exc_info.value)

def test_legacy_format_with_complex_names(self, client, mock_model_response):
"""Test legacy format with complex owner/model names."""
# Mock the underlying _get method
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
# Test with hyphenated names and numbers
result = client.models.get("black-forest-labs/flux-1.1-pro")

# Verify parsing
mock_get.assert_called_once_with("/models/black-forest-labs/flux-1.1-pro", options=Mock())

def test_legacy_format_multiple_slashes(self, client):
"""Test legacy format with multiple slashes (should split on first slash only)."""
# Mock the underlying _get method
with patch.object(client.models, "_get", return_value=Mock()) as mock_get:
# This should work - split on first slash only
client.models.get("owner/name/with/slashes")

# Verify it was parsed correctly
mock_get.assert_called_once_with("/models/owner/name/with/slashes", options=Mock())


class TestAsyncModelGetBackwardCompatibility:
"""Test backward compatibility for async models.get() method."""

@pytest.fixture
async def async_client(self):
"""Create an async Replicate client with mocked token."""
return AsyncReplicate(bearer_token="test-token")

@pytest.mark.asyncio
async def test_async_legacy_format_owner_name(self, async_client, mock_model_response):
"""Test async legacy format: models.get('owner/name')."""
# Mock the underlying _get method
with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get:
# Call with legacy format
result = await async_client.models.get("stability-ai/stable-diffusion")

# Verify underlying method was called with correct parameters
mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock())
assert result == mock_model_response

@pytest.mark.asyncio
async def test_async_new_format_keyword_args(self, async_client, mock_model_response):
"""Test async new format: models.get(model_owner='owner', model_name='name')."""
# Mock the underlying _get method
with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get:
# Call with new format
result = await async_client.models.get(model_owner="stability-ai", model_name="stable-diffusion")

# Verify underlying method was called with correct parameters
mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock())
assert result == mock_model_response

@pytest.mark.asyncio
async def test_async_error_mixed_formats(self, async_client):
"""Test async error when mixing legacy and new formats."""
with pytest.raises(ValueError) as exc_info:
await async_client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")

assert "Cannot specify both positional and keyword arguments" in str(exc_info.value)
Loading