Skip to content

Commit 3b7acd9

Browse files
committed
feat: introduct db context in graph and put status cache there
1 parent ad213ee commit 3b7acd9

File tree

9 files changed

+80
-138
lines changed

9 files changed

+80
-138
lines changed

modules/importer/src/runner/csaf/mod.rs

+1-13
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use crate::{
1212
},
1313
server::context::WalkerProgress,
1414
};
15-
use anyhow::anyhow;
1615
use csaf_walker::{
1716
metadata::MetadataRetriever,
1817
retrieve::RetrievingVisitor,
@@ -25,10 +24,7 @@ use reqwest::StatusCode;
2524
use std::collections::HashSet;
2625
use std::{sync::Arc, time::SystemTime};
2726
use tracing::instrument;
28-
use trustify_module_ingestor::{
29-
graph::Graph,
30-
service::{advisory::StatusCache, IngestorService},
31-
};
27+
use trustify_module_ingestor::{graph::Graph, service::IngestorService};
3228
use url::Url;
3329
use walker_common::fetcher::{Fetcher, FetcherOptions};
3430

@@ -67,18 +63,10 @@ impl super::ImportRunner {
6763
};
6864

6965
// storage (called by validator)
70-
71-
let mut status_cache = StatusCache::new();
72-
status_cache
73-
.load_statuses(&self.db)
74-
.await
75-
.map_err(|_| anyhow!("Failed to load statuses"))?;
76-
7766
let ingestor = IngestorService {
7867
graph: Graph::new(self.db.clone()),
7968
storage: self.storage.clone(),
8069
analysis: self.analysis.clone(),
81-
status_cache,
8270
};
8371

8472
let storage = storage::StorageVisitor {
+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
use std::collections::HashMap;
2+
3+
use crate::graph::error::Error;
4+
use sea_orm::ConnectionTrait;
5+
use sea_orm::EntityTrait;
6+
use trustify_entity::status;
7+
use uuid::Uuid;
8+
9+
#[derive(Debug, Clone)]
10+
pub struct DbContext {
11+
pub status_cache: HashMap<String, Uuid>,
12+
}
13+
14+
impl DbContext {
15+
pub fn new() -> Self {
16+
Self {
17+
status_cache: HashMap::new(),
18+
}
19+
}
20+
21+
pub async fn load_statuses(&mut self, connection: &impl ConnectionTrait) -> Result<(), Error> {
22+
self.status_cache.clear();
23+
let statuses = status::Entity::find().all(connection).await?;
24+
statuses
25+
.iter()
26+
.map(|s| self.status_cache.insert(s.slug.clone(), s.id))
27+
.for_each(drop);
28+
29+
Ok(())
30+
}
31+
32+
pub async fn get_status_id(
33+
&mut self,
34+
status: &str,
35+
connection: &impl ConnectionTrait,
36+
) -> Result<Uuid, Error> {
37+
if let Some(s) = self.status_cache.get(status) {
38+
return Ok(*s);
39+
}
40+
41+
// If not found, reload the cache and check again
42+
self.load_statuses(connection).await?;
43+
44+
self.status_cache
45+
.get(status)
46+
.cloned()
47+
.ok_or_else(|| crate::graph::error::Error::InvalidStatus(status.to_string()))
48+
}
49+
}
50+
51+
impl Default for DbContext {
52+
fn default() -> Self {
53+
Self::new()
54+
}
55+
}

modules/ingestor/src/graph/mod.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
pub mod advisory;
22
pub mod cpe;
3+
pub mod db_context;
34
pub mod error;
45
pub mod organization;
56
pub mod product;
67
pub mod purl;
78
pub mod sbom;
89
pub mod vulnerability;
910

11+
use db_context::DbContext;
1012
use sea_orm::DbErr;
1113
use std::fmt::Debug;
1214
use std::ops::{Deref, DerefMut};
15+
use std::sync::Arc;
16+
use tokio::sync::Mutex;
1317

1418
#[derive(Debug, Clone)]
1519
pub struct Graph {
1620
pub(crate) db: trustify_common::db::Database,
21+
pub(crate) db_context: Arc<Mutex<DbContext>>,
1722
}
1823

1924
#[derive(Debug, thiserror::Error)]
@@ -26,7 +31,10 @@ pub enum Error<E: Send> {
2631

2732
impl Graph {
2833
pub fn new(db: trustify_common::db::Database) -> Self {
29-
Self { db }
34+
Self {
35+
db,
36+
db_context: Arc::new(Mutex::new(DbContext::new())),
37+
}
3038
}
3139
}
3240

modules/ingestor/src/service/advisory/csaf/creator.rs

+4-9
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@ use crate::{
88
Graph,
99
},
1010
service::{
11-
advisory::{
12-
csaf::{product_status::ProductStatus, util::ResolveProductIdCache},
13-
StatusCache,
14-
},
11+
advisory::csaf::{product_status::ProductStatus, util::ResolveProductIdCache},
1512
Error,
1613
},
1714
};
@@ -96,7 +93,6 @@ impl<'a> StatusCreator<'a> {
9693
&mut self,
9794
graph: &Graph,
9895
connection: &C,
99-
mut status_cache: StatusCache,
10096
) -> Result<(), Error> {
10197
let mut product_status_models = Vec::new();
10298
let mut purls = PurlCreator::new();
@@ -108,11 +104,10 @@ impl<'a> StatusCreator<'a> {
108104
let mut product_version_ranges = Vec::new();
109105

110106
let product_statuses = self.products.clone();
107+
let mut db_context = graph.db_context.lock().await;
111108

112109
for product in product_statuses {
113-
let status_id = status_cache
114-
.get_status_id(product.status, connection)
115-
.await?;
110+
let status_id = db_context.get_status_id(product.status, connection).await?;
116111

117112
// There should be only a few organizations per document,
118113
// so simple caching should work here.
@@ -248,7 +243,7 @@ impl<'a> StatusCreator<'a> {
248243
purl,
249244
scheme,
250245
spec,
251-
status_cache
246+
db_context
252247
.get_status_id(&Status::Affected.to_string(), connection)
253248
.await?,
254249
);

modules/ingestor/src/service/advisory/csaf/loader.rs

+8-29
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@ use crate::{
88
},
99
model::IngestResult,
1010
service::{
11-
advisory::{
12-
csaf::{util::gen_identifier, StatusCreator},
13-
StatusCache,
14-
},
11+
advisory::csaf::{util::gen_identifier, StatusCreator},
1512
Error, Warnings,
1613
},
1714
};
@@ -91,7 +88,6 @@ impl<'g> CsafLoader<'g> {
9188
labels: impl Into<Labels> + Debug,
9289
csaf: Csaf,
9390
digests: &Digests,
94-
status_cache: StatusCache,
9591
) -> Result<IngestResult, Error> {
9692
let warnings = Warnings::new();
9793

@@ -116,7 +112,7 @@ impl<'g> CsafLoader<'g> {
116112
.await?;
117113

118114
for vuln in csaf.vulnerabilities.iter().flatten() {
119-
self.ingest_vulnerability(&csaf, &advisory, vuln, &warnings, &tx, status_cache.clone())
115+
self.ingest_vulnerability(&csaf, &advisory, vuln, &warnings, &tx)
120116
.await?;
121117
}
122118

@@ -142,7 +138,6 @@ impl<'g> CsafLoader<'g> {
142138
vulnerability: &Vulnerability,
143139
report: &dyn ReportSink,
144140
connection: &C,
145-
status_cache: StatusCache,
146141
) -> Result<(), Error> {
147142
let Some(cve_id) = &vulnerability.cve else {
148143
return Ok(());
@@ -173,14 +168,8 @@ impl<'g> CsafLoader<'g> {
173168
.await?;
174169

175170
if let Some(product_status) = &vulnerability.product_status {
176-
self.ingest_product_statuses(
177-
csaf,
178-
&advisory_vulnerability,
179-
product_status,
180-
connection,
181-
status_cache,
182-
)
183-
.await?;
171+
self.ingest_product_statuses(csaf, &advisory_vulnerability, product_status, connection)
172+
.await?;
184173
}
185174

186175
for score in vulnerability.scores.iter().flatten() {
@@ -211,7 +200,6 @@ impl<'g> CsafLoader<'g> {
211200
advisory_vulnerability: &AdvisoryVulnerabilityContext<'_>,
212201
product_status: &ProductStatus,
213202
connection: &C,
214-
status_cache: StatusCache,
215203
) -> Result<(), Error> {
216204
let mut creator = StatusCreator::new(
217205
csaf,
@@ -226,7 +214,7 @@ impl<'g> CsafLoader<'g> {
226214
creator.add_all(&product_status.known_not_affected, "not_affected");
227215
creator.add_all(&product_status.known_affected, "affected");
228216

229-
creator.create(self.graph, connection, status_cache).await?;
217+
creator.create(self.graph, connection).await?;
230218

231219
Ok(())
232220
}
@@ -250,12 +238,7 @@ mod test {
250238
let (csaf, digests): (Csaf, _) = document("csaf/CVE-2023-20862.json").await?;
251239
let loader = CsafLoader::new(&graph);
252240
loader
253-
.load(
254-
("file", "CVE-2023-20862.json"),
255-
csaf,
256-
&digests,
257-
StatusCache::default(),
258-
)
241+
.load(("file", "CVE-2023-20862.json"), csaf, &digests)
259242
.await?;
260243

261244
let loaded_vulnerability = graph.get_vulnerability("CVE-2023-20862", &ctx.db).await?;
@@ -329,9 +312,7 @@ mod test {
329312
let loader = CsafLoader::new(&graph);
330313

331314
let (csaf, digests): (Csaf, _) = document("csaf/rhsa-2024_3666.json").await?;
332-
loader
333-
.load(("source", "test"), csaf, &digests, StatusCache::default())
334-
.await?;
315+
loader.load(("source", "test"), csaf, &digests).await?;
335316

336317
let loaded_vulnerability = graph.get_vulnerability("CVE-2024-23672", &ctx.db).await?;
337318
assert!(loaded_vulnerability.is_some());
@@ -372,9 +353,7 @@ mod test {
372353
let loader = CsafLoader::new(&graph);
373354

374355
let (csaf, digests): (Csaf, _) = document("csaf/cve-2023-0044.json").await?;
375-
loader
376-
.load(("source", "test"), csaf, &digests, StatusCache::default())
377-
.await?;
356+
loader.load(("source", "test"), csaf, &digests).await?;
378357

379358
let loaded_vulnerability = graph.get_vulnerability("CVE-2023-0044", &ctx.db).await?;
380359
assert!(loaded_vulnerability.is_some());
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,3 @@
1-
use std::collections::HashMap;
2-
3-
use crate::graph::error::Error;
4-
use sea_orm::ConnectionTrait;
5-
use sea_orm::EntityTrait;
6-
use trustify_entity::status;
7-
use uuid::Uuid;
8-
91
pub mod csaf;
102
pub mod cve;
113
pub mod osv;
12-
13-
#[derive(Debug, Clone)]
14-
pub struct StatusCache {
15-
pub cache: HashMap<String, Uuid>,
16-
}
17-
18-
impl StatusCache {
19-
pub fn new() -> Self {
20-
Self {
21-
cache: HashMap::new(),
22-
}
23-
}
24-
25-
pub async fn load_statuses(&mut self, connection: &impl ConnectionTrait) -> Result<(), Error> {
26-
self.cache.clear();
27-
let statuses = status::Entity::find().all(connection).await?;
28-
statuses
29-
.iter()
30-
.map(|s| self.cache.insert(s.slug.clone(), s.id))
31-
.for_each(drop);
32-
33-
Ok(())
34-
}
35-
36-
pub async fn get_status_id(
37-
&mut self,
38-
status: &str,
39-
connection: &impl ConnectionTrait,
40-
) -> Result<Uuid, Error> {
41-
if let Some(s) = self.cache.get(status) {
42-
return Ok(*s);
43-
}
44-
45-
// If not found, reload the cache and check again
46-
self.load_statuses(connection).await?;
47-
48-
self.cache
49-
.get(status)
50-
.cloned()
51-
.ok_or_else(|| crate::graph::error::Error::InvalidStatus(status.to_string()))
52-
}
53-
}
54-
55-
impl Default for StatusCache {
56-
fn default() -> Self {
57-
Self::new()
58-
}
59-
}

modules/ingestor/src/service/dataset/mod.rs

+1-15
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ use trustify_common::hashing::Digests;
2121
use trustify_entity::labels::Labels;
2222
use trustify_module_storage::{service::dispatch::DispatchBackend, service::StorageBackend};
2323

24-
use super::advisory::StatusCache;
25-
2624
pub struct DatasetLoader<'g> {
2725
graph: &'g Graph,
2826
storage: &'g DispatchBackend,
@@ -45,9 +43,6 @@ impl<'g> DatasetLoader<'g> {
4543

4644
let mut zip = zip::ZipArchive::new(Cursor::new(buffer))?;
4745

48-
let mut status_cache = StatusCache::new();
49-
status_cache.load_statuses(&self.graph.db).await?;
50-
5146
for i in 0..zip.len() {
5247
let mut file = zip.by_index(i)?;
5348

@@ -111,20 +106,11 @@ impl<'g> DatasetLoader<'g> {
111106
.await
112107
.map_err(|err| Error::Storage(anyhow!("{err}")))?;
113108

114-
let sc = status_cache.clone();
115-
116109
// We need to box it, to work around async recursion limits
117110
let result = Box::pin({
118111
async move {
119112
format
120-
.load(
121-
self.graph,
122-
labels,
123-
None,
124-
&Digests::digest(&data),
125-
&data,
126-
sc,
127-
)
113+
.load(self.graph, labels, None, &Digests::digest(&data), &data)
128114
.await
129115
}
130116
})

0 commit comments

Comments
 (0)