Skip to content

Commit

Permalink
Convert lib/auth to use slog (#50577)
Browse files Browse the repository at this point in the history
* Convert auth_with_roles to use slog

* allow benchmark tests to log

* remove last use of logrus in github.go

* remove last use of logrus in rotate.go

* remove auth package logrus logger

* convert auth.go to exclusively use slog

* convert db.go to exclusively use slog

* convert methods.go to exclusively use slog

* convert password.go to exclusively use slog

* convert join_ec2.go to exclusively use slog

* convert join_iam.go to exclusively use slog

* convert kube.go to exclusively use slog

* convert oidc.go to exclusively use slog

* convert saml.go to exclusively use slog

* convert sso_diag_context.go to exclusively use slog

* convert transport_credentials.go to exclusively use slog

* convert trustedcluster.go to exclusively use slog

* convert usertoken.go to exclusively use slog

* convert user.go to exclusively use slog

* remove logrus import

* fix: pass in context

* fix: remove new use of package logrus logger
  • Loading branch information
rosstimothy authored Jan 6, 2025
1 parent ffacd99 commit ec2718b
Show file tree
Hide file tree
Showing 21 changed files with 403 additions and 289 deletions.
8 changes: 4 additions & 4 deletions lib/auth/accountrecovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (a *Server) VerifyAccountRecovery(ctx context.Context, req *proto.VerifyAcc
return nil, trace.AccessDenied(verifyRecoveryBadAuthnErrMsg)
}

if err := a.verifyUserToken(startToken, authclient.UserTokenTypeRecoveryStart); err != nil {
if err := a.verifyUserToken(ctx, startToken, authclient.UserTokenTypeRecoveryStart); err != nil {
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -304,7 +304,7 @@ func (a *Server) CompleteAccountRecovery(ctx context.Context, req *proto.Complet
return trace.AccessDenied(completeRecoveryGenericErrMsg)
}

if err := a.verifyUserToken(approvedToken, authclient.UserTokenTypeRecoveryApproved); err != nil {
if err := a.verifyUserToken(ctx, approvedToken, authclient.UserTokenTypeRecoveryApproved); err != nil {
return trace.Wrap(err)
}

Expand Down Expand Up @@ -403,7 +403,7 @@ func (a *Server) CreateAccountRecoveryCodes(ctx context.Context, req *proto.Crea
return nil, trace.AccessDenied("only local users may create recovery codes")
}

if err := a.verifyUserToken(token, authclient.UserTokenTypeRecoveryApproved, authclient.UserTokenTypePrivilege); err != nil {
if err := a.verifyUserToken(ctx, token, authclient.UserTokenTypeRecoveryApproved, authclient.UserTokenTypePrivilege); err != nil {
return nil, trace.Wrap(err)
}

Expand All @@ -428,7 +428,7 @@ func (a *Server) GetAccountRecoveryToken(ctx context.Context, req *proto.GetAcco
return nil, trace.AccessDenied("access denied")
}

if err := a.verifyUserToken(token, authclient.UserTokenTypeRecoveryStart, authclient.UserTokenTypeRecoveryApproved); err != nil {
if err := a.verifyUserToken(ctx, token, authclient.UserTokenTypeRecoveryStart, authclient.UserTokenTypeRecoveryApproved); err != nil {
return nil, trace.Wrap(err)
}

Expand Down
230 changes: 133 additions & 97 deletions lib/auth/auth.go

Large diffs are not rendered by default.

245 changes: 159 additions & 86 deletions lib/auth/auth_with_roles.go

Large diffs are not rendered by default.

21 changes: 0 additions & 21 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"crypto/tls"
"crypto/x509/pkix"
"fmt"
"io"
"net/url"
"slices"
"strconv"
Expand All @@ -38,7 +37,6 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/pquerna/otp/totp"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -77,7 +75,6 @@ import (
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/srv/discovery/common"
"github.com/gravitational/teleport/lib/tlsca"
logutils "github.com/gravitational/teleport/lib/utils/log"
"github.com/gravitational/teleport/lib/utils/pagination"
)

Expand Down Expand Up @@ -1824,12 +1821,6 @@ func BenchmarkListNodes(b *testing.B) {
const nodeCount = 50_000
const roleCount = 32

logger := logrus.StandardLogger()
logger.ReplaceHooks(make(logrus.LevelHooks))
logrus.SetFormatter(logutils.NewTestJSONFormatter())
logger.SetLevel(logrus.DebugLevel)
logger.SetOutput(io.Discard)

ctx := context.Background()
srv := newTestTLSServer(b)

Expand Down Expand Up @@ -6124,12 +6115,6 @@ func BenchmarkListUnifiedResourcesFilter(b *testing.B) {
const nodeCount = 150_000
const roleCount = 32

logger := logrus.StandardLogger()
logger.ReplaceHooks(make(logrus.LevelHooks))
logrus.SetFormatter(logutils.NewTestJSONFormatter())
logger.SetLevel(logrus.PanicLevel)
logger.SetOutput(io.Discard)

ctx := context.Background()
srv := newTestTLSServer(b)

Expand Down Expand Up @@ -6257,12 +6242,6 @@ func BenchmarkListUnifiedResources(b *testing.B) {
const nodeCount = 150_000
const roleCount = 32

logger := logrus.StandardLogger()
logger.ReplaceHooks(make(logrus.LevelHooks))
logrus.SetFormatter(logutils.NewTestJSONFormatter())
logger.SetLevel(logrus.DebugLevel)
logger.SetOutput(io.Discard)

ctx := context.Background()
srv := newTestTLSServer(b)

Expand Down
11 changes: 7 additions & 4 deletions lib/auth/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func (a *Server) SignDatabaseCSR(ctx context.Context, req *proto.DatabaseCSRRequ
"this Teleport cluster is not licensed for database access, please contact the cluster administrator")
}

log.Debugf("Signing database CSR for cluster %v.", req.ClusterName)
a.logger.DebugContext(ctx, "Signing database CSR for cluster", "cluster", req.ClusterName)

clusterName, err := a.GetClusterName()
if err != nil {
Expand Down Expand Up @@ -348,7 +348,7 @@ func (a *Server) GenerateSnowflakeJWT(ctx context.Context, req *proto.SnowflakeJ
return nil, trace.Wrap(err)
}

subject, issuer := getSnowflakeJWTParams(req.AccountName, req.UserName, pubKey)
subject, issuer := getSnowflakeJWTParams(ctx, req.AccountName, req.UserName, pubKey)

_, signer, err := a.GetKeyStore().GetTLSCertAndSigner(ctx, ca)
if err != nil {
Expand All @@ -371,7 +371,7 @@ func (a *Server) GenerateSnowflakeJWT(ctx context.Context, req *proto.SnowflakeJ
}, nil
}

func getSnowflakeJWTParams(accountName, userName string, publicKey []byte) (string, string) {
func getSnowflakeJWTParams(ctx context.Context, accountName, userName string, publicKey []byte) (string, string) {
// Use only the first part of the account name to generate JWT
// Based on:
// https://github.com/snowflakedb/snowflake-connector-python/blob/f2f7e6f35a162484328399c8a50a5015825a5573/src/snowflake/connector/auth_keypair.py#L83
Expand All @@ -383,7 +383,10 @@ func getSnowflakeJWTParams(accountName, userName string, publicKey []byte) (stri
accnToken, _, _ := strings.Cut(accountName, accNameSeparator)
accnTokenCap := strings.ToUpper(accnToken)
userNameCap := strings.ToUpper(userName)
log.Debugf("Signing database JWT token for %s %s", accnTokenCap, userNameCap)
logger.DebugContext(ctx, "Signing database JWT token",
"account_name", accnTokenCap,
"user_name", userNameCap,
)

subject := fmt.Sprintf("%s.%s", accnTokenCap, userNameCap)

Expand Down
2 changes: 1 addition & 1 deletion lib/auth/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func Test_getSnowflakeJWTParams(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
subject, issuer := getSnowflakeJWTParams(tt.args.accountName, tt.args.userName, tt.args.publicKey)
subject, issuer := getSnowflakeJWTParams(context.Background(), tt.args.accountName, tt.args.userName, tt.args.publicKey)

require.Equal(t, tt.wantSubject, subject)
require.Equal(t, tt.wantIssuer, issuer)
Expand Down
3 changes: 1 addition & 2 deletions lib/auth/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (
"time"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -349,7 +348,7 @@ func orgUsesExternalSSO(ctx context.Context, endpointURL, org string, client htt
if resp != nil {
io.Copy(io.Discard, resp.Body)
if bodyErr := resp.Body.Close(); bodyErr != nil {
logrus.WithError(bodyErr).Error("Error closing response body.")
logger.ErrorContext(ctx, "Error closing response body", "error", bodyErr)
}
}
// Handle makeHTTPGetReq errors.
Expand Down
2 changes: 0 additions & 2 deletions lib/auth/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import (
"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
oteltrace "go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -75,7 +74,6 @@ import (
)

var (
log = logrus.WithField(teleport.ComponentKey, teleport.ComponentAuth)
logger = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentAuth)
)

Expand Down
10 changes: 7 additions & 3 deletions lib/auth/join_ec2.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,12 @@ func (a *Server) tryToDetectIdentityReuse(ctx context.Context, req *types.Regist
return trace.Wrap(err)
}
if instanceExists {
log.Warnf("Server with ID %q and role %q is attempting to join the cluster with a Simplified Node Joining request, but"+
" a server with this ID is already present in the cluster.", req.HostID, req.Role)
const msg = "Server is attempting to join the cluster with a Simplified Node Joining request, but" +
" a server with this ID is already present in the cluster"
a.logger.WarnContext(ctx, msg,
"host_id", req.HostID,
"role", req.Role,
)
return trace.AccessDenied("server with host ID %q and role %q already exists", req.HostID, req.Role)
}
return nil
Expand All @@ -363,7 +367,7 @@ func (a *Server) checkEC2JoinRequest(ctx context.Context, req *types.RegisterUsi
return trace.Wrap(err)
}

log.Debugf("Received Simplified Node Joining request for host %q", req.HostID)
a.logger.DebugContext(ctx, "Received Simplified Node Joining request", "host_id", req.HostID)

if len(req.EC2IdentityDocument) == 0 {
return trace.AccessDenied("this token is only valid for the EC2 join " +
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/join_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func validateSTSIdentityRequest(req *http.Request, challenge string, cfg *iamReg
// invalid sts:GetCallerIdentity request, it's either going to be caused
// by a node in a unknown region or an attacker.
if err != nil {
log.WithError(err).Warn("Detected an invalid sts:GetCallerIdentity used by a client attempting to use the IAM join method.")
logger.WarnContext(req.Context(), "Detected an invalid sts:GetCallerIdentity used by a client attempting to use the IAM join method", "error", err)
}
}()

Expand Down
2 changes: 1 addition & 1 deletion lib/auth/kube.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (a *Server) ProcessKubeCSR(req authclient.KubeCSR) (*authclient.KubeCSRResp

// Certificate for remote cluster is a user certificate
// with special provisions.
log.Debugf("Generating certificate to access remote Kubernetes clusters.")
a.logger.DebugContext(ctx, "Generating certificate to access remote Kubernetes clusters")

hostCA, err := a.GetCertAuthority(ctx, types.CertAuthID{
Type: types.HostCA,
Expand Down
41 changes: 23 additions & 18 deletions lib/auth/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"bytes"
"context"
"errors"
"log/slog"
"net"
"time"

Expand Down Expand Up @@ -71,16 +70,19 @@ func (a *Server) authenticateUserLogin(ctx context.Context, req authclient.Authe
clientMetadata: req.ClientMetadata,
authErr: err,
}); err != nil {
log.WithError(err).Warn("Failed to emit login event")
a.logger.WarnContext(ctx, "Failed to emit login event", "error", err)
}
return nil, nil, trace.Wrap(err)
}

switch {
case username != "" && actualUsername != "" && username != actualUsername:
log.Warnf("Authenticate user mismatch (%q vs %q). Using request user (%q)", username, actualUsername, username)
a.logger.WarnContext(ctx, "Authenticate user mismatch, using request user",
"username", username,
"request_user", actualUsername,
)
case username == "" && actualUsername != "":
log.Debugf("User %q authenticated via passwordless", actualUsername)
a.logger.DebugContext(ctx, "User authenticated via passwordless", "username", actualUsername)
username = actualUsername
}

Expand Down Expand Up @@ -123,7 +125,7 @@ func (a *Server) authenticateUserLogin(ctx context.Context, req authclient.Authe
checker: checker,
authErr: err,
}); err != nil {
log.WithError(err).Warn("Failed to emit login event")
a.logger.WarnContext(ctx, "Failed to emit login event", "error", err)
}
return nil, nil, trace.Wrap(err)
}
Expand All @@ -135,7 +137,7 @@ func (a *Server) authenticateUserLogin(ctx context.Context, req authclient.Authe
mfaDevice: mfaDev,
checker: checker,
}); err != nil {
log.WithError(err).Warn("Failed to emit login event")
a.logger.WarnContext(ctx, "Failed to emit login event", "error", err)
}

return userState, checker, trace.Wrap(err)
Expand Down Expand Up @@ -303,7 +305,7 @@ func (a *Server) authenticateUserInternal(
if req.HeadlessAuthenticationID != "" {
mfaDev, err = a.authenticateHeadless(ctx, req)
if err != nil {
slog.DebugContext(ctx, "Headless authenticate failed while waiting for approval",
a.logger.DebugContext(ctx, "Headless authenticate failed while waiting for approval",
"user", user,
"error", err,
)
Expand All @@ -330,7 +332,7 @@ func (a *Server) authenticateUserInternal(
case err != nil:
return nil, "", trace.Wrap(err)
case u.GetUserType() != types.UserTypeLocal:
slog.WarnContext(ctx, "Non-local user attempted local authentication",
a.logger.WarnContext(ctx, "Non-local user attempted local authentication",
"user", user,
"user_type", u.GetUserType(),
)
Expand Down Expand Up @@ -381,7 +383,7 @@ func (a *Server) authenticateUserInternal(
})
switch {
case err != nil:
slog.DebugContext(ctx, "User failed to authenticate.",
a.logger.DebugContext(ctx, "User failed to authenticate",
"user", user,
"error", err,
)
Expand All @@ -391,7 +393,7 @@ func (a *Server) authenticateUserInternal(

return nil, "", trace.Wrap(authErr)
case mfaDev == nil:
slog.DebugContext(ctx, "MFA authentication returned nil device.",
a.logger.DebugContext(ctx, "MFA authentication returned nil device",
"webauthn", req.Webauthn != nil,
"totp", req.OTP != nil,
"headless", req.HeadlessAuthenticationID != "",
Expand Down Expand Up @@ -420,7 +422,7 @@ func (a *Server) authenticateUserInternal(
// Some form of MFA is required but none provided. Either client is
// buggy (didn't send MFA response) or someone is trying to bypass
// MFA.
slog.WarnContext(ctx, "MFA bypass attempt, access denied.", "user", user)
a.logger.WarnContext(ctx, "MFA bypass attempt, access denied", "user", user)
return nil, "", trace.AccessDenied("missing second factor")
case authPreference.IsSecondFactorEnabled():
// 2FA is optional. Make sure that a user does not have MFA devices
Expand All @@ -430,7 +432,7 @@ func (a *Server) authenticateUserInternal(
return nil, "", trace.Wrap(err)
}
if len(devs) != 0 {
slog.WarnContext(ctx, "MFA bypass attempt, access denied.", "user", user)
a.logger.WarnContext(ctx, "MFA bypass attempt, access denied", "user", user)
return nil, "", trace.AccessDenied("missing second factor authentication")
}
default:
Expand All @@ -444,7 +446,7 @@ func (a *Server) authenticateUserInternal(
}
// provide obscure message on purpose, while logging the real
// error server side
slog.DebugContext(ctx, "User failed to authenticate.",
a.logger.DebugContext(ctx, "User failed to authenticate",
"user", user,
"error", err,
)
Expand All @@ -467,15 +469,18 @@ func (a *Server) authenticatePasswordless(ctx context.Context, req authclient.Au
case errors.Is(err, types.ErrPassswordlessLoginBySSOUser):
return nil, "", trace.Wrap(err)
case err != nil:
log.Debugf("Passwordless authentication failed: %v", err)
a.logger.DebugContext(ctx, "Passwordless authentication failed", "error", err)
return nil, "", trace.Wrap(authenticateWebauthnError)
}

// A distinction between passwordless and "plain" MFA is that we can't
// acquire the user lock beforehand (or at all on failures!)
// We do grab it here so successful logins go through the regular process.
if err := a.WithUserLock(ctx, mfaData.User, func() error { return nil }); err != nil {
log.Debugf("WithUserLock for user %q failed during passwordless authentication: %v", mfaData.User, err)
a.logger.DebugContext(ctx, "WithUserLock failed during passwordless authentication",
"user", mfaData.User,
"error", err,
)
return nil, mfaData.User, trace.Wrap(authenticateWebauthnError)
}

Expand All @@ -487,7 +492,7 @@ func (a *Server) authenticateHeadless(ctx context.Context, req authclient.Authen
defer func() {
if err != nil {
if err := a.DeleteHeadlessAuthentication(a.CloseContext(), req.Username, req.HeadlessAuthenticationID); err != nil && !trace.IsNotFound(err) {
log.Debugf("Failed to delete headless authentication: %v", err)
a.logger.DebugContext(ctx, "Failed to delete headless authentication", "error", err)
}
}
}()
Expand Down Expand Up @@ -773,7 +778,7 @@ func (a *Server) emitNoLocalAuthEvent(username string) {
Error: noLocalAuth,
},
}); err != nil {
log.WithError(err).Warn("Failed to emit no local auth event.")
a.logger.WarnContext(a.closeCtx, "Failed to emit no local auth event", "error", err)
}
}

Expand All @@ -794,7 +799,7 @@ func getErrorByTraceField(err error) error {
ok := errors.As(err, &traceErr)
switch {
case !ok:
log.WithError(err).Warn("Unexpected error type, wanted TraceError")
logger.WarnContext(context.Background(), "Unexpected error type, wanted TraceError", "error", err)
return trace.AccessDenied("an error has occurred")
case traceErr.GetFields()[ErrFieldKeyUserMaxedAttempts] != nil:
return trace.AccessDenied(MaxFailedAttemptsErrMsg)
Expand Down
Loading

0 comments on commit ec2718b

Please sign in to comment.