|
5 | 5 | from agno.utils.log import logger
|
6 | 6 |
|
7 | 7 | try:
|
8 |
| - import pkg_resources |
| 8 | + from ollama import Client as OllamaClient |
| 9 | + import importlib.metadata as metadata |
9 | 10 | from packaging import version
|
10 | 11 |
|
11 |
| - ollama_version = pkg_resources.get_distribution("ollama").version |
12 |
| - if version.parse(ollama_version).major == 0 and version.parse(ollama_version).minor < 3: |
| 12 | + # Get installed Ollama version |
| 13 | + ollama_version = metadata.version("ollama") |
| 14 | + |
| 15 | + # Check version compatibility (requires v0.3.x or higher) |
| 16 | + parsed_version = version.parse(ollama_version) |
| 17 | + if parsed_version.major == 0 and parsed_version.minor < 3: |
13 | 18 | import warnings
|
| 19 | + warnings.warn("Only Ollama v0.3.x and above are supported", UserWarning) |
| 20 | + raise RuntimeError("Incompatible Ollama version detected") |
14 | 21 |
|
15 |
| - warnings.warn( |
16 |
| - "We only support Ollama v0.3.x and above.", |
17 |
| - UserWarning, |
18 |
| - ) |
19 |
| - raise RuntimeError("Incompatible Ollama version detected. Execution halted.") |
| 22 | +except ImportError as e: |
| 23 | + # Handle different import error scenarios |
| 24 | + if "ollama" in str(e): |
| 25 | + raise ImportError( |
| 26 | + "Ollama not installed. Install with `pip install ollama`" |
| 27 | + ) from e |
| 28 | + else: |
| 29 | + raise ImportError( |
| 30 | + "Missing dependencies. Install with `pip install packaging importlib-metadata`" |
| 31 | + ) from e |
20 | 32 |
|
21 |
| - from ollama import Client as OllamaClient |
22 |
| -except (ModuleNotFoundError, ImportError): |
23 |
| - raise ImportError("`ollama` not installed. Please install using `pip install ollama`") |
| 33 | +except Exception as e: |
| 34 | + # Catch-all for unexpected errors |
| 35 | + print(f"An unexpected error occurred: {e}") |
24 | 36 |
|
25 | 37 |
|
26 | 38 | @dataclass
|
@@ -53,14 +65,23 @@ def _response(self, text: str) -> Dict[str, Any]:
|
53 | 65 | if self.options is not None:
|
54 | 66 | kwargs["options"] = self.options
|
55 | 67 |
|
56 |
| - return self.client.embed(input=text, model=self.id, **kwargs) # type: ignore |
| 68 | + response = self.client.embed(input=text, model=self.id, **kwargs) |
| 69 | + if response and "embeddings" in response: |
| 70 | + embeddings = response["embeddings"] |
| 71 | + if isinstance(embeddings, list) and len(embeddings) > 0 and isinstance(embeddings[0], list): |
| 72 | + return {"embeddings": embeddings[0]} # Use the first element |
| 73 | + elif isinstance(embeddings, list) and all(isinstance(x, (int, float)) for x in embeddings): |
| 74 | + return {"embeddings": embeddings} # Return as-is if already flat |
| 75 | + return {"embeddings": []} # Return an empty list if no valid embedding is found |
57 | 76 |
|
58 | 77 | def get_embedding(self, text: str) -> List[float]:
|
59 | 78 | try:
|
60 | 79 | response = self._response(text=text)
|
61 |
| - if response is None: |
| 80 | + embedding = response.get("embeddings", []) |
| 81 | + if len(embedding) != self.dimensions: |
| 82 | + logger.warning(f"Expected embedding dimension {self.dimensions}, but got {len(embedding)}") |
62 | 83 | return []
|
63 |
| - return response.get("embeddings", []) |
| 84 | + return embedding |
64 | 85 | except Exception as e:
|
65 | 86 | logger.warning(e)
|
66 | 87 | return []
|
|
0 commit comments