Skip to content

Commit 00d1151

Browse files
committed
initial integrated inf impl
1 parent afc4345 commit 00d1151

File tree

6 files changed

+245
-16
lines changed

6 files changed

+245
-16
lines changed

src/models/cloud.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
use serde::{Deserialize, Serialize};
2+
use strum::{Display, EnumString};
3+
4+
/// The public cloud where you would like your index hosted.
5+
#[derive(
6+
Debug, Default, Clone, Copy, PartialEq, Eq, Display, EnumString, Serialize, Deserialize,
7+
)]
8+
#[serde(rename_all = "lowercase")]
9+
#[strum(serialize_all = "lowercase")]
10+
pub enum Cloud {
11+
/// GCP
12+
#[default]
13+
Gcp,
14+
/// AWS
15+
Aws,
16+
/// Azure
17+
Azure,
18+
}
19+
20+
impl From<crate::openapi::models::serverless_spec::Cloud> for Cloud {
21+
fn from(cloud: crate::openapi::models::serverless_spec::Cloud) -> Self {
22+
match cloud {
23+
crate::openapi::models::serverless_spec::Cloud::Gcp => Cloud::Gcp,
24+
crate::openapi::models::serverless_spec::Cloud::Aws => Cloud::Aws,
25+
crate::openapi::models::serverless_spec::Cloud::Azure => Cloud::Azure,
26+
}
27+
}
28+
}
29+
30+
impl From<Cloud> for crate::openapi::models::serverless_spec::Cloud {
31+
fn from(cloud: Cloud) -> Self {
32+
match cloud {
33+
Cloud::Gcp => crate::openapi::models::serverless_spec::Cloud::Gcp,
34+
Cloud::Aws => crate::openapi::models::serverless_spec::Cloud::Aws,
35+
Cloud::Azure => crate::openapi::models::serverless_spec::Cloud::Azure,
36+
}
37+
}
38+
}
39+
40+
impl From<crate::openapi::models::create_index_for_model_request::Cloud> for Cloud {
41+
fn from(cloud: crate::openapi::models::create_index_for_model_request::Cloud) -> Self {
42+
match cloud {
43+
crate::openapi::models::create_index_for_model_request::Cloud::Gcp => Cloud::Gcp,
44+
crate::openapi::models::create_index_for_model_request::Cloud::Aws => Cloud::Aws,
45+
crate::openapi::models::create_index_for_model_request::Cloud::Azure => Cloud::Azure,
46+
}
47+
}
48+
}
49+
50+
impl From<Cloud> for crate::openapi::models::create_index_for_model_request::Cloud {
51+
fn from(cloud: Cloud) -> Self {
52+
match cloud {
53+
Cloud::Gcp => crate::openapi::models::create_index_for_model_request::Cloud::Gcp,
54+
Cloud::Aws => crate::openapi::models::create_index_for_model_request::Cloud::Aws,
55+
Cloud::Azure => crate::openapi::models::create_index_for_model_request::Cloud::Azure,
56+
}
57+
}
58+
}

src/models/index_model.rs

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
use super::{DeletionProtection, IndexModelSpec, IndexModelStatus, Metric};
1+
use super::{DeletionProtection, IndexModelSpec, IndexModelStatus, Metric, VectorType};
22
use crate::openapi::models::index_model::IndexModel as OpenApiIndexModel;
3+
use crate::openapi::models::{CreateIndexForModelRequestEmbed, ModelIndexEmbed};
4+
use serde_with::serde_derive::Serialize;
5+
use std::collections::HashMap;
36

47
/// IndexModel : The IndexModel describes the configuration and status of a Pinecone index.
58
#[derive(Clone, Default, Debug, PartialEq)]
@@ -18,6 +21,12 @@ pub struct IndexModel {
1821
pub spec: IndexModelSpec,
1922
/// Index model specs
2023
pub status: IndexModelStatus,
24+
/// Index tags
25+
pub tags: Option<HashMap<String, String>>,
26+
/// Index embedding configuration
27+
pub embed: Option<ModelIndexEmbed>,
28+
/// Index vector type
29+
pub vector_type: VectorType,
2130
}
2231

2332
impl From<OpenApiIndexModel> for IndexModel {
@@ -30,6 +39,84 @@ impl From<OpenApiIndexModel> for IndexModel {
3039
deletion_protection: openapi_index_model.deletion_protection,
3140
spec: *openapi_index_model.spec,
3241
status: *openapi_index_model.status,
42+
tags: openapi_index_model.tags,
43+
embed: openapi_index_model.embed.map(|emb| *emb),
44+
vector_type: openapi_index_model.vector_type,
45+
}
46+
}
47+
}
48+
49+
/// A field mapping entry by type.
50+
#[derive(Clone, Debug)]
51+
pub enum FieldMapEntry {
52+
/// The name of the text field from your document model that is embedded.
53+
TextField(String),
54+
}
55+
56+
/// A model parameter value of a specific type.
57+
#[derive(Clone, Debug, Serialize)]
58+
pub enum ModelParameterValue {
59+
/// A string value type
60+
StringVal(String),
61+
/// An integer value type
62+
IntVal(i32),
63+
/// A floating point value type
64+
FloatVal(f32),
65+
/// A boolean value type.
66+
BoolVal(bool),
67+
}
68+
69+
/// Configuration options for the index with integrated embedding.
70+
#[derive(Clone, Debug)]
71+
pub struct CreateIndexForModelOptions {
72+
/// The name of the embedding model to use for the index.
73+
pub model: String,
74+
/// Identifies the name of the field from your document model that will be embedded. (Only one
75+
/// field is supported for now.)
76+
pub field_map: Vec<FieldMapEntry>,
77+
/// The distance metric to be used for similarity search. You can use 'euclidean', 'cosine', or 'dotproduct'. If not specified, the metric will be defaulted according to the model. Cannot be updated once set.
78+
pub metric: Option<Metric>,
79+
/// The desired vector dimension, if supported by the model.
80+
pub dimension: Option<i32>,
81+
/// The read parameters for the embedding model.
82+
pub read_parameters: Option<HashMap<String, ModelParameterValue>>,
83+
/// The write parameters for the embedding model.
84+
pub write_parameters: Option<HashMap<String, ModelParameterValue>>,
85+
}
86+
87+
impl From<CreateIndexForModelOptions> for CreateIndexForModelRequestEmbed {
88+
fn from(options: CreateIndexForModelOptions) -> Self {
89+
let field_map = options
90+
.field_map
91+
.into_iter()
92+
.map(|entry| match entry {
93+
FieldMapEntry::TextField(field_name) => {
94+
("text", serde_json::Value::String(field_name))
95+
}
96+
})
97+
.collect();
98+
99+
let read_parameters = options.read_parameters.map(|params| {
100+
params
101+
.into_iter()
102+
.map(|(key, value)| (key, serde_json::to_value(value).unwrap()))
103+
.collect()
104+
});
105+
106+
let write_parameters = options.write_parameters.map(|params| {
107+
params
108+
.into_iter()
109+
.map(|(key, value)| (key, serde_json::to_value(value).unwrap()))
110+
.collect()
111+
});
112+
113+
CreateIndexForModelRequestEmbed {
114+
model: options.model,
115+
field_map,
116+
metric: options.metric.map(|m| m.into()),
117+
read_parameters,
118+
write_parameters,
119+
dimension: options.dimension,
33120
}
34121
}
35122
}

src/models/metric.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,35 @@ impl From<Metric> for ResponseMetric {
5252
}
5353
}
5454
}
55+
56+
impl From<Metric> for crate::openapi::models::create_index_for_model_request_embed::Metric {
57+
fn from(model: Metric) -> Self {
58+
match model {
59+
Metric::Cosine => {
60+
crate::openapi::models::create_index_for_model_request_embed::Metric::Cosine
61+
}
62+
Metric::Euclidean => {
63+
crate::openapi::models::create_index_for_model_request_embed::Metric::Euclidean
64+
}
65+
Metric::Dotproduct => {
66+
crate::openapi::models::create_index_for_model_request_embed::Metric::Dotproduct
67+
}
68+
}
69+
}
70+
}
71+
72+
impl From<crate::openapi::models::create_index_for_model_request_embed::Metric> for Metric {
73+
fn from(model: crate::openapi::models::create_index_for_model_request_embed::Metric) -> Self {
74+
match model {
75+
crate::openapi::models::create_index_for_model_request_embed::Metric::Cosine => {
76+
Metric::Cosine
77+
}
78+
crate::openapi::models::create_index_for_model_request_embed::Metric::Euclidean => {
79+
Metric::Euclidean
80+
}
81+
crate::openapi::models::create_index_for_model_request_embed::Metric::Dotproduct => {
82+
Metric::Dotproduct
83+
}
84+
}
85+
}
86+
}

src/models/mod.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ mod namespace;
1111
pub use self::namespace::Namespace;
1212

1313
mod index_model;
14-
pub use self::index_model::IndexModel;
14+
pub use self::index_model::{
15+
CreateIndexForModelOptions, FieldMapEntry, IndexModel, ModelParameterValue,
16+
};
1517

1618
mod index_list;
1719
pub use self::index_list::IndexList;
@@ -25,11 +27,14 @@ pub use self::embedding::Embedding;
2527
mod vector_type;
2628
pub use self::vector_type::VectorType;
2729

30+
mod cloud;
31+
pub use self::cloud::Cloud;
32+
2833
pub use crate::openapi::models::{
29-
index_model_status::State, serverless_spec::Cloud, CollectionList, CollectionModel,
30-
ConfigureIndexRequest, ConfigureIndexRequestSpec, ConfigureIndexRequestSpecPod,
31-
CreateCollectionRequest, DeletionProtection, EmbedRequestParameters, IndexModelSpec,
32-
IndexModelStatus, IndexSpec, PodSpec, PodSpecMetadataConfig, ServerlessSpec,
34+
index_model_status::State, CollectionList, CollectionModel, ConfigureIndexRequest,
35+
ConfigureIndexRequestSpec, ConfigureIndexRequestSpecPod, CreateCollectionRequest,
36+
DeletionProtection, EmbedRequestParameters, IndexModelSpec, IndexModelStatus, IndexSpec,
37+
PodSpec, PodSpecMetadataConfig, ServerlessSpec,
3338
};
3439

3540
pub use crate::protos::{

src/pinecone/control.rs

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@ use std::collections::HashMap;
33
use std::time::Duration;
44

55
use crate::openapi::apis::manage_indexes_api;
6-
use crate::openapi::models::{ByocSpec, CreateIndexRequest};
6+
use crate::openapi::models::{ByocSpec, CreateIndexForModelRequestEmbed, CreateIndexRequest};
77
use crate::pinecone::PineconeClient;
88
use crate::utils::errors::PineconeError;
99

1010
use crate::models::{
1111
Cloud, CollectionList, CollectionModel, ConfigureIndexRequest, ConfigureIndexRequestSpec,
12-
ConfigureIndexRequestSpecPod, CreateCollectionRequest, DeletionProtection, IndexList,
13-
IndexModel, IndexSpec, Metric, PodSpec, PodSpecMetadataConfig, ServerlessSpec, VectorType,
14-
WaitPolicy,
12+
ConfigureIndexRequestSpecPod, CreateCollectionRequest, CreateIndexForModelOptions,
13+
DeletionProtection, IndexList, IndexModel, IndexSpec, Metric, PodSpec, PodSpecMetadataConfig,
14+
ServerlessSpec, VectorType, WaitPolicy,
1515
};
1616
use crate::openapi::models;
1717

@@ -71,7 +71,7 @@ impl PineconeClient {
7171
// create request specs
7272
let create_index_request_spec = IndexSpec {
7373
serverless: Some(Box::new(ServerlessSpec {
74-
cloud,
74+
cloud: cloud.into(),
7575
region: region.to_string(),
7676
})),
7777
pod: None,
@@ -250,6 +250,7 @@ impl PineconeClient {
250250
/// # Ok(())
251251
/// # }
252252
/// ```
253+
#[allow(clippy::too_many_arguments)]
253254
pub async fn create_byoc_index(
254255
&self,
255256
name: &str,
@@ -294,6 +295,40 @@ impl PineconeClient {
294295
}
295296
}
296297

298+
/// Creates an integrated index for a model.
299+
#[allow(clippy::too_many_arguments)]
300+
pub async fn create_index_for_model(
301+
&self,
302+
name: &str,
303+
cloud: Cloud,
304+
region: &str,
305+
embed: CreateIndexForModelOptions,
306+
deletion_protection: Option<DeletionProtection>,
307+
tags: Option<HashMap<String, String>>,
308+
timeout: WaitPolicy,
309+
) -> Result<IndexModel, PineconeError> {
310+
let embed: CreateIndexForModelRequestEmbed = embed.into();
311+
312+
let request = models::CreateIndexForModelRequest {
313+
name: name.to_string(),
314+
cloud: cloud.into(),
315+
region: region.to_string(),
316+
embed: Box::new(embed),
317+
deletion_protection,
318+
tags,
319+
};
320+
321+
let res = manage_indexes_api::create_index_for_model(&self.openapi_config, request)
322+
.await
323+
.map_err(PineconeError::from)?;
324+
325+
// poll index status
326+
match self.handle_poll_index(name, timeout).await {
327+
Ok(_) => Ok(res.into()),
328+
Err(e) => Err(e),
329+
}
330+
}
331+
297332
// Checks if the index is ready by polling the index status
298333
async fn handle_poll_index(
299334
&self,
@@ -908,14 +943,14 @@ mod tests {
908943
then.status(422)
909944
.header("content-type", "application/json")
910945
.body(
911-
r#"{
946+
r#"{
912947
"error": {
913948
"code": "INVALID_ARGUMENT",
914949
"message": "Failed to deserialize the JSON body into the target type: missing field `metric` at line 1 column 16"
915950
},
916951
"status": 422
917952
}"#,
918-
);
953+
);
919954
});
920955

921956
let config = PineconeClientConfig {
@@ -1051,6 +1086,9 @@ mod tests {
10511086
pod: None,
10521087
byoc: None,
10531088
},
1089+
vector_type: VectorType::Dense,
1090+
tags: None,
1091+
embed: None,
10541092
};
10551093

10561094
assert_eq!(index, expected);
@@ -1191,6 +1229,9 @@ mod tests {
11911229
deletion_protection: None,
11921230
spec: models::IndexModelSpec::default(),
11931231
status: models::IndexModelStatus::default(),
1232+
embed: None,
1233+
tags: None,
1234+
vector_type: VectorType::Dense,
11941235
},
11951236
IndexModel {
11961237
name: "index2".to_string(),
@@ -1200,6 +1241,9 @@ mod tests {
12001241
deletion_protection: None,
12011242
spec: models::IndexModelSpec::default(),
12021243
status: models::IndexModelStatus::default(),
1244+
embed: None,
1245+
tags: None,
1246+
vector_type: VectorType::Dense,
12031247
},
12041248
]),
12051249
};

tests/integration_test_control.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ async fn test_create_list_indexes() -> Result<(), PineconeError> {
8888
assert_eq!(index1.dimension, Some(2));
8989
assert_eq!(index1.metric, Metric::Cosine);
9090
let spec1 = index1.spec.serverless.as_ref().unwrap();
91-
assert_eq!(spec1.cloud, Cloud::Aws);
91+
let spec1_cloud: Cloud = spec1.cloud.into();
92+
assert_eq!(spec1_cloud, Cloud::Aws);
9293
assert_eq!(spec1.region, "us-west-2");
9394

9495
let index2 = indexes
@@ -100,7 +101,8 @@ async fn test_create_list_indexes() -> Result<(), PineconeError> {
100101
assert_eq!(index2.dimension, Some(2));
101102
assert_eq!(index2.metric, Metric::Dotproduct);
102103
let spec2 = index2.spec.serverless.as_ref().unwrap();
103-
assert_eq!(spec2.cloud, Cloud::Aws);
104+
let spec2_cloud: Cloud = spec2.cloud.into();
105+
assert_eq!(spec2_cloud, Cloud::Aws);
104106
assert_eq!(spec2.region, "us-west-2");
105107

106108
pinecone
@@ -142,7 +144,8 @@ async fn test_create_delete_index() -> Result<(), PineconeError> {
142144
assert_eq!(response.metric, Metric::Euclidean);
143145

144146
let spec = response.spec.serverless.unwrap();
145-
assert_eq!(spec.cloud, Cloud::Aws);
147+
let spec_cloud: Cloud = spec.cloud.into();
148+
assert_eq!(spec_cloud, Cloud::Aws);
146149
assert_eq!(spec.region, "us-west-2");
147150

148151
pinecone

0 commit comments

Comments
 (0)