Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize vulnerability host counts #24914

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/22364-vuln-cron
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* fixed issue where the vulnerabilities cron was failing in large environments due to large SQL queries
4 changes: 3 additions & 1 deletion cmd/fleet/cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ func cronVulnerabilities(
return fmt.Errorf("scanning vulnerabilities: %w", err)
}

start := time.Now()
level.Info(logger).Log("msg", "updating vulnerability host counts")
if err := ds.UpdateVulnerabilityHostCounts(ctx); err != nil {
if err := ds.UpdateVulnerabilityHostCounts(ctx, config.MaxRoutines); err != nil {
return fmt.Errorf("updating vulnerability host counts: %w", err)
}
level.Info(logger).Log("msg", "vulnerability host counts updated", "took", time.Since(start).Seconds())
}

return nil
Expand Down
7 changes: 7 additions & 0 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ type VulnerabilitiesConfig struct {
DisableDataSync bool `json:"disable_data_sync" yaml:"disable_data_sync"`
RecentVulnerabilityMaxAge time.Duration `json:"recent_vulnerability_max_age" yaml:"recent_vulnerability_max_age"`
DisableWinOSVulnerabilities bool `json:"disable_win_os_vulnerabilities" yaml:"disable_win_os_vulnerabilities"`
MaxRoutines int `json:"max_routines" yaml:"max_routines"`
}

// UpgradesConfig defines configs related to fleet server upgrades.
Expand Down Expand Up @@ -1257,6 +1258,11 @@ func (man Manager) addConfigs() {
false,
"Don't sync installed Windows updates nor perform Windows OS vulnerability processing.",
)
man.addConfigInt(
"vulnerabilities.max_routines",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"vulnerabilities.max_routines",
"vulnerabilities.max_concurrency",

Seems like "concurrency" is more self-evident here. Guessing you cycled through that as a naming idea here, so it'd be useful to understand why this naming convention won.

5,
"Maximum number of concurrent database queries to use for processing vulnerabilities.",
)

// Upgrades
man.addConfigBool("upgrades.allow_missing_migrations", false,
Expand Down Expand Up @@ -1528,6 +1534,7 @@ func (man Manager) LoadConfig() FleetConfig {
DisableDataSync: man.getConfigBool("vulnerabilities.disable_data_sync"),
RecentVulnerabilityMaxAge: man.getConfigDuration("vulnerabilities.recent_vulnerability_max_age"),
DisableWinOSVulnerabilities: man.getConfigBool("vulnerabilities.disable_win_os_vulnerabilities"),
MaxRoutines: man.getConfigInt("vulnerabilities.max_routines"),
},
Upgrades: UpgradesConfig{
AllowMissingMigrations: man.getConfigBool("upgrades.allow_missing_migrations"),
Expand Down
251 changes: 185 additions & 66 deletions server/datastore/mysql/vulnerabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"strings"
"sync"
"time"

"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
Expand Down Expand Up @@ -342,30 +343,193 @@ func (ds *Datastore) CountVulnerabilities(ctx context.Context, opt fleet.VulnLis
return count, nil
}

func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context) error {
func (ds *Datastore) distinctCVEs(ctx context.Context) ([]string, error) {
uniqueCVEQuery := `
SELECT DISTINCT cve FROM (
SELECT cve FROM software_cve
UNION
SELECT cve FROM operating_system_vulnerabilities
) AS combined_cves;
`

var cves []string
err := sqlx.SelectContext(ctx, ds.reader(ctx), &cves, uniqueCVEQuery)
if err != nil {
return nil, err
}
return cves, nil
}

type CountScope int

const (
GlobalCount CountScope = iota
NoTeamCount
TeamCount
)

func (ds *Datastore) batchFetchVulnerabilityCounts(
ctx context.Context,
scope CountScope,
maxRoutines int,
) ([]hostCount, error) {
const (
batchSize = 20
)

// Fetch distinct CVEs
allCVEs, err := ds.distinctCVEs(ctx)
if err != nil {
return nil, err
}

query := getVulnHostCountQuery(scope)
if query == "" {
return nil, ctxerr.Errorf(ctx, "invalid scope: %d", scope)
}

var (
hostCounts []hostCount
mu sync.Mutex
wg sync.WaitGroup
sem = make(chan struct{}, maxRoutines)
errChan = make(chan error, len(allCVEs)/batchSize+1)
)

// Process CVEs in batches concurrently
for i := 0; i < len(allCVEs); i += batchSize {
end := i + batchSize
if end > len(allCVEs) {
end = len(allCVEs)
}

batchCVEs := allCVEs[i:end]
wg.Add(1)
sem <- struct{}{} // Acquire semaphore

go func(cves []string) {
defer wg.Done()
defer func() { <-sem }() // Release semaphore

counts, err := ds.fetchBatchCounts(ctx, cves, query)
if err != nil {
errChan <- err
return
}

mu.Lock()
hostCounts = append(hostCounts, counts...)
mu.Unlock()
}(batchCVEs)
}

wg.Wait()
close(errChan)

// Check for errors
for err := range errChan {
if err != nil {
return nil, err
}
}

return hostCounts, nil
}

// fetchBatchCounts executes the query for a batch of CVEs.
func (ds *Datastore) fetchBatchCounts(
ctx context.Context,
batchCVEs []string,
scopeConfig string,
) ([]hostCount, error) {
query, args, err := sqlx.In(scopeConfig, batchCVEs, batchCVEs)
if err != nil {
return nil, err
}

var counts []hostCount
err = sqlx.SelectContext(ctx, ds.reader(ctx), &counts, query, args...)
if err != nil {
return nil, err
}

return counts, nil
}

// getScopeConfig returns the query configuration for the given scope.
func getVulnHostCountQuery(scope CountScope) string {
switch scope {
case GlobalCount:
return `
SELECT 0 as team_id, 1 as global_stats, combined_results.cve, COUNT(*) AS host_count
FROM (
SELECT sc.cve, hs.host_id
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id
WHERE sc.cve IN (?)

UNION

SELECT osv.cve, hos.host_id
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
WHERE osv.cve IN (?)
) AS combined_results
GROUP BY cve
`
case NoTeamCount:
return `
SELECT 0 as team_id, 0 as global_stats, combined_results.cve, COUNT(*) AS host_count
FROM (
SELECT sc.cve, hs.host_id
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id
WHERE sc.cve IN (?)

UNION

SELECT osv.cve, hos.host_id
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
WHERE osv.cve IN (?)
) AS combined_results
INNER JOIN hosts h ON combined_results.host_id = h.id
WHERE h.team_id IS NULL
GROUP BY cve
`
case TeamCount:
return `
SELECT h.team_id as team_id, 0 as global_stats, combined_results.cve, COUNT(*) AS host_count
FROM (
SELECT sc.cve, hs.host_id
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id
WHERE sc.cve IN (?)

UNION

SELECT osv.cve, hos.host_id
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
WHERE osv.cve IN (?)
) AS combined_results
INNER JOIN hosts h ON combined_results.host_id = h.id
WHERE h.team_id IS NOT NULL
GROUP BY h.team_id, combined_results.cve
`
default:
return ""
}
}

func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context, maxRoutines int) error {
// set all counts to 0 to later identify rows to delete
_, err := ds.writer(ctx).ExecContext(ctx, "UPDATE vulnerability_host_counts SET host_count = 0")
if err != nil {
return ctxerr.Wrap(ctx, err, "initializing vulnerability host counts")
}

globalSelectStmt := `
SELECT 0 as team_id, 1 as global_stats, cve, COUNT(*) AS host_count
FROM (
SELECT sc.cve, hs.host_id
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id

UNION

SELECT osv.cve, hos.host_id
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
) AS combined_results
GROUP BY cve;
`

globalHostCounts, err := ds.fetchHostCounts(ctx, globalSelectStmt)
globalHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, GlobalCount, maxRoutines)
if err != nil {
return ctxerr.Wrap(ctx, err, "fetching global vulnerability host counts")
}
Expand All @@ -375,25 +539,7 @@ func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context) error {
return ctxerr.Wrap(ctx, err, "inserting global vulnerability host counts")
}

teamSelectStmt := `
SELECT h.team_id, 0 as global_stats, combined_results.cve, COUNT(*) AS host_count
FROM (
SELECT hs.host_id, sc.cve
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id

UNION

SELECT hos.host_id, osv.cve
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
) AS combined_results
INNER JOIN hosts h ON combined_results.host_id = h.id
WHERE h.team_id IS NOT NULL
GROUP BY h.team_id, combined_results.cve
`

teamHostCounts, err := ds.fetchHostCounts(ctx, teamSelectStmt)
teamHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, TeamCount, maxRoutines)
if err != nil {
return ctxerr.Wrap(ctx, err, "fetching team vulnerability host counts")
}
Expand All @@ -403,27 +549,9 @@ func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context) error {
return ctxerr.Wrap(ctx, err, "inserting team vulnerability host counts")
}

noTeamSelectStmt := `
SELECT 0 as team_id, 0 as global_stats, cve, COUNT(*) AS host_count
FROM (
SELECT hs.host_id, sc.cve
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id

UNION

SELECT hos.host_id, osv.cve
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
) AS combined_results
INNER JOIN hosts h ON combined_results.host_id = h.id
WHERE h.team_id IS NULL
GROUP BY cve
`

noTeamHostCounts, err := ds.fetchHostCounts(ctx, noTeamSelectStmt)
noTeamHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, NoTeamCount, maxRoutines)
if err != nil {
return ctxerr.Wrap(ctx, err, "fetching team vulnerability host counts")
return ctxerr.Wrap(ctx, err, "fetching no team vulnerability host counts")
}

err = ds.batchInsertHostCounts(ctx, noTeamHostCounts)
Expand Down Expand Up @@ -455,15 +583,6 @@ func (ds *Datastore) cleanupVulnerabilityHostCounts(ctx context.Context) error {
return nil
}

func (ds *Datastore) fetchHostCounts(ctx context.Context, query string) ([]hostCount, error) {
var hostCounts []hostCount
err := sqlx.SelectContext(ctx, ds.reader(ctx), &hostCounts, query)
if err != nil {
return nil, err
}
return hostCounts, nil
}

func (ds *Datastore) batchInsertHostCounts(ctx context.Context, counts []hostCount) error {
if len(counts) == 0 {
return nil
Expand Down
Loading
Loading