Skip to content

Commit 99a3e26

Browse files
committed
feat(domain): implement suggested fix version for packages
1 parent 29affcc commit 99a3e26

File tree

1 file changed

+167
-1
lines changed

1 file changed

+167
-1
lines changed

src/domain/scanresult/package.rs

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use crate::domain::scanresult::accepted_risk::AcceptedRisk;
22
use crate::domain::scanresult::layer::Layer;
33
use crate::domain::scanresult::package_type::PackageType;
4+
use crate::domain::scanresult::severity::Severity;
45
use crate::domain::scanresult::vulnerability::Vulnerability;
56
use crate::domain::scanresult::weak_hash::WeakHash;
67
use semver::Version;
7-
use std::collections::HashSet;
8+
use std::collections::{HashMap, HashSet};
89
use std::fmt::Debug;
910
use std::hash::{Hash, Hasher};
1011
use std::sync::{Arc, RwLock};
@@ -109,6 +110,69 @@ impl Package {
109110
.filter_map(|r| r.0.upgrade())
110111
.collect()
111112
}
113+
114+
pub fn suggested_fix_version(&self) -> Option<Version> {
115+
let vulnerabilities = self.vulnerabilities();
116+
if vulnerabilities.is_empty() {
117+
return None;
118+
}
119+
120+
let candidate_versions: Vec<Version> = vulnerabilities
121+
.iter()
122+
.filter_map(|vuln| vuln.fix_version().cloned())
123+
.collect::<HashSet<_>>()
124+
.into_iter()
125+
.collect();
126+
127+
if candidate_versions.is_empty() {
128+
return None;
129+
}
130+
131+
let severity_order = [
132+
Severity::Critical,
133+
Severity::High,
134+
Severity::Medium,
135+
Severity::Low,
136+
Severity::Negligible,
137+
Severity::Unknown,
138+
];
139+
140+
let mut scores: HashMap<Version, HashMap<Severity, usize>> = HashMap::new();
141+
142+
for candidate in &candidate_versions {
143+
let mut score: HashMap<Severity, usize> = HashMap::new();
144+
for severity in &severity_order {
145+
score.insert(*severity, 0);
146+
}
147+
for vuln in &vulnerabilities {
148+
if let Some(fix_version) = vuln.fix_version()
149+
&& fix_version == candidate
150+
{
151+
*score.entry(vuln.severity()).or_insert(0) += 1;
152+
}
153+
}
154+
scores.insert(candidate.clone(), score);
155+
}
156+
157+
let mut sorted_candidates = candidate_versions;
158+
sorted_candidates.sort_by(|a, b| {
159+
let score_a = scores.get(a).unwrap();
160+
let score_b = scores.get(b).unwrap();
161+
162+
for severity in &severity_order {
163+
let count_a = score_a.get(severity).unwrap();
164+
let count_b = score_b.get(severity).unwrap();
165+
if count_a != count_b {
166+
return count_b.cmp(count_a); // Higher count is better
167+
}
168+
}
169+
170+
// If scores are identical, lower version is better
171+
a.cmp(b)
172+
});
173+
174+
sorted_candidates.first().cloned()
175+
}
112176
}
113177

114178
impl PartialEq for Package {
@@ -144,3 +208,105 @@ impl Clone for Package {
144208
}
145209
}
146210
}
211+
212+
#[cfg(test)]
213+
mod tests {
214+
use super::*;
215+
use crate::domain::scanresult::layer::Layer;
216+
use crate::domain::scanresult::package_type::PackageType;
217+
use crate::domain::scanresult::severity::Severity;
218+
use crate::domain::scanresult::vulnerability::Vulnerability;
219+
use chrono::NaiveDate;
220+
use rstest::{fixture, rstest};
221+
use semver::Version;
222+
use std::sync::Arc;
223+
224+
#[fixture]
225+
fn layer() -> Arc<Layer> {
226+
Arc::new(Layer::new(
227+
"a_digest".to_string(),
228+
0,
229+
None,
230+
"a_command".to_string(),
231+
))
232+
}
233+
234+
#[fixture]
235+
fn package(#[default("")] version: &str, layer: Arc<Layer>) -> Arc<Package> {
236+
Arc::new(Package::new(
237+
PackageType::Os,
238+
"a_name".to_string(),
239+
Version::parse(version).unwrap(),
240+
"a_path".to_string(),
241+
layer,
242+
))
243+
}
244+
245+
fn a_vulnerability(
246+
cve: &str,
247+
severity: Severity,
248+
fix_version: Option<&str>,
249+
) -> Arc<Vulnerability> {
250+
Arc::new(Vulnerability::new(
251+
cve.to_string(),
252+
severity,
253+
NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(),
254+
None,
255+
false,
256+
fix_version.map(|v| Version::parse(v).unwrap()),
257+
))
258+
}
259+
260+
#[rstest]
261+
#[case("is_none_when_no_vulnerabilities", "1.0.0", vec![], None)]
262+
#[case("is_none_when_no_fixable_vulnerabilities", "1.0.0", vec![a_vulnerability("CVE-1", Severity::High, None)], None)]
263+
#[case("returns_only_available_fix", "1.0.0", vec![a_vulnerability("CVE-1", Severity::High, Some("1.0.1"))], Some("1.0.1"))]
264+
#[case("chooses_version_with_more_critical_fixes", "1.0.0", vec![
265+
a_vulnerability("CVE-1", Severity::Critical, Some("1.0.1")),
266+
a_vulnerability("CVE-2", Severity::Critical, Some("1.0.2")),
267+
a_vulnerability("CVE-3", Severity::High, Some("1.0.2")),
268+
], Some("1.0.2"))]
269+
#[case("chooses_version_with_more_high_fixes_when_criticals_tied", "1.0.0", vec![
270+
a_vulnerability("CVE-1", Severity::Critical, Some("1.0.1")),
271+
a_vulnerability("CVE-5", Severity::Medium, Some("1.0.1")),
272+
a_vulnerability("CVE-2", Severity::Critical, Some("1.0.2")),
273+
a_vulnerability("CVE-3", Severity::High, Some("1.0.2")),
274+
a_vulnerability("CVE-4", Severity::High, Some("1.0.2")),
275+
], Some("1.0.2"))]
276+
#[case("chooses_lower_version_when_counts_are_tied", "1.0.0", vec![
277+
a_vulnerability("CVE-1", Severity::Critical, Some("1.0.1")),
278+
a_vulnerability("CVE-3", Severity::High, Some("1.0.1")),
279+
a_vulnerability("CVE-2", Severity::Critical, Some("1.0.2")),
280+
a_vulnerability("CVE-4", Severity::High, Some("1.0.2")),
281+
], Some("1.0.1"))]
282+
#[case("handles_complex_scenario", "2.8.1", vec![
283+
a_vulnerability("CVE-2022-25857", Severity::High, Some("2.8.2")),
284+
a_vulnerability("CVE-2022-39253", Severity::High, Some("2.8.2")),
285+
a_vulnerability("CVE-2022-0536", Severity::Medium, Some("2.8.2")),
286+
a_vulnerability("CVE-2022-41724", Severity::Medium, Some("2.8.2")),
287+
a_vulnerability("CVE-2022-41725", Severity::Medium, Some("2.8.2")),
288+
289+
a_vulnerability("CVE-2021-33574", Severity::Critical, Some("2.9.0")),
290+
a_vulnerability("CVE-2022-25857", Severity::High, Some("2.9.0")),
291+
a_vulnerability("CVE-2022-39253", Severity::High, Some("2.9.0")),
292+
a_vulnerability("CVE-2022-0536", Severity::Medium, Some("2.9.0")),
293+
a_vulnerability("CVE-2022-41724", Severity::Medium, Some("2.9.0")),
294+
a_vulnerability("CVE-2022-41725", Severity::Medium, Some("2.9.0")),
295+
], Some("2.9.0"))]
296+
fn test_suggested_fix_version(
297+
#[case] _description: &str,
298+
#[case] version: &str,
299+
#[with(version)] package: Arc<Package>,
300+
#[case] vulnerabilities: Vec<Arc<Vulnerability>>,
301+
#[case] expected_fix: Option<&str>,
302+
) {
303+
assert_eq!(package.version(), &Version::parse(version).unwrap());
304+
305+
for vuln in &vulnerabilities {
306+
package.add_vulnerability_found(vuln.clone());
307+
}
308+
309+
let expected = expected_fix.map(|v| Version::parse(v).unwrap());
310+
assert_eq!(package.suggested_fix_version(), expected);
311+
}
312+
}

0 commit comments

Comments
 (0)