Skip to content

Commit

Permalink
Preserve context in logger when calling FromContext. (#11)
Browse files Browse the repository at this point in the history
Modification of #5 

Allows for passing through the logger while still using default Info
methods. InfoContext remains to override the context with another value.

This changes the stored context value to store the underlying
slog.Logger instead of the wrapped logger so we're not storing a context
within a context.
  • Loading branch information
wlynch authored Feb 6, 2024
1 parent 4d1a940 commit f4145d2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 12 deletions.
14 changes: 14 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,17 @@ func ExampleLogger() {
// level=INFO msg="hello world" a=b foo=bar
// level=ERROR msg=asdf a=b baz=true
}

func ExampleFromContext_preserveContext() {
log := clog.NewLogger(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
// Remove time for repeatable results
ReplaceAttr: slogtest.RemoveTime,
}))).With("foo", "bar")
ctx := clog.WithLogger(context.Background(), log)

// Previous context values are preserved when using FromContext
clog.FromContext(ctx).Info("hello world")

// Output:
// level=INFO msg="hello world" foo=bar
}
48 changes: 36 additions & 12 deletions logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

// Logger implements a wrapper around [slog.Logger] that adds formatter functions (e.g. Infof, Errorf)
type Logger struct {
ctx context.Context
slog.Logger
}

Expand All @@ -19,19 +20,32 @@ func DefaultLogger() *Logger {
return NewLogger(slog.Default())
}

// NewLogger returns a new logger that wraps the given [slog.Logger].
// NewLogger returns a new logger that wraps the given [slog.Logger] with the default context.
func NewLogger(l *slog.Logger) *Logger {
return NewLoggerWithContext(context.Background(), l)
}

// NewLoggerWithContext returns a new logger that wraps the given [slog.Logger].
func NewLoggerWithContext(ctx context.Context, l *slog.Logger) *Logger {
if l == nil {
l = slog.Default()
}
return &Logger{Logger: *l}
return &Logger{
ctx: ctx,
Logger: *l,
}
}

// New returns a new logger that wraps the given [slog.Handler].
func New(h slog.Handler) *Logger {
return NewLogger(slog.New(h))
}

// New returns a new logger that wraps the given [slog.Handler].
func NewWithContext(ctx context.Context, h slog.Handler) *Logger {
return NewLoggerWithContext(ctx, slog.New(h))
}

// With calls [Logger.With] on the default logger.
func With(args ...any) *Logger {
return DefaultLogger().With(args...)
Expand All @@ -47,9 +61,16 @@ func (l *Logger) WithGroup(name string) *Logger {
return NewLogger(l.Logger.WithGroup(name))
}

func (l *Logger) context() context.Context {
if l.ctx == nil {
return context.Background()
}
return l.ctx
}

// Infof logs at LevelInfo with the given format and arguments.
func (l *Logger) Infof(format string, args ...any) {
wrapf(context.Background(), l, slog.LevelInfo, format, args...)
wrapf(l.context(), l, slog.LevelInfo, format, args...)
}

// InfoContextf logs at LevelInfo with the given context, format and arguments.
Expand All @@ -59,7 +80,7 @@ func (l *Logger) InfoContextf(ctx context.Context, format string, args ...any) {

// Warnf logs at LevelWarn with the given format and arguments.
func (l *Logger) Warnf(format string, args ...any) {
wrapf(context.Background(), l, slog.LevelWarn, format, args...)
wrapf(l.context(), l, slog.LevelWarn, format, args...)
}

// WarnContextf logs at LevelWarn with the given context, format and arguments.
Expand All @@ -69,7 +90,7 @@ func (l *Logger) WarnContextf(ctx context.Context, format string, args ...any) {

// Errorf logs at LevelError with the given format and arguments.
func (l *Logger) Errorf(format string, args ...any) {
wrapf(context.Background(), l, slog.LevelError, format, args...)
wrapf(l.context(), l, slog.LevelError, format, args...)
}

// ErrorContextf logs at LevelError with the given context, format and arguments.
Expand All @@ -79,7 +100,7 @@ func (l *Logger) ErrorContextf(ctx context.Context, format string, args ...any)

// Debugf logs at LevelDebug with the given format and arguments.
func (l *Logger) Debugf(format string, args ...any) {
wrapf(context.Background(), l, slog.LevelDebug, format, args...)
wrapf(l.context(), l, slog.LevelDebug, format, args...)
}

// DebugContextf logs at LevelDebug with the given context, format and arguments.
Expand All @@ -89,13 +110,13 @@ func (l *Logger) DebugContextf(ctx context.Context, format string, args ...any)

// Fatalf logs at LevelError with the given format and arguments, then exits.
func (l *Logger) Fatalf(format string, args ...any) {
wrapf(context.Background(), l, slog.LevelError, format, args...)
wrapf(l.context(), l, slog.LevelError, format, args...)
os.Exit(1)
}

// Fatal logs at LevelError with the given message, then exits.
func (l *Logger) Fatal(msg string, args ...any) {
wrap(context.Background(), l, slog.LevelError, msg, args...)
wrap(l.context(), l, slog.LevelError, msg, args...)
os.Exit(1)
}

Expand Down Expand Up @@ -149,12 +170,15 @@ func wrapf(ctx context.Context, logger *Logger, level slog.Level, format string,
type loggerKey struct{}

func WithLogger(ctx context.Context, logger *Logger) context.Context {
return context.WithValue(ctx, loggerKey{}, logger)
return context.WithValue(ctx, loggerKey{}, logger.Logger)
}

func FromContext(ctx context.Context) *Logger {
if logger, ok := ctx.Value(loggerKey{}).(*Logger); ok {
return logger
if logger, ok := ctx.Value(loggerKey{}).(slog.Logger); ok {
return &Logger{
ctx: ctx,
Logger: logger,
}
}
return DefaultLogger()
return NewLoggerWithContext(ctx, slog.Default())
}

0 comments on commit f4145d2

Please sign in to comment.