|
1 | 1 | use crate::domain::scanresult::accepted_risk::AcceptedRisk; |
2 | 2 | use crate::domain::scanresult::layer::Layer; |
3 | 3 | use crate::domain::scanresult::package_type::PackageType; |
| 4 | +use crate::domain::scanresult::severity::Severity; |
4 | 5 | use crate::domain::scanresult::vulnerability::Vulnerability; |
5 | 6 | use crate::domain::scanresult::weak_hash::WeakHash; |
6 | 7 | use semver::Version; |
7 | | -use std::collections::HashSet; |
| 8 | +use std::collections::{HashMap, HashSet}; |
8 | 9 | use std::fmt::Debug; |
9 | 10 | use std::hash::{Hash, Hasher}; |
10 | 11 | use std::sync::{Arc, RwLock}; |
@@ -109,6 +110,69 @@ impl Package { |
109 | 110 | .filter_map(|r| r.0.upgrade()) |
110 | 111 | .collect() |
111 | 112 | } |
| 113 | + |
| 114 | + pub fn suggested_fix_version(&self) -> Option<Version> { |
| 115 | + let vulnerabilities = self.vulnerabilities(); |
| 116 | + if vulnerabilities.is_empty() { |
| 117 | + return None; |
| 118 | + } |
| 119 | + |
| 120 | + let candidate_versions: Vec<Version> = vulnerabilities |
| 121 | + .iter() |
| 122 | + .filter_map(|vuln| vuln.fix_version().cloned()) |
| 123 | + .collect::<HashSet<_>>() |
| 124 | + .into_iter() |
| 125 | + .collect(); |
| 126 | + |
| 127 | + if candidate_versions.is_empty() { |
| 128 | + return None; |
| 129 | + } |
| 130 | + |
| 131 | + let severity_order = [ |
| 132 | + Severity::Critical, |
| 133 | + Severity::High, |
| 134 | + Severity::Medium, |
| 135 | + Severity::Low, |
| 136 | + Severity::Negligible, |
| 137 | + Severity::Unknown, |
| 138 | + ]; |
| 139 | + |
| 140 | + let mut scores: HashMap<Version, HashMap<Severity, usize>> = HashMap::new(); |
| 141 | + |
| 142 | + for candidate in &candidate_versions { |
| 143 | + let mut score: HashMap<Severity, usize> = HashMap::new(); |
| 144 | + for severity in &severity_order { |
| 145 | + score.insert(*severity, 0); |
| 146 | + } |
| 147 | + for vuln in &vulnerabilities { |
| 148 | + if let Some(fix_version) = vuln.fix_version() |
| 149 | + && fix_version == candidate |
| 150 | + { |
| 151 | + *score.entry(vuln.severity()).or_insert(0) += 1; |
| 152 | + } |
| 153 | + } |
| 154 | + scores.insert(candidate.clone(), score); |
| 155 | + } |
| 156 | + |
| 157 | + let mut sorted_candidates = candidate_versions; |
| 158 | + sorted_candidates.sort_by(|a, b| { |
| 159 | + let score_a = scores.get(a).unwrap(); |
| 160 | + let score_b = scores.get(b).unwrap(); |
| 161 | + |
| 162 | + for severity in &severity_order { |
| 163 | + let count_a = score_a.get(severity).unwrap(); |
| 164 | + let count_b = score_b.get(severity).unwrap(); |
| 165 | + if count_a != count_b { |
| 166 | + return count_b.cmp(count_a); // Higher count is better |
| 167 | + } |
| 168 | + } |
| 169 | + |
| 170 | + // If scores are identical, lower version is better |
| 171 | + a.cmp(b) |
| 172 | + }); |
| 173 | + |
| 174 | + sorted_candidates.first().cloned() |
| 175 | + } |
112 | 176 | } |
113 | 177 |
|
114 | 178 | impl PartialEq for Package { |
@@ -144,3 +208,105 @@ impl Clone for Package { |
144 | 208 | } |
145 | 209 | } |
146 | 210 | } |
| 211 | + |
| 212 | +#[cfg(test)] |
| 213 | +mod tests { |
| 214 | + use super::*; |
| 215 | + use crate::domain::scanresult::layer::Layer; |
| 216 | + use crate::domain::scanresult::package_type::PackageType; |
| 217 | + use crate::domain::scanresult::severity::Severity; |
| 218 | + use crate::domain::scanresult::vulnerability::Vulnerability; |
| 219 | + use chrono::NaiveDate; |
| 220 | + use rstest::{fixture, rstest}; |
| 221 | + use semver::Version; |
| 222 | + use std::sync::Arc; |
| 223 | + |
| 224 | + #[fixture] |
| 225 | + fn layer() -> Arc<Layer> { |
| 226 | + Arc::new(Layer::new( |
| 227 | + "a_digest".to_string(), |
| 228 | + 0, |
| 229 | + None, |
| 230 | + "a_command".to_string(), |
| 231 | + )) |
| 232 | + } |
| 233 | + |
| 234 | + #[fixture] |
| 235 | + fn package(#[default("")] version: &str, layer: Arc<Layer>) -> Arc<Package> { |
| 236 | + Arc::new(Package::new( |
| 237 | + PackageType::Os, |
| 238 | + "a_name".to_string(), |
| 239 | + Version::parse(version).unwrap(), |
| 240 | + "a_path".to_string(), |
| 241 | + layer, |
| 242 | + )) |
| 243 | + } |
| 244 | + |
| 245 | + fn a_vulnerability( |
| 246 | + cve: &str, |
| 247 | + severity: Severity, |
| 248 | + fix_version: Option<&str>, |
| 249 | + ) -> Arc<Vulnerability> { |
| 250 | + Arc::new(Vulnerability::new( |
| 251 | + cve.to_string(), |
| 252 | + severity, |
| 253 | + NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(), |
| 254 | + None, |
| 255 | + false, |
| 256 | + fix_version.map(|v| Version::parse(v).unwrap()), |
| 257 | + )) |
| 258 | + } |
| 259 | + |
| 260 | + #[rstest] |
| 261 | + #[case("is_none_when_no_vulnerabilities", "1.0.0", vec![], None)] |
| 262 | + #[case("is_none_when_no_fixable_vulnerabilities", "1.0.0", vec![a_vulnerability("CVE-1", Severity::High, None)], None)] |
| 263 | + #[case("returns_only_available_fix", "1.0.0", vec![a_vulnerability("CVE-1", Severity::High, Some("1.0.1"))], Some("1.0.1"))] |
| 264 | + #[case("chooses_version_with_more_critical_fixes", "1.0.0", vec![ |
| 265 | + a_vulnerability("CVE-1", Severity::Critical, Some("1.0.1")), |
| 266 | + a_vulnerability("CVE-2", Severity::Critical, Some("1.0.2")), |
| 267 | + a_vulnerability("CVE-3", Severity::High, Some("1.0.2")), |
| 268 | + ], Some("1.0.2"))] |
| 269 | + #[case("chooses_version_with_more_high_fixes_when_criticals_tied", "1.0.0", vec![ |
| 270 | + a_vulnerability("CVE-1", Severity::Critical, Some("1.0.1")), |
| 271 | + a_vulnerability("CVE-5", Severity::Medium, Some("1.0.1")), |
| 272 | + a_vulnerability("CVE-2", Severity::Critical, Some("1.0.2")), |
| 273 | + a_vulnerability("CVE-3", Severity::High, Some("1.0.2")), |
| 274 | + a_vulnerability("CVE-4", Severity::High, Some("1.0.2")), |
| 275 | + ], Some("1.0.2"))] |
| 276 | + #[case("chooses_lower_version_when_counts_are_tied", "1.0.0", vec![ |
| 277 | + a_vulnerability("CVE-1", Severity::Critical, Some("1.0.1")), |
| 278 | + a_vulnerability("CVE-3", Severity::High, Some("1.0.1")), |
| 279 | + a_vulnerability("CVE-2", Severity::Critical, Some("1.0.2")), |
| 280 | + a_vulnerability("CVE-4", Severity::High, Some("1.0.2")), |
| 281 | + ], Some("1.0.1"))] |
| 282 | + #[case("handles_complex_scenario", "2.8.1", vec![ |
| 283 | + a_vulnerability("CVE-2022-25857", Severity::High, Some("2.8.2")), |
| 284 | + a_vulnerability("CVE-2022-39253", Severity::High, Some("2.8.2")), |
| 285 | + a_vulnerability("CVE-2022-0536", Severity::Medium, Some("2.8.2")), |
| 286 | + a_vulnerability("CVE-2022-41724", Severity::Medium, Some("2.8.2")), |
| 287 | + a_vulnerability("CVE-2022-41725", Severity::Medium, Some("2.8.2")), |
| 288 | +
|
| 289 | + a_vulnerability("CVE-2021-33574", Severity::Critical, Some("2.9.0")), |
| 290 | + a_vulnerability("CVE-2022-25857", Severity::High, Some("2.9.0")), |
| 291 | + a_vulnerability("CVE-2022-39253", Severity::High, Some("2.9.0")), |
| 292 | + a_vulnerability("CVE-2022-0536", Severity::Medium, Some("2.9.0")), |
| 293 | + a_vulnerability("CVE-2022-41724", Severity::Medium, Some("2.9.0")), |
| 294 | + a_vulnerability("CVE-2022-41725", Severity::Medium, Some("2.9.0")), |
| 295 | + ], Some("2.9.0"))] |
| 296 | + fn test_suggested_fix_version( |
| 297 | + #[case] _description: &str, |
| 298 | + #[case] version: &str, |
| 299 | + #[with(version)] package: Arc<Package>, |
| 300 | + #[case] vulnerabilities: Vec<Arc<Vulnerability>>, |
| 301 | + #[case] expected_fix: Option<&str>, |
| 302 | + ) { |
| 303 | + assert_eq!(package.version(), &Version::parse(version).unwrap()); |
| 304 | + |
| 305 | + for vuln in &vulnerabilities { |
| 306 | + package.add_vulnerability_found(vuln.clone()); |
| 307 | + } |
| 308 | + |
| 309 | + let expected = expected_fix.map(|v| Version::parse(v).unwrap()); |
| 310 | + assert_eq!(package.suggested_fix_version(), expected); |
| 311 | + } |
| 312 | +} |
0 commit comments