Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions llama_index/vector_stores/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, cast

from packaging import version
from pkg_resources import get_distribution

from llama_index.bridge.pydantic import PrivateAttr
from llama_index.schema import BaseNode, MetadataMode, TextNode
from llama_index.vector_stores.types import (
Expand Down Expand Up @@ -69,7 +72,8 @@ def _transform_pinecone_filter_operator(operator: str) -> str:


def build_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]:
"""Build a list of sparse dictionaries from a batch of input_ids.
"""
Build a list of sparse dictionaries from a batch of input_ids.

NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.

Expand All @@ -93,7 +97,8 @@ def build_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]:
def generate_sparse_vectors(
context_batch: List[str], tokenizer: Callable
) -> List[Dict[str, Any]]:
"""Generate sparse vectors from a batch of contexts.
"""
Generate sparse vectors from a batch of contexts.

NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.

Expand All @@ -105,7 +110,8 @@ def generate_sparse_vectors(


def get_default_tokenizer() -> Callable:
"""Get default tokenizer.
"""
Get default tokenizer.

NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.

Expand Down Expand Up @@ -157,7 +163,8 @@ def _to_pinecone_filter(standard_filters: MetadataFilters) -> dict:


class PineconeVectorStore(BasePydanticVectorStore):
"""Pinecone Vector Store.
"""
Pinecone Vector Store.

In this vector store, embeddings and docs are stored within a
Pinecone index.
Expand Down Expand Up @@ -217,14 +224,24 @@ def __init__(
if pinecone_index is not None:
self._pinecone_index = cast(pinecone.Index, pinecone_index)
else:
if index_name is None or environment is None:
raise ValueError(
"Must specify index_name and environment "
"if not directly passing in client."
)
pinecone_client_version = get_distribution("pinecone-client").version

if version.parse(pinecone_client_version) >= version.parse("3.0.0"):
if index_name is None:
raise ValueError(
"Must specify index_name if not directly passing in client."
)
pinecone_instance = pinecone.Pinecone(api_key=api_key)
self._pinecone_index = pinecone_instance.Index(index_name)
else:
if index_name is None or environment is None:
raise ValueError(
"Must specify index_name and environment "
"if not directly passing in client."
)

pinecone.init(api_key=api_key, environment=environment)
self._pinecone_index = pinecone.Index(index_name)
pinecone.init(api_key=api_key, environment=environment)
self._pinecone_index = pinecone.Index(index_name)

insert_kwargs = insert_kwargs or {}

Expand Down Expand Up @@ -265,8 +282,14 @@ def from_params(
except ImportError:
raise ImportError(import_err_msg)

pinecone.init(api_key=api_key, environment=environment)
pinecone_index = pinecone.Index(index_name)
pinecone_client_version = get_distribution("pinecone-client").version

if version.parse(pinecone_client_version) >= version.parse("3.0.0"):
pinecone_instance = pinecone.Pinecone(api_key=api_key)
pinecone_index = pinecone_instance.Index(index_name)
else:
pinecone.init(api_key=api_key, environment=environment)
pinecone_index = pinecone.Index(index_name)

return cls(
pinecone_index=pinecone_index,
Expand All @@ -286,14 +309,15 @@ def from_params(

@classmethod
def class_name(cls) -> str:
return "PinconeVectorStore"
return "PineconeVectorStore"

def add(
self,
nodes: List[BaseNode],
**add_kwargs: Any,
) -> List[str]:
"""Add nodes to index.
"""
Add nodes to index.

Args:
nodes: List[BaseNode]: list of nodes with embeddings
Expand Down Expand Up @@ -353,7 +377,8 @@ def client(self) -> Any:
return self._pinecone_index

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
"""
Query index for top k most similar nodes.

Args:
query_embedding (List[float]): query embedding
Expand Down