diff --git a/example_test.go b/example_test.go index 393b7c5..8755a73 100644 --- a/example_test.go +++ b/example_test.go @@ -42,3 +42,17 @@ func ExampleLogger() { // level=INFO msg="hello world" a=b foo=bar // level=ERROR msg=asdf a=b baz=true } + +func ExampleFromContext_preserveContext() { + log := clog.NewLogger(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + // Remove time for repeatable results + ReplaceAttr: slogtest.RemoveTime, + }))).With("foo", "bar") + ctx := clog.WithLogger(context.Background(), log) + + // Previous context values are preserved when using FromContext + clog.FromContext(ctx).Info("hello world") + + // Output: + // level=INFO msg="hello world" foo=bar +} diff --git a/logger.go b/logger.go index f64e2c2..853a4ed 100644 --- a/logger.go +++ b/logger.go @@ -11,6 +11,7 @@ import ( // Logger implements a wrapper around [slog.Logger] that adds formatter functions (e.g. Infof, Errorf) type Logger struct { + ctx context.Context slog.Logger } @@ -19,12 +20,20 @@ func DefaultLogger() *Logger { return NewLogger(slog.Default()) } -// NewLogger returns a new logger that wraps the given [slog.Logger]. +// NewLogger returns a new logger that wraps the given [slog.Logger] with the default context. func NewLogger(l *slog.Logger) *Logger { + return NewLoggerWithContext(context.Background(), l) +} + +// NewLoggerWithContext returns a new logger that wraps the given [slog.Logger]. +func NewLoggerWithContext(ctx context.Context, l *slog.Logger) *Logger { if l == nil { l = slog.Default() } - return &Logger{Logger: *l} + return &Logger{ + ctx: ctx, + Logger: *l, + } } // New returns a new logger that wraps the given [slog.Handler]. @@ -32,6 +41,11 @@ func New(h slog.Handler) *Logger { return NewLogger(slog.New(h)) } +// New returns a new logger that wraps the given [slog.Handler]. +func NewWithContext(ctx context.Context, h slog.Handler) *Logger { + return NewLoggerWithContext(ctx, slog.New(h)) +} + // With calls [Logger.With] on the default logger. func With(args ...any) *Logger { return DefaultLogger().With(args...) @@ -47,9 +61,16 @@ func (l *Logger) WithGroup(name string) *Logger { return NewLogger(l.Logger.WithGroup(name)) } +func (l *Logger) context() context.Context { + if l.ctx == nil { + return context.Background() + } + return l.ctx +} + // Infof logs at LevelInfo with the given format and arguments. func (l *Logger) Infof(format string, args ...any) { - wrapf(context.Background(), l, slog.LevelInfo, format, args...) + wrapf(l.context(), l, slog.LevelInfo, format, args...) } // InfoContextf logs at LevelInfo with the given context, format and arguments. @@ -59,7 +80,7 @@ func (l *Logger) InfoContextf(ctx context.Context, format string, args ...any) { // Warnf logs at LevelWarn with the given format and arguments. func (l *Logger) Warnf(format string, args ...any) { - wrapf(context.Background(), l, slog.LevelWarn, format, args...) + wrapf(l.context(), l, slog.LevelWarn, format, args...) } // WarnContextf logs at LevelWarn with the given context, format and arguments. @@ -69,7 +90,7 @@ func (l *Logger) WarnContextf(ctx context.Context, format string, args ...any) { // Errorf logs at LevelError with the given format and arguments. func (l *Logger) Errorf(format string, args ...any) { - wrapf(context.Background(), l, slog.LevelError, format, args...) + wrapf(l.context(), l, slog.LevelError, format, args...) } // ErrorContextf logs at LevelError with the given context, format and arguments. @@ -79,7 +100,7 @@ func (l *Logger) ErrorContextf(ctx context.Context, format string, args ...any) // Debugf logs at LevelDebug with the given format and arguments. func (l *Logger) Debugf(format string, args ...any) { - wrapf(context.Background(), l, slog.LevelDebug, format, args...) + wrapf(l.context(), l, slog.LevelDebug, format, args...) } // DebugContextf logs at LevelDebug with the given context, format and arguments. @@ -89,13 +110,13 @@ func (l *Logger) DebugContextf(ctx context.Context, format string, args ...any) // Fatalf logs at LevelError with the given format and arguments, then exits. func (l *Logger) Fatalf(format string, args ...any) { - wrapf(context.Background(), l, slog.LevelError, format, args...) + wrapf(l.context(), l, slog.LevelError, format, args...) os.Exit(1) } // Fatal logs at LevelError with the given message, then exits. func (l *Logger) Fatal(msg string, args ...any) { - wrap(context.Background(), l, slog.LevelError, msg, args...) + wrap(l.context(), l, slog.LevelError, msg, args...) os.Exit(1) } @@ -149,12 +170,15 @@ func wrapf(ctx context.Context, logger *Logger, level slog.Level, format string, type loggerKey struct{} func WithLogger(ctx context.Context, logger *Logger) context.Context { - return context.WithValue(ctx, loggerKey{}, logger) + return context.WithValue(ctx, loggerKey{}, logger.Logger) } func FromContext(ctx context.Context) *Logger { - if logger, ok := ctx.Value(loggerKey{}).(*Logger); ok { - return logger + if logger, ok := ctx.Value(loggerKey{}).(slog.Logger); ok { + return &Logger{ + ctx: ctx, + Logger: logger, + } } - return DefaultLogger() + return NewLoggerWithContext(ctx, slog.Default()) }