Skip to content

Commit

Permalink
Ericbrehault/sc 11725/support catalog in sdk (#156)
Browse files Browse the repository at this point in the history
* rename add_labelset in set_labelset

* support catalog endpoint

* lint

* fix
  • Loading branch information
ebrehault authored Feb 19, 2025
1 parent 7ba49e7 commit f6028f5
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 9 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Changelog

## 4.5.3 (unreleased)
## 4.6.0 (unreleased)


- Nothing changed yet.
- Rename `add_labelset` to `set_labelset`
- Support `/catalog` endpoint


## 4.5.2 (2025-02-17)
Expand Down
6 changes: 3 additions & 3 deletions docs/09-manage.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ You can list all the labels in a Knowledge Box:
labelsets = kb.list_labelsets()
```

You can create a labelset in a Knowledge Box:
You can create or modify a labelset in a Knowledge Box:

- CLI:

```sh
nuclia kb add_labelset --labelset="heroes" --labels="['Batman','Catwoman']"
nuclia kb set_labelset --labelset="heroes" --labels="['Batman','Catwoman']"
```

- SDK:

```python
from nuclia import sdk
kb = sdk.NucliaKB()
kb.add_labelset(labelset="heroes", labels=["Batman", "Catwoman"])
kb.set_labelset(labelset="heroes", labels=["Batman", "Catwoman"])
```

You can get a labelset in a Knowledge Box:
Expand Down
47 changes: 47 additions & 0 deletions nuclia/sdk/kb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from datetime import datetime
from deprecated import deprecated
import os
import tempfile
import time
Expand Down Expand Up @@ -73,6 +74,7 @@ def get_labelset(
return ndb.ndb.get_labelset(kbid=ndb.kbid, labelset=labelset)

@kb
@deprecated(version="5.0.0", reason="You should use set_labelset")
def add_labelset(
self,
*,
Expand All @@ -83,6 +85,28 @@ def add_labelset(
color: Optional[str] = None,
labels: Optional[List[str]] = None,
**kwargs,
):
self.set_labelset(
labelset=labelset,
kind=kind,
multiple=multiple,
title=title,
color=color,
labels=labels,
**kwargs,
)

@kb
def set_labelset(
self,
*,
labelset: str,
kind: LabelSetKind = LabelSetKind.RESOURCES,
multiple: bool = True,
title: Optional[str] = None,
color: Optional[str] = None,
labels: Optional[List[str]] = None,
**kwargs,
):
ndb: NucliaDBClient = kwargs["ndb"]
if labels is None:
Expand Down Expand Up @@ -355,6 +379,7 @@ async def get_labelset(
return await ndb.ndb.get_labelset(kbid=ndb.kbid, labelset=labelset)

@kb
@deprecated(version="5.0.0", reason="You should use set_labelset")
async def add_labelset(
self,
*,
Expand All @@ -365,6 +390,28 @@ async def add_labelset(
color: Optional[str] = None,
labels: Optional[List[str]] = None,
**kwargs,
):
await self.set_labelset(
labelset=labelset,
kind=kind,
multiple=multiple,
title=title,
color=color,
labels=labels,
**kwargs,
)

@kb
async def set_labelset(
self,
*,
labelset: str,
kind: LabelSetKind = LabelSetKind.RESOURCES,
multiple: bool = True,
title: Optional[str] = None,
color: Optional[str] = None,
labels: Optional[List[str]] = None,
**kwargs,
):
ndb: AsyncNucliaDBClient = kwargs["ndb"]
if labels is None:
Expand Down
69 changes: 69 additions & 0 deletions nuclia/sdk/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from nucliadb_models.search import (
AskRequest,
AskResponseItem,
CatalogRequest,
Filter,
FindRequest,
KnowledgeboxFindResults,
Expand Down Expand Up @@ -139,6 +140,40 @@ def find(

return ndb.ndb.find(req, kbid=ndb.kbid)

@kb
@pretty
def catalog(
self,
*,
query: Union[str, CatalogRequest] = "",
filters: Optional[Union[List[str], List[Filter]]] = None,
**kwargs,
):
"""
Perform a catalog query.
See https://docs.nuclia.dev/docs/api#tag/Search/operation/catalog_post_kb__kbid__catalog_post
"""
ndb: NucliaDBClient = kwargs["ndb"]
if isinstance(query, str):
req = CatalogRequest(
query=query,
filters=filters or [], # type: ignore
**kwargs,
)
elif isinstance(query, CatalogRequest):
req = query
elif isinstance(query, dict):
try:
req = CatalogRequest.model_validate(query)
except ValidationError:
logger.exception("Error validating query")
raise
else:
raise Exception("Invalid Query either str or CatalogRequest")

return ndb.ndb.catalog(req, kbid=ndb.kbid)

@kb
def ask(
self,
Expand Down Expand Up @@ -379,6 +414,40 @@ async def find(

return await ndb.ndb.find(req, kbid=ndb.kbid)

@kb
@pretty
async def catalog(
self,
*,
query: Union[str, CatalogRequest] = "",
filters: Optional[Union[List[str], List[Filter]]] = None,
**kwargs,
):
"""
Perform a catalog query.
See https://docs.nuclia.dev/docs/api#tag/Search/operation/catalog_post_kb__kbid__catalog_post
"""
ndb: AsyncNucliaDBClient = kwargs["ndb"]
if isinstance(query, str):
req = CatalogRequest(
query=query,
filters=filters or [], # type: ignore
**kwargs,
)
elif isinstance(query, CatalogRequest):
req = query
elif isinstance(query, dict):
try:
req = CatalogRequest.model_validate(query)
except ValidationError:
logger.exception("Error validating query")
raise
else:
raise Exception("Invalid Query either str or CatalogRequest")

return await ndb.ndb.catalog(req, kbid=ndb.kbid)

@kb
async def ask(
self,
Expand Down
2 changes: 1 addition & 1 deletion nuclia/tests/test_kb/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def test_labels(testing_config):
nkb = NucliaKB()
nkb.add_label(labelset="labelset1", label="label1")
nkb.add_labelset(labelset="labelset1")
nkb.set_labelset(labelset="labelset1")
nkb.add_label(labelset="labelset1", label="label1")
nkb.add_label(labelset="labelset1", label="label2")
nkb.del_labelset(labelset="labelset2")
Expand Down
6 changes: 6 additions & 0 deletions nuclia/tests/test_kb/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def test_search(testing_config):
assert "Lamarr Lesson plan.pdf" in titles


def test_catalog(testing_config):
search = NucliaSearch()
results = search.catalog()
assert len(results.resources.keys()) == 2


def test_search_object(testing_config):
search = NucliaSearch()
results = search.search(query={"query": "Who is hedy Lamarr?"})
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ requests
httpx
httpcore>=1.0.0
prompt_toolkit
nucliadb_sdk>=6.2.1.post2864,<7
nucliadb_models>=6.2.1.post2864,<7
nucliadb_protos>=6.2.1.post2864,<7
nucliadb_sdk>=6.2.1.post3247,<7
nucliadb_models>=6.2.1.post3247,<7
nucliadb_protos>=6.2.1.post3247,<7
nuclia-models>=0.25.0
tqdm
aiofiles
Expand Down

0 comments on commit f6028f5

Please sign in to comment.