Skip to content

Commit

Permalink
fix: check slice length when adding annotations, refine and correct q…
Browse files Browse the repository at this point in the history
…ueries
  • Loading branch information
ashearin committed Nov 24, 2024
1 parent 0ea962d commit 28d4b73
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 38 deletions.
65 changes: 28 additions & 37 deletions backends/ent/annotations.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@ import (
// AddAnnotationToDocuments applies a single named annotation value to multiple documents.
func (backend *Backend) AddAnnotationToDocuments(name, value string, documentIDs ...string) error {
data := ent.Annotations{}
predicates := []predicate.Metadata{}

if len(documentIDs) > 0 {
predicates = append(predicates, metadata.NativeIDIn(documentIDs...))
}

docUUIDs, err := backend.client.Metadata.Query().
Where(metadata.NativeIDIn(documentIDs...)).
Where(predicates...).
QueryDocument().
IDs(backend.ctx)
if err != nil {
Expand All @@ -45,12 +50,17 @@ func (backend *Backend) AddAnnotationToDocuments(name, value string, documentIDs
// AddAnnotationToNodes applies a single named annotation value to multiple nodes.
func (backend *Backend) AddAnnotationToNodes(name, value string, nodeIDs ...string) error {
data := ent.Annotations{}
predicates := []predicate.Node{}

if len(nodeIDs) > 0 {
predicates = append(predicates, node.NativeIDIn(nodeIDs...))
}

nodes, err := backend.client.Node.Query().
Where(node.NativeIDIn(nodeIDs...)).
Where(predicates...).
All(backend.ctx)
if err != nil {
return fmt.Errorf("querying Node IDs: %w", err)
return fmt.Errorf("querying nodes: %w", err)
}

for _, n := range nodes {
Expand Down Expand Up @@ -117,17 +127,17 @@ func (backend *Backend) ClearDocumentAnnotations(documentIDs ...string) error {
return nil
}

docUUIDs, err := backend.client.Document.Query().
QueryMetadata().
docUUIDs, err := backend.client.Metadata.Query().
Where(metadata.NativeIDIn(documentIDs...)).
QueryDocument().
IDs(backend.ctx)
if err != nil {
return fmt.Errorf("querying document IDs: %w", err)
}

return backend.withTx(func(tx *ent.Tx) error {
if _, err := tx.Annotation.Delete().
Where(annotation.HasDocumentWith(document.MetadataIDIn(docUUIDs...))).
Where(annotation.HasDocumentWith(document.IDIn(docUUIDs...))).
Exec(backend.ctx); err != nil {
return fmt.Errorf("clearing annotations: %w", err)
}
Expand Down Expand Up @@ -196,32 +206,22 @@ func (backend *Backend) GetDocumentsByAnnotation(name string, values ...string)
predicates = append(predicates, annotation.ValueIn(values...))
}

uniqueIDs, err := backend.client.Annotation.Query().
ids := []string{}

err := backend.client.Annotation.Query().
Where(predicates...).
QueryDocument().
QueryMetadata().
IDs(backend.ctx)
Select(metadata.FieldNativeID).
Scan(backend.ctx, &ids)
if err != nil {
return nil, fmt.Errorf("querying documents table: %w", err)
}

if len(uniqueIDs) == 0 {
if len(ids) == 0 {
return []*sbom.Document{}, nil
}

mds, err := backend.client.Document.Query().
QueryMetadata().
Where(metadata.IDIn(uniqueIDs...)).
All(backend.ctx)
if err != nil {
return nil, fmt.Errorf("querying document IDs: %w", err)
}

ids := []string{}
for _, md := range mds {
ids = append(ids, md.NativeID)
}

return backend.GetDocumentsByID(ids...)
}

Expand Down Expand Up @@ -287,30 +287,21 @@ func (backend *Backend) GetNodesByAnnotation(name string, values ...string) ([]*
predicates = append(predicates, annotation.ValueIn(values...))
}

uniqueIDs, err := backend.client.Annotation.Query().
ids := []string{}

err := backend.client.Annotation.Query().
Where(predicates...).
QueryNode().
IDs(backend.ctx)
Select(node.FieldNativeID).
Scan(backend.ctx, &ids)
if err != nil {
return nil, fmt.Errorf("querying nodes table: %w", err)
}

if len(uniqueIDs) == 0 {
if len(ids) == 0 {
return []*sbom.Node{}, nil
}

nodes, err := backend.client.Node.Query().
Where(node.IDIn(uniqueIDs...)).
All(backend.ctx)
if err != nil {
return nil, fmt.Errorf("querying node IDs: %w", err)
}

ids := []string{}
for _, md := range nodes {
ids = append(ids, md.NativeID)
}

return backend.GetNodesByID(ids...)
}

Expand Down
12 changes: 11 additions & 1 deletion backends/ent/annotations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ func (as *annotationsSuite) TestBackend_ClearDocumentAnnotations() {
documentIDs := []string{}

as.Require().NoError(as.Backend.AddAnnotationToDocuments(annotationName, "test-value", documentIDs...))

for _, document := range as.documents {
documentIDs = append(documentIDs, document.GetMetadata().GetId())
}

as.Require().NoError(as.Backend.ClearDocumentAnnotations(documentIDs...))

annotations := as.getTestResult(annotationName)
Expand All @@ -166,7 +171,12 @@ func (as *annotationsSuite) TestBackend_ClearNodeAnnotations() {
annotationName := "clear_node_annotations_test"
nodeIDs := []string{}

as.Require().NoError(as.Backend.AddAnnotationToDocuments(annotationName, "test-node-value", nodeIDs...))
as.Require().NoError(as.Backend.AddAnnotationToNodes(annotationName, "test-node-value", nodeIDs...))

for _, node := range as.nodes {
nodeIDs = append(nodeIDs, node.GetId())
}

as.Require().NoError(as.Backend.ClearNodeAnnotations(nodeIDs...))

annotations := as.getTestResult(annotationName)
Expand Down

0 comments on commit 28d4b73

Please sign in to comment.