Skip to content

Commit 206cecc

Browse files
committed
Hide lockdown logic behind shouldFilter function
1 parent 2de28f7 commit 206cecc

File tree

4 files changed

+27
-35
lines changed

4 files changed

+27
-35
lines changed

pkg/github/issues.go

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,11 @@ func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAc
331331
}
332332
login := issue.GetUser().GetLogin()
333333
if login != "" {
334-
info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo)
334+
shouldFilter, err := cache.ShouldFilterContent(ctx, login, owner, repo)
335335
if err != nil {
336336
return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil
337337
}
338-
if info.ViewerLogin != login && !info.IsPrivate && !info.HasPushAccess {
338+
if shouldFilter {
339339
return mcp.NewToolResultError("access to issue details is restricted by lockdown mode"), nil
340340
}
341341
}
@@ -394,16 +394,11 @@ func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdow
394394
if login == "" {
395395
continue
396396
}
397-
info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo)
397+
shouldFilter, err := cache.ShouldFilterContent(ctx, login, owner, repo)
398398
if err != nil {
399399
return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil
400400
}
401-
// Do not filter content for private repositories or if the comment author is the viewer
402-
if info.IsPrivate || info.ViewerLogin == login {
403-
filteredComments = comments
404-
break
405-
}
406-
if info.HasPushAccess {
401+
if !shouldFilter {
407402
filteredComments = append(filteredComments, comment)
408403
}
409404
}
@@ -459,16 +454,11 @@ func GetSubIssues(ctx context.Context, client *github.Client, cache *lockdown.Re
459454
if login == "" {
460455
continue
461456
}
462-
info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo)
457+
shouldFilter, err := cache.ShouldFilterContent(ctx, login, owner, repo)
463458
if err != nil {
464459
return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil
465460
}
466-
// Repo is private or the comment author is the viewer, do not filter content
467-
if info.IsPrivate || info.ViewerLogin == login {
468-
filteredSubIssues = subIssues
469-
break
470-
}
471-
if info.HasPushAccess {
461+
if !shouldFilter {
472462
filteredSubIssues = append(filteredSubIssues, subIssue)
473463
}
474464
}

pkg/github/pullrequests.go

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,12 @@ func GetPullRequest(ctx context.Context, client *github.Client, cache *lockdown.
141141
}
142142
login := pr.GetUser().GetLogin()
143143
if login != "" {
144-
info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo)
144+
shouldFilter, err := cache.ShouldFilterContent(ctx, login, owner, repo)
145145
if err != nil {
146146
return nil, fmt.Errorf("failed to check content removal: %w", err)
147147
}
148148

149-
if info.ViewerLogin != login && !info.IsPrivate && !info.HasPushAccess {
149+
if shouldFilter {
150150
return mcp.NewToolResultError("access to pull request is restricted by lockdown mode"), nil
151151
}
152152
}
@@ -303,16 +303,11 @@ func GetPullRequestReviewComments(ctx context.Context, client *github.Client, ca
303303
if user == nil {
304304
continue
305305
}
306-
info, err := cache.GetRepoAccessInfo(ctx, user.GetLogin(), owner, repo)
306+
shouldFilter, err := cache.ShouldFilterContent(ctx, user.GetLogin(), owner, repo)
307307
if err != nil {
308308
return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil
309309
}
310-
// Do not filter content for private repositories or if the comment author is the viewer
311-
if info.IsPrivate || info.ViewerLogin == user.GetLogin() {
312-
filteredComments = comments
313-
break
314-
}
315-
if info.HasPushAccess {
310+
if !shouldFilter {
316311
filteredComments = append(filteredComments, comment)
317312
}
318313
}
@@ -354,14 +349,11 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lo
354349
for _, review := range reviews {
355350
login := review.GetUser().GetLogin()
356351
if login != "" {
357-
info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo)
352+
shouldFilter, err := cache.ShouldFilterContent(ctx, login, owner, repo)
358353
if err != nil {
359354
return nil, fmt.Errorf("failed to check lockdown mode: %w", err)
360355
}
361-
if info.IsPrivate || info.ViewerLogin == login {
362-
filteredReviews = reviews
363-
}
364-
if info.HasPushAccess {
356+
if !shouldFilter {
365357
filteredReviews = append(filteredReviews, review)
366358
}
367359
reviews = filteredReviews

pkg/lockdown/lockdown.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,19 @@ type CacheStats struct {
109109
Evictions int64
110110
}
111111

112-
// GetRepoAccessInfo returns repository access metadata for the provided user.
113-
// Results are cached per repository to avoid repeated GraphQL round-trips.
114-
func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) {
112+
func (c *RepoAccessCache) ShouldFilterContent(ctx context.Context, username, owner, repo string) (bool, error) {
113+
repoInfo, err := c.getRepoAccessInfo(ctx, username, owner, repo)
114+
if err != nil {
115+
c.logDebug("error checking repo access info for content filtering", "owner", owner, "repo", repo, "user", username, "error", err)
116+
return false, err
117+
}
118+
if repoInfo.IsPrivate || repoInfo.ViewerLogin == username {
119+
return false, nil
120+
}
121+
return !repoInfo.HasPushAccess, nil
122+
}
123+
124+
func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) {
115125
if c == nil {
116126
return RepoAccessInfo{}, fmt.Errorf("nil repo access cache")
117127
}

pkg/lockdown/lockdown_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,15 @@ func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) {
9696
ctx := t.Context()
9797

9898
cache, transport := newMockRepoAccessCache(t, 5*time.Millisecond)
99-
info, err := cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo)
99+
info, err := cache.getRepoAccessInfo(ctx, testUser, testOwner, testRepo)
100100
require.NoError(t, err)
101101
require.Equal(t, testUser, info.ViewerLogin)
102102
require.True(t, info.HasPushAccess)
103103
require.EqualValues(t, 1, transport.CallCount())
104104

105105
time.Sleep(20 * time.Millisecond)
106106

107-
info, err = cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo)
107+
info, err = cache.getRepoAccessInfo(ctx, testUser, testOwner, testRepo)
108108
require.NoError(t, err)
109109
require.Equal(t, testUser, info.ViewerLogin)
110110
require.True(t, info.HasPushAccess)

0 commit comments

Comments
 (0)