Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ericbrehault/sc 11725/support catalog in sdk #156

Merged
merged 4 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Changelog

## 4.5.3 (unreleased)
## 4.6.0 (unreleased)


- Nothing changed yet.
- Rename `add_labelset` to `set_labelset`
- Support `/catalog` endpoint


## 4.5.2 (2025-02-17)
Expand Down
6 changes: 3 additions & 3 deletions docs/09-manage.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ You can list all the labels in a Knowledge Box:
labelsets = kb.list_labelsets()
```

You can create a labelset in a Knowledge Box:
You can create or modify a labelset in a Knowledge Box:

- CLI:

```sh
nuclia kb add_labelset --labelset="heroes" --labels="['Batman','Catwoman']"
nuclia kb set_labelset --labelset="heroes" --labels="['Batman','Catwoman']"
```

- SDK:

```python
from nuclia import sdk
kb = sdk.NucliaKB()
kb.add_labelset(labelset="heroes", labels=["Batman", "Catwoman"])
kb.set_labelset(labelset="heroes", labels=["Batman", "Catwoman"])
```

You can get a labelset in a Knowledge Box:
Expand Down
47 changes: 47 additions & 0 deletions nuclia/sdk/kb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from datetime import datetime
from deprecated import deprecated
import os
import tempfile
import time
Expand Down Expand Up @@ -73,6 +74,7 @@ def get_labelset(
return ndb.ndb.get_labelset(kbid=ndb.kbid, labelset=labelset)

@kb
@deprecated(version="5.0.0", reason="You should use set_labelset")
def add_labelset(
self,
*,
Expand All @@ -83,6 +85,28 @@ def add_labelset(
color: Optional[str] = None,
labels: Optional[List[str]] = None,
**kwargs,
):
self.set_labelset(
labelset=labelset,
kind=kind,
multiple=multiple,
title=title,
color=color,
labels=labels,
**kwargs,
)

@kb
def set_labelset(
self,
*,
labelset: str,
kind: LabelSetKind = LabelSetKind.RESOURCES,
multiple: bool = True,
title: Optional[str] = None,
color: Optional[str] = None,
labels: Optional[List[str]] = None,
**kwargs,
):
ndb: NucliaDBClient = kwargs["ndb"]
if labels is None:
Expand Down Expand Up @@ -355,6 +379,7 @@ async def get_labelset(
return await ndb.ndb.get_labelset(kbid=ndb.kbid, labelset=labelset)

@kb
@deprecated(version="5.0.0", reason="You should use set_labelset")
async def add_labelset(
self,
*,
Expand All @@ -365,6 +390,28 @@ async def add_labelset(
color: Optional[str] = None,
labels: Optional[List[str]] = None,
**kwargs,
):
self.set_labelset(
labelset=labelset,
kind=kind,
multiple=multiple,
title=title,
color=color,
labels=labels,
**kwargs,
)

@kb
async def set_labelset(
self,
*,
labelset: str,
kind: LabelSetKind = LabelSetKind.RESOURCES,
multiple: bool = True,
title: Optional[str] = None,
color: Optional[str] = None,
labels: Optional[List[str]] = None,
**kwargs,
):
ndb: AsyncNucliaDBClient = kwargs["ndb"]
if labels is None:
Expand Down
69 changes: 69 additions & 0 deletions nuclia/sdk/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from nucliadb_models.search import (
AskRequest,
AskResponseItem,
CatalogRequest,
Filter,
FindRequest,
KnowledgeboxFindResults,
Expand Down Expand Up @@ -139,6 +140,40 @@ def find(

return ndb.ndb.find(req, kbid=ndb.kbid)

@kb
@pretty
def catalog(
self,
*,
query: Union[str, FindRequest] = "",
filters: Optional[Union[List[str], List[Filter]]] = None,
**kwargs,
):
"""
Perform a catalog query.

See https://docs.nuclia.dev/docs/api#tag/Search/operation/catalog_post_kb__kbid__catalog_post
"""
ndb: NucliaDBClient = kwargs["ndb"]
if isinstance(query, str):
req = CatalogRequest(
query=query,
filters=filters or [], # type: ignore
**kwargs,
)
elif isinstance(query, CatalogRequest):
req = query
elif isinstance(query, dict):
try:
req = CatalogRequest.model_validate(query)
except ValidationError:
logger.exception("Error validating query")
raise
else:
raise Exception("Invalid Query either str or FindRequest")

return ndb.ndb.catalog(req, kbid=ndb.kbid)

@kb
def ask(
self,
Expand Down Expand Up @@ -379,6 +414,40 @@ async def find(

return await ndb.ndb.find(req, kbid=ndb.kbid)

@kb
@pretty
async def catalog(
self,
*,
query: Union[str, FindRequest] = "",
filters: Optional[Union[List[str], List[Filter]]] = None,
**kwargs,
):
"""
Perform a catalog query.

See https://docs.nuclia.dev/docs/api#tag/Search/operation/catalog_post_kb__kbid__catalog_post
"""
ndb: AsyncNucliaDBClient = kwargs["ndb"]
if isinstance(query, str):
req = CatalogRequest(
query=query,
filters=filters or [], # type: ignore
**kwargs,
)
elif isinstance(query, CatalogRequest):
req = query
elif isinstance(query, dict):
try:
req = CatalogRequest.model_validate(query)
except ValidationError:
logger.exception("Error validating query")
raise
else:
raise Exception("Invalid Query either str or FindRequest")

return await ndb.ndb.catalog(req, kbid=ndb.kbid)

@kb
async def ask(
self,
Expand Down
2 changes: 1 addition & 1 deletion nuclia/tests/test_kb/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def test_labels(testing_config):
nkb = NucliaKB()
nkb.add_label(labelset="labelset1", label="label1")
nkb.add_labelset(labelset="labelset1")
nkb.set_labelset(labelset="labelset1")
nkb.add_label(labelset="labelset1", label="label1")
nkb.add_label(labelset="labelset1", label="label2")
nkb.del_labelset(labelset="labelset2")
Expand Down
6 changes: 6 additions & 0 deletions nuclia/tests/test_kb/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def test_search(testing_config):
assert "Lamarr Lesson plan.pdf" in titles


def test_catalog(testing_config):
search = NucliaSearch()
results = search.catalog()
assert len(results.resources.keys()) == 2


def test_search_object(testing_config):
search = NucliaSearch()
results = search.search(query={"query": "Who is hedy Lamarr?"})
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ requests
httpx
httpcore>=1.0.0
prompt_toolkit
nucliadb_sdk>=6.2.1.post2864,<7
nucliadb_models>=6.2.1.post2864,<7
nucliadb_protos>=6.2.1.post2864,<7
nucliadb_sdk>=6.2.1.post3247,<7
nucliadb_models>=6.2.1.post3247,<7
nucliadb_protos>=6.2.1.post3247,<7
nuclia-models>=0.25.0
tqdm
aiofiles
Expand Down