From e69f4e2893cbbab34d79ff9653cdb99cbbc215e2 Mon Sep 17 00:00:00 2001 From: anemeth Date: Wed, 2 Jul 2025 15:48:06 -0700 Subject: [PATCH 1/3] fix: Add cache invalidation --- src/SharpLinks.cs | 93 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 72 insertions(+), 21 deletions(-) diff --git a/src/SharpLinks.cs b/src/SharpLinks.cs index baa1321..8f730a7 100644 --- a/src/SharpLinks.cs +++ b/src/SharpLinks.cs @@ -19,6 +19,7 @@ using System.DirectoryServices.ActiveDirectory; using System.IO; using System.Linq; +using System.Reflection; using System.Security.Principal; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -137,61 +138,106 @@ public IContext InitCommonLib(IContext context) { context.Logger.LogTrace("Getting cache path"); var path = context.GetCachePath(); context.Logger.LogTrace("Cache Path: {Path}", path); + Cache cache; - if (!File.Exists(path)) { + if (!File.Exists(path)) + { context.Logger.LogTrace("Cache file does not exist"); cache = null; - } else - try { + } + else if (context.Flags.InvalidateCache) + { + context.Logger.LogTrace($"Skipping cache load per option {nameof(Options.RebuildCache)}"); + cache = null; + } + else + { + try + { context.Logger.LogTrace("Loading cache from disk"); var json = File.ReadAllText(path); cache = JsonConvert.DeserializeObject(json, CacheContractResolver.Settings); context.Logger.LogInformation("Loaded cache with stats: {stats}", cache?.GetCacheStats()); - } catch (Exception e) { + } + catch (Exception e) + { context.Logger.LogError("Error loading cache: {exception}, creating new", e); cache = null; } + var version = Assembly.GetExecutingAssembly().GetName().Version; + if (CacheNeedsInvalidation(cache, version)) + { + context.Logger.LogInformation("Old cache found, ignoring"); + cache = null; + } + } + CommonLib.InitializeCommonLib(context.Logger, cache); context.Logger.LogTrace("Exiting InitCommonLib"); return context; } - public async Task GetDomainsForEnumeration(IContext context) { + private bool CacheNeedsInvalidation(Cache cache, Version version) + { + var threshold = DateTime.Now.Subtract(TimeSpan.FromDays(30)); + if (cache.CacheCreationDate < threshold) { + return true; + } + + if (cache.CacheCreationVersion == null || version > cache.CacheCreationVersion) + { + return true; + } + + return false; + } + + public async Task GetDomainsForEnumeration(IContext context) + { context.Logger.LogTrace("Entering GetDomainsForEnumeration"); - if (context.Flags.RecurseDomains) { + if (context.Flags.RecurseDomains) + { context.Logger.LogInformation( "[RecurseDomains] Cross-domain enumeration may result in reduced data quality"); context.Domains = await BuildRecursiveDomainList(context).ToArrayAsync(); return context; } - if (context.Flags.SearchForest) { + if (context.Flags.SearchForest) + { context.Logger.LogInformation( "[SearchForest] Cross-domain enumeration may result in reduced data quality"); - if (!context.LDAPUtils.GetDomain(context.DomainName, out var dObj)) { + if (!context.LDAPUtils.GetDomain(context.DomainName, out var dObj)) + { context.Logger.LogError("Unable to get domain object for SearchForest"); context.Flags.IsFaulted = true; return context; } Forest forest; - try { + try + { forest = dObj.Forest; - } catch (Exception e) { + } + catch (Exception e) + { context.Logger.LogError("Unable to get forest object for SearchForest: {Message}", e.Message); context.Flags.IsFaulted = true; return context; } var temp = new List(); - foreach (Domain d in forest.Domains) { + foreach (Domain d in forest.Domains) + { var entry = d.GetDirectoryEntry().ToDirectoryObject(); - if (!entry.TryGetSecurityIdentifier(out var domainSid)) { + if (!entry.TryGetSecurityIdentifier(out var domainSid)) + { continue; } - temp.Add(new EnumerationDomain() { + temp.Add(new EnumerationDomain() + { Name = d.Name, DomainSid = domainSid }); @@ -203,28 +249,33 @@ public async Task GetDomainsForEnumeration(IContext context) { return context; } - if (!context.LDAPUtils.GetDomain(context.DomainName, out var domainObject)) { + if (!context.LDAPUtils.GetDomain(context.DomainName, out var domainObject)) + { context.Logger.LogError("Unable to resolve a domain to use, manually specify one or check spelling"); context.Flags.IsFaulted = true; return context; } var domain = domainObject?.Name ?? context.DomainName; - if (domain == null) { + if (domain == null) + { context.Logger.LogError("Unable to resolve a domain to use, manually specify one or check spelling"); context.Flags.IsFaulted = true; return context; } if (domainObject != null && domainObject.GetDirectoryEntry().ToDirectoryObject() - .TryGetSecurityIdentifier(out var sid)) { + .TryGetSecurityIdentifier(out var sid)) + { context.Domains = new[] { new EnumerationDomain { Name = domain, DomainSid = sid } }; - } else { + } + else + { context.Domains = new[] { new EnumerationDomain { Name = domain, @@ -315,11 +366,11 @@ public IContext Finish(IContext context) { } public IContext SaveCacheFile(IContext context) { - if (context.Flags.MemCache) - return context; - // 15. Program exit started. Save the cache file + // if (context.Flags.MemCache) + // return context; + // // 15. Program exit started. Save the cache file var cache = Cache.GetCacheInstance(); - context.Logger.LogInformation("Saving cache with stats: {stats}", cache.GetCacheStats()); + // context.Logger.LogInformation("Saving cache with stats: {stats}", cache.GetCacheStats()); var serialized = JsonConvert.SerializeObject(cache, CacheContractResolver.Settings); using var stream = new StreamWriter(context.GetCachePath()); From fe40ec8ce08e6d8e2ebf9e5ecfc5dbef451e8341 Mon Sep 17 00:00:00 2001 From: anemeth Date: Wed, 2 Jul 2025 15:51:10 -0700 Subject: [PATCH 2/3] chore: Formatting --- src/SharpLinks.cs | 72 +++++++++++++++++------------------------------ 1 file changed, 26 insertions(+), 46 deletions(-) diff --git a/src/SharpLinks.cs b/src/SharpLinks.cs index 8f730a7..d3fcca0 100644 --- a/src/SharpLinks.cs +++ b/src/SharpLinks.cs @@ -30,8 +30,7 @@ using SharpHoundCommonLib.Processors; using Timer = System.Timers.Timer; -namespace Sharphound -{ +namespace Sharphound { internal class SharpLinks : Links { /// /// Define methods that SharpHound executes as part of operation pipeline. @@ -63,7 +62,8 @@ public IContext Initialize(IContext context, LdapConfig options) { if (!context.LDAPUtils.GetDomain(out var d)) { context.Logger.LogCritical("unable to get current domain"); context.Flags.IsFaulted = true; - } else { + } + else { context.DomainName = d.Name; context.Logger.LogInformation("Resolved current domain to {Domain}", d.Name); } @@ -91,7 +91,8 @@ public IContext Initialize(IContext context, LdapConfig options) { } File.Delete(filename); - } catch (Exception e) { + } + catch (Exception e) { context.Logger.LogCritical(e, "unable to write to target directory"); context.Flags.IsFaulted = true; } @@ -140,34 +141,28 @@ public IContext InitCommonLib(IContext context) { context.Logger.LogTrace("Cache Path: {Path}", path); Cache cache; - if (!File.Exists(path)) - { + if (!File.Exists(path)) { context.Logger.LogTrace("Cache file does not exist"); cache = null; } - else if (context.Flags.InvalidateCache) - { + else if (context.Flags.InvalidateCache) { context.Logger.LogTrace($"Skipping cache load per option {nameof(Options.RebuildCache)}"); cache = null; } - else - { - try - { + else { + try { context.Logger.LogTrace("Loading cache from disk"); var json = File.ReadAllText(path); cache = JsonConvert.DeserializeObject(json, CacheContractResolver.Settings); context.Logger.LogInformation("Loaded cache with stats: {stats}", cache?.GetCacheStats()); } - catch (Exception e) - { + catch (Exception e) { context.Logger.LogError("Error loading cache: {exception}, creating new", e); cache = null; } var version = Assembly.GetExecutingAssembly().GetName().Version; - if (CacheNeedsInvalidation(cache, version)) - { + if (CacheNeedsInvalidation(cache, version)) { context.Logger.LogInformation("Old cache found, ignoring"); cache = null; } @@ -178,66 +173,55 @@ public IContext InitCommonLib(IContext context) { return context; } - private bool CacheNeedsInvalidation(Cache cache, Version version) - { + private bool CacheNeedsInvalidation(Cache cache, Version version) { var threshold = DateTime.Now.Subtract(TimeSpan.FromDays(30)); if (cache.CacheCreationDate < threshold) { return true; } - if (cache.CacheCreationVersion == null || version > cache.CacheCreationVersion) - { + if (cache.CacheCreationVersion == null || version > cache.CacheCreationVersion) { return true; } return false; } - public async Task GetDomainsForEnumeration(IContext context) - { + public async Task GetDomainsForEnumeration(IContext context) { context.Logger.LogTrace("Entering GetDomainsForEnumeration"); - if (context.Flags.RecurseDomains) - { + if (context.Flags.RecurseDomains) { context.Logger.LogInformation( "[RecurseDomains] Cross-domain enumeration may result in reduced data quality"); context.Domains = await BuildRecursiveDomainList(context).ToArrayAsync(); return context; } - if (context.Flags.SearchForest) - { + if (context.Flags.SearchForest) { context.Logger.LogInformation( "[SearchForest] Cross-domain enumeration may result in reduced data quality"); - if (!context.LDAPUtils.GetDomain(context.DomainName, out var dObj)) - { + if (!context.LDAPUtils.GetDomain(context.DomainName, out var dObj)) { context.Logger.LogError("Unable to get domain object for SearchForest"); context.Flags.IsFaulted = true; return context; } Forest forest; - try - { + try { forest = dObj.Forest; } - catch (Exception e) - { + catch (Exception e) { context.Logger.LogError("Unable to get forest object for SearchForest: {Message}", e.Message); context.Flags.IsFaulted = true; return context; } var temp = new List(); - foreach (Domain d in forest.Domains) - { + foreach (Domain d in forest.Domains) { var entry = d.GetDirectoryEntry().ToDirectoryObject(); - if (!entry.TryGetSecurityIdentifier(out var domainSid)) - { + if (!entry.TryGetSecurityIdentifier(out var domainSid)) { continue; } - temp.Add(new EnumerationDomain() - { + temp.Add(new EnumerationDomain() { Name = d.Name, DomainSid = domainSid }); @@ -249,24 +233,21 @@ public async Task GetDomainsForEnumeration(IContext context) return context; } - if (!context.LDAPUtils.GetDomain(context.DomainName, out var domainObject)) - { + if (!context.LDAPUtils.GetDomain(context.DomainName, out var domainObject)) { context.Logger.LogError("Unable to resolve a domain to use, manually specify one or check spelling"); context.Flags.IsFaulted = true; return context; } var domain = domainObject?.Name ?? context.DomainName; - if (domain == null) - { + if (domain == null) { context.Logger.LogError("Unable to resolve a domain to use, manually specify one or check spelling"); context.Flags.IsFaulted = true; return context; } if (domainObject != null && domainObject.GetDirectoryEntry().ToDirectoryObject() - .TryGetSecurityIdentifier(out var sid)) - { + .TryGetSecurityIdentifier(out var sid)) { context.Domains = new[] { new EnumerationDomain { Name = domain, @@ -274,8 +255,7 @@ public async Task GetDomainsForEnumeration(IContext context) } }; } - else - { + else { context.Domains = new[] { new EnumerationDomain { Name = domain, From 3f147fea8a96770d09768e52ded8f2ff52410e7c Mon Sep 17 00:00:00 2001 From: anemeth Date: Wed, 2 Jul 2025 15:56:50 -0700 Subject: [PATCH 3/3] fix: Create new cache with version when invalidating --- src/SharpLinks.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/SharpLinks.cs b/src/SharpLinks.cs index d3fcca0..7600012 100644 --- a/src/SharpLinks.cs +++ b/src/SharpLinks.cs @@ -140,14 +140,15 @@ public IContext InitCommonLib(IContext context) { var path = context.GetCachePath(); context.Logger.LogTrace("Cache Path: {Path}", path); + var version = Assembly.GetExecutingAssembly().GetName().Version; Cache cache; if (!File.Exists(path)) { context.Logger.LogTrace("Cache file does not exist"); - cache = null; + cache = Cache.CreateNewCache(version); } else if (context.Flags.InvalidateCache) { context.Logger.LogTrace($"Skipping cache load per option {nameof(Options.RebuildCache)}"); - cache = null; + cache = Cache.CreateNewCache(version); } else { try { @@ -158,13 +159,12 @@ public IContext InitCommonLib(IContext context) { } catch (Exception e) { context.Logger.LogError("Error loading cache: {exception}, creating new", e); - cache = null; + cache = Cache.CreateNewCache(version); } - var version = Assembly.GetExecutingAssembly().GetName().Version; if (CacheNeedsInvalidation(cache, version)) { context.Logger.LogInformation("Old cache found, ignoring"); - cache = null; + cache = Cache.CreateNewCache(version); } }