Skip to content

Commit 262217a

Browse files
committed
Stub base model class and CLI download-model command
Why these changes are being introduced: For this application to create embeddings, it will need some structure to handle the downloading, loading, and use of particular models. Additionally, this application should provide the opinionated functionality to download and zip up a model, either for use within a Docker image, or local testing. How this addresses that need: * Class 'BaseEmbeddingModel' is created. This base class will be extended by actual models we intend to use. This base class will expect any child class to define methods for downloading, loading, and using the model. * Stubs a CLI command 'download-model' which we will use very soon for building a Docker image that contains the model weights. Side effects of this change: * None at this time. Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-112 * https://mitlibraries.atlassian.net/browse/USE-113
1 parent 60f5faf commit 262217a

File tree

10 files changed

+283
-1
lines changed

10 files changed

+283
-1
lines changed

embeddings/cli.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import logging
22
import time
33
from datetime import timedelta
4+
from pathlib import Path
45

56
import click
67

78
from embeddings.config import configure_logger, configure_sentry
9+
from embeddings.models.registry import get_model_class
810

911
logger = logging.getLogger(__name__)
1012

@@ -46,6 +48,54 @@ def ping() -> None:
4648
click.echo("pong")
4749

4850

51+
@main.command()
52+
@click.option(
53+
"--model-uri",
54+
required=True,
55+
help="HuggingFace model URI (e.g., 'org/model-name')",
56+
)
57+
@click.option(
58+
"--output",
59+
required=True,
60+
type=click.Path(path_type=Path),
61+
help="Output path for zipped model (e.g., '/path/to/model.zip')",
62+
)
63+
def download_model(model_uri: str, output: Path) -> None:
64+
"""Download a model from HuggingFace and save as zip file."""
65+
try:
66+
model_class = get_model_class(model_uri)
67+
except ValueError as e:
68+
logger.exception("Unknown model URI: %s", model_uri)
69+
raise click.ClickException(str(e)) from e
70+
71+
logger.info("Downloading model: %s", model_uri)
72+
model = model_class(model_uri)
73+
74+
try:
75+
result_path = model.download(output)
76+
logger.info("Model downloaded successfully to: %s", result_path)
77+
click.echo(f"Model saved to: {result_path}")
78+
except NotImplementedError as e:
79+
logger.exception("Download not yet implemented for model: %s", model_uri)
80+
raise click.ClickException(str(e)) from e
81+
except Exception as e:
82+
logger.exception("Failed to download model: %s", model_uri)
83+
msg = f"Download failed: {e}"
84+
raise click.ClickException(msg) from e
85+
86+
87+
@main.command()
88+
@click.option(
89+
"--model-uri",
90+
required=True,
91+
help="HuggingFace model URI (e.g., 'org/model-name')",
92+
)
93+
def create_embeddings(model_uri: str) -> None:
94+
"""Create embeddings."""
95+
logger.info("create-embeddings command called with model: %s", model_uri)
96+
raise NotImplementedError
97+
98+
4999
if __name__ == "__main__": # pragma: no cover
50100
logger = logging.getLogger("embeddings.main")
51101
main()

embeddings/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

embeddings/models/base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Base class for embedding models."""
2+
3+
from abc import ABC, abstractmethod
4+
from pathlib import Path
5+
6+
7+
class BaseEmbeddingModel(ABC):
8+
"""Abstract base class for embedding models.
9+
10+
Args:
11+
model_uri: HuggingFace model identifier (e.g., 'org/model-name').
12+
"""
13+
14+
def __init__(self, model_uri: str) -> None:
15+
self.model_uri = model_uri
16+
17+
@abstractmethod
18+
def download(self, output_path: Path) -> Path:
19+
"""Download and prepare model, saving to output_path.
20+
21+
Args:
22+
output_path: Path where the model zip should be saved.
23+
"""
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""OpenSearch Neural Sparse Doc v3 GTE model."""
2+
3+
import logging
4+
from pathlib import Path
5+
6+
from embeddings.models.base import BaseEmbeddingModel
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):
12+
"""OpenSearch Neural Sparse Encoding Doc v3 GTE model.
13+
14+
HuggingFace URI: opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte
15+
"""
16+
17+
def download(self, output_path: Path) -> Path:
18+
"""Download and prepare model, saving to output_path.
19+
20+
Args:
21+
output_path: Path where the model zip should be saved.
22+
"""
23+
logger.info(f"Downloading model: { self.model_uri}, saving to: {output_path}.")
24+
raise NotImplementedError

embeddings/models/registry.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Registry mapping model URIs to model classes."""
2+
3+
from embeddings.models.base import BaseEmbeddingModel
4+
from embeddings.models.os_neural_sparse_doc_v3_gte import OSNeuralSparseDocV3GTE
5+
6+
MODEL_REGISTRY: dict[str, type[BaseEmbeddingModel]] = {
7+
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte": (
8+
OSNeuralSparseDocV3GTE
9+
),
10+
}
11+
12+
13+
def get_model_class(model_uri: str) -> type[BaseEmbeddingModel]:
14+
"""Get model class for given URI.
15+
16+
Args:
17+
model_uri: HuggingFace model identifier.
18+
19+
Returns:
20+
Model class for the given URI.
21+
"""
22+
if model_uri not in MODEL_REGISTRY:
23+
available = ", ".join(sorted(MODEL_REGISTRY.keys()))
24+
msg = f"Unknown model URI: {model_uri}. Available models: {available}"
25+
raise ValueError(msg)
26+
return MODEL_REGISTRY[model_uri]

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ requires-python = ">=3.12"
99

1010
dependencies = [
1111
"click>=8.2.1",
12+
"huggingface-hub>=0.26.0",
1213
"sentry-sdk>=2.34.1",
1314
"timdex-dataset-api",
1415
]
@@ -56,7 +57,8 @@ ignore = [
5657
"D101",
5758
"D102",
5859
"D103",
59-
"D104",
60+
"D104",
61+
"G004",
6062
"PLR0912",
6163
"PLR0913",
6264
"PLR0915",

tests/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import zipfile
2+
from pathlib import Path
3+
14
import pytest
25
from click.testing import CliRunner
36

7+
from embeddings.models.base import BaseEmbeddingModel
8+
49

510
@pytest.fixture(autouse=True)
611
def _test_env(monkeypatch):
@@ -11,3 +16,22 @@ def _test_env(monkeypatch):
1116
@pytest.fixture
1217
def runner():
1318
return CliRunner()
19+
20+
21+
class MockEmbeddingModel(BaseEmbeddingModel):
22+
"""Simple test model that doesn't hit external APIs."""
23+
24+
def download(self, output_path: Path) -> Path:
25+
"""Create a fake model zip file for testing."""
26+
output_path.parent.mkdir(parents=True, exist_ok=True)
27+
with zipfile.ZipFile(output_path, "w") as zf:
28+
zf.writestr("config.json", '{"model": "mock", "vocab_size": 30000}')
29+
zf.writestr("pytorch_model.bin", b"fake model weights")
30+
zf.writestr("tokenizer.json", '{"version": "1.0"}')
31+
return output_path
32+
33+
34+
@pytest.fixture
35+
def mock_model():
36+
"""Fixture providing a MockEmbeddingModel instance."""
37+
return MockEmbeddingModel("test/mock-model")

tests/test_cli.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,30 @@ def test_cli_debug_logging(caplog, runner):
1414
assert "Logger 'root' configured with level=DEBUG" in caplog.text
1515
assert "pong" in caplog.text
1616
assert "pong" in result.output
17+
18+
19+
def test_download_model_unknown_uri(runner):
20+
result = runner.invoke(
21+
main, ["download-model", "--model-uri", "unknown/model", "--output", "out.zip"]
22+
)
23+
assert result.exit_code != 0
24+
assert "Unknown model URI" in result.output
25+
26+
27+
def test_download_model_not_implemented(caplog, runner):
28+
caplog.set_level("INFO")
29+
result = runner.invoke(
30+
main,
31+
[
32+
"download-model",
33+
"--model-uri",
34+
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte",
35+
"--output",
36+
"out.zip",
37+
],
38+
)
39+
assert (
40+
"Downloading model: opensearch-project/"
41+
"opensearch-neural-sparse-encoding-doc-v3-gte, saving to: out.zip."
42+
) in caplog.text
43+
assert result.exit_code != 0

tests/test_models.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import zipfile
2+
3+
import pytest
4+
5+
from embeddings.models.registry import MODEL_REGISTRY, get_model_class
6+
7+
8+
def test_mock_model_instantiation(mock_model):
9+
assert mock_model.model_uri == "test/mock-model"
10+
11+
12+
def test_mock_model_download_creates_zip(mock_model, tmp_path):
13+
output_path = tmp_path / "test_model.zip"
14+
result = mock_model.download(output_path)
15+
16+
assert result == output_path
17+
assert output_path.exists()
18+
assert zipfile.is_zipfile(output_path)
19+
20+
21+
def test_mock_model_download_contains_expected_files(mock_model, tmp_path):
22+
output_path = tmp_path / "test_model.zip"
23+
mock_model.download(output_path)
24+
25+
with zipfile.ZipFile(output_path, "r") as zf:
26+
file_list = zf.namelist()
27+
assert "config.json" in file_list
28+
assert "pytorch_model.bin" in file_list
29+
assert "tokenizer.json" in file_list
30+
31+
32+
def test_registry_contains_opensearch_model():
33+
assert (
34+
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
35+
in MODEL_REGISTRY
36+
)
37+
38+
39+
def test_get_model_class_returns_correct_class():
40+
model_class = get_model_class(
41+
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
42+
)
43+
assert model_class.__name__ == "OSNeuralSparseDocV3GTE"
44+
45+
46+
def test_get_model_class_raises_for_unknown_uri():
47+
with pytest.raises(ValueError, match="Unknown model URI"):
48+
get_model_class("unknown/model-uri")

uv.lock

Lines changed: 57 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)