diff --git a/README.md b/README.md index a3f4a8a..fff31f8 100644 --- a/README.md +++ b/README.md @@ -236,6 +236,13 @@ proxy: tunnel_listen: ":8080" # Optional CONNECT/SOCKS5 listener max_request_body_bytes: 1048576 # 1 MiB (default) max_response_body_bytes: 0 # uncapped (default) + auth: + required: false # Default: false + # users: + # - login: "ci" + # password: + # type: env + # var: "IRON_PROXY_CI_PASSWORD" tls: ca_cert: "/etc/iron-proxy/ca.crt" # Required @@ -251,6 +258,10 @@ transforms: - "*.anthropic.com" cidrs: - "10.0.0.0/8" + rules: + - host: "api.openai.com" + proxy_logins: ["ci"] + source_cidrs: ["10.16.0.0/16"] - name: secrets config: @@ -269,6 +280,33 @@ log: level: "info" # debug, info, warn, error ``` +### Proxy auth + +`proxy.auth.required` enables client authentication for HTTP proxy requests, +HTTP `CONNECT`, and SOCKS5. Default is `false`. + +HTTP and `CONNECT` clients use `Proxy-Authorization: Basic ...`. SOCKS5 clients +use username/password auth. In `tls.mode: sni-only`, raw HTTPS connections do +not carry auth metadata; when auth is required, use the tunnel listener. +Each user password is a secret source (`env`, `file`, AWS, 1Password, etc.) +resolved on startup and management reload. + +```yaml +proxy: + tunnel_listen: ":8080" + auth: + required: true + users: + - login: "ci" + password: + type: env + var: "IRON_PROXY_CI_PASSWORD" + - login: "dev" + password: + type: file + path: "/run/secrets/iron_proxy_dev_password" +``` + ### DNS Everything resolves to `proxy_ip` by default, which is what routes traffic @@ -286,6 +324,19 @@ Unmatched requests get a `403 Forbidden`. Domain patterns use glob matching: `*.example.com` matches any subdomain and `example.com` itself. +Rules can also be scoped by authenticated proxy login and client source CIDR: + +```yaml +transforms: + - name: allowlist + config: + rules: + - host: "api.openai.com" + methods: ["POST"] + proxy_logins: ["ci", "dev"] + source_cidrs: ["10.16.0.0/16"] +``` + **Warn mode:** Set `warn: true` to observe what the allowlist would block without actually enforcing it. Requests that would be rejected are allowed through but annotated with `"action": "warn"` in the transform trace. This is useful for diff --git a/cmd/iron-proxy/main.go b/cmd/iron-proxy/main.go index 3645e78..32c7841 100644 --- a/cmd/iron-proxy/main.go +++ b/cmd/iron-proxy/main.go @@ -204,7 +204,7 @@ func main() { } // Initialize proxy. - p := proxy.New(proxy.Options{ + p, err := proxy.New(proxy.Options{ HTTPAddr: cfg.Proxy.HTTPListen, HTTPSAddr: cfg.Proxy.HTTPSListen, TunnelAddr: cfg.Proxy.TunnelListen, @@ -214,10 +214,15 @@ func main() { Resolver: resolver, Guard: guard, MCPPolicy: mcpHolder, + Auth: cfg.Proxy.Auth, Logger: logger, UpstreamResponseHeaderTimeout: time.Duration(cfg.Proxy.UpstreamResponseHeaderTimeout), UpstreamProxy: cfg.Proxy.UpstreamProxy.ProxyFunc(), }) + if err != nil { + logger.Error("initializing proxy", slog.String("error", err.Error())) + os.Exit(1) + } // Initialize metrics server. metricsServer := metrics.New(cfg.Metrics.Listen, logger) @@ -228,7 +233,7 @@ func main() { mgmtServer = management.New(management.Options{ Addr: cfg.Management.Listen, APIKey: os.Getenv(cfg.Management.APIKeyEnv), - Reload: newReloadFunc(*configPath, holder, mcpHolder, pgManager, bodyLimits, logger), + Reload: newReloadFunc(*configPath, holder, mcpHolder, pgManager, p, bodyLimits, logger), Logger: logger, Ctx: ctx, }) @@ -587,7 +592,7 @@ func applyPostgresSync(ctx context.Context, mgr *postgres.Manager, local *postgr // wrapped in *management.ValidationError so the management server returns // 422 and the existing state is left untouched. Validation runs for every // component before any state is mutated. -func newReloadFunc(configPath string, holder *transform.PipelineHolder, mcpHolder *mcp.PolicyHolder, pgManager *postgres.Manager, bodyLimits transform.BodyLimits, logger *slog.Logger) management.ReloadFunc { +func newReloadFunc(configPath string, holder *transform.PipelineHolder, mcpHolder *mcp.PolicyHolder, pgManager *postgres.Manager, p *proxy.Proxy, bodyLimits transform.BodyLimits, logger *slog.Logger) management.ReloadFunc { return func(ctx context.Context) error { newCfg, err := config.LoadConfig(configPath) if err != nil { @@ -608,6 +613,9 @@ func newReloadFunc(configPath string, holder *transform.PipelineHolder, mcpHolde if err != nil { return &management.ValidationError{Err: err} } + if err := p.ReloadAuth(ctx, newCfg.Proxy.Auth); err != nil { + return &management.ValidationError{Err: err} + } newPipeline.SetAuditFunc(holder.Load().AuditFunc()) holder.Store(newPipeline) mcpHolder.Store(newPolicy) diff --git a/internal/config/config.go b/internal/config/config.go index f6103e8..f71493c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -75,11 +75,12 @@ type DNSRecord struct { // Proxy configures the HTTP/HTTPS listener addresses. type Proxy struct { - HTTPListen string `yaml:"http_listen"` - HTTPSListen string `yaml:"https_listen"` - TunnelListen string `yaml:"tunnel_listen"` - MaxRequestBodyBytes int64 `yaml:"max_request_body_bytes"` - MaxResponseBodyBytes int64 `yaml:"max_response_body_bytes"` + HTTPListen string `yaml:"http_listen"` + HTTPSListen string `yaml:"https_listen"` + TunnelListen string `yaml:"tunnel_listen"` + MaxRequestBodyBytes int64 `yaml:"max_request_body_bytes"` + MaxResponseBodyBytes int64 `yaml:"max_response_body_bytes"` + Auth ProxyAuth `yaml:"auth"` // UpstreamResponseHeaderTimeout caps how long the proxy waits for an // upstream response's headers before returning 502. Accepts Go duration // syntax: "30s" (default), "5m", "2h". Useful for upstream endpoints @@ -96,6 +97,20 @@ type Proxy struct { UpstreamProxy UpstreamProxy `yaml:"upstream_proxy"` } +// ProxyAuth configures optional client authentication for HTTP proxy, +// CONNECT, and SOCKS5 proxy modes. Empty config preserves the historical +// no-auth behavior. +type ProxyAuth struct { + Required bool `yaml:"required"` + Users []ProxyAuthUser `yaml:"users"` +} + +// ProxyAuthUser is one proxy login. Password is a secrets source. +type ProxyAuthUser struct { + Login string `yaml:"login"` + Password yaml.Node `yaml:"password"` +} + // CIDRList is a list of CIDR strings whose presence in YAML is distinguishable // from absence: an explicit empty list opts out of any default population, // while an unset field signals "apply the default". @@ -280,6 +295,9 @@ func Validate(cfg *Config) error { if err := dnsguard.ValidateCIDRs(cfg.Proxy.UpstreamDenyCIDRs.Values); err != nil { return fmt.Errorf("proxy.upstream_deny_cidrs: %w", err) } + if err := validateProxyAuth(cfg.Proxy.Auth); err != nil { + return err + } if cfg.Management.Listen != "" { if cfg.Management.APIKeyEnv == "" { @@ -305,3 +323,26 @@ func Validate(cfg *Config) error { return nil } + +func validateProxyAuth(auth ProxyAuth) error { + seen := make(map[string]struct{}, len(auth.Users)) + for i, user := range auth.Users { + if user.Login == "" { + return fmt.Errorf("proxy.auth.users[%d].login is required", i) + } + if _, ok := seen[user.Login]; ok { + return fmt.Errorf("proxy.auth.users[%d].login %q is duplicated", i, user.Login) + } + seen[user.Login] = struct{}{} + if user.Password.Kind == 0 { + return fmt.Errorf("proxy.auth.users[%d].password is required", i) + } + if user.Password.Kind != yaml.MappingNode { + return fmt.Errorf("proxy.auth.users[%d].password must be a secret source mapping", i) + } + } + if auth.Required && len(auth.Users) == 0 { + return fmt.Errorf("proxy.auth.users is required when proxy.auth.required is true") + } + return nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ff4bf93..a84a185 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -58,6 +58,8 @@ func TestLoad_Defaults(t *testing.T) { require.Equal(t, 1000, cfg.TLS.CertCacheSize) require.Equal(t, ":9090", cfg.Metrics.Listen) require.Equal(t, "info", cfg.Log.Level) + require.False(t, cfg.Proxy.Auth.Required) + require.Empty(t, cfg.Proxy.Auth.Users) } func TestLoad_OverrideDefaults(t *testing.T) { @@ -179,6 +181,38 @@ tls: `, wantErr: "dns.records[0].value is required", }, + { + name: "proxy auth required without users", + yaml: ` +dns: + proxy_ip: "10.0.0.1" +proxy: + auth: + required: true +tls: + ca_cert: "/tmp/ca.crt" + ca_key: "/tmp/ca.key" +`, + wantErr: "proxy.auth.users is required", + }, + { + name: "proxy auth duplicate login", + yaml: ` +dns: + proxy_ip: "10.0.0.1" +proxy: + auth: + users: + - login: ci + password: {type: env, var: ONE} + - login: ci + password: {type: env, var: TWO} +tls: + ca_cert: "/tmp/ca.crt" + ca_key: "/tmp/ca.key" +`, + wantErr: "duplicated", + }, } for _, tt := range tests { @@ -190,6 +224,28 @@ tls: } } +func TestLoad_ProxyAuthPasswordSource(t *testing.T) { + cfg, err := Load(strings.NewReader(` +dns: + proxy_ip: "10.0.0.1" +proxy: + auth: + required: true + users: + - login: ci + password: + type: env + var: IRON_PROXY_CI_PASSWORD +tls: + ca_cert: "/tmp/ca.crt" + ca_key: "/tmp/ca.key" +`)) + require.NoError(t, err) + require.True(t, cfg.Proxy.Auth.Required) + require.Equal(t, "ci", cfg.Proxy.Auth.Users[0].Login) + require.NotZero(t, cfg.Proxy.Auth.Users[0].Password.Kind) +} + func TestLoad_UnknownFields(t *testing.T) { yaml := ` dns: diff --git a/internal/hostmatch/rule.go b/internal/hostmatch/rule.go index a75ada2..95431f5 100644 --- a/internal/hostmatch/rule.go +++ b/internal/hostmatch/rule.go @@ -2,27 +2,43 @@ package hostmatch import ( "fmt" + "net" "net/http" "strings" ) // RuleConfig is the YAML-decoded form of a host/method/path matching rule. type RuleConfig struct { - Host string `yaml:"host,omitempty"` - CIDR string `yaml:"cidr,omitempty"` - Methods []string `yaml:"methods,omitempty"` - Paths []string `yaml:"paths,omitempty"` + Host string `yaml:"host,omitempty"` + CIDR string `yaml:"cidr,omitempty"` + Methods []string `yaml:"methods,omitempty"` + Paths []string `yaml:"paths,omitempty"` + ProxyLogins []string `yaml:"proxy_logins,omitempty"` + SourceCIDRs []string `yaml:"source_cidrs,omitempty"` } // Rule is a compiled matching rule ready for use. type Rule struct { - Matcher *Matcher - Methods map[string]bool // nil = all methods - Paths []string // nil = all paths + Matcher *Matcher + Methods map[string]bool // nil = all methods + Paths []string // nil = all paths + ProxyLogins map[string]bool // nil = all proxy logins + SourceCIDRs []*net.IPNet // nil = all client source IPs +} + +// MatchContext carries connection metadata used by optional rule filters. +type MatchContext struct { + ProxyLogin string + SourceIP string } // Matches returns true if the request matches this rule. func (r *Rule) Matches(host, method, path string) bool { + return r.MatchesContext(host, method, path, MatchContext{}) +} + +// MatchesContext returns true if the request and connection metadata match. +func (r *Rule) MatchesContext(host, method, path string, ctx MatchContext) bool { if !r.Matcher.Matches(host) { return false } @@ -32,6 +48,12 @@ func (r *Rule) Matches(host, method, path string) bool { if r.Paths != nil && !MatchAnyPath(r.Paths, path) { return false } + if r.ProxyLogins != nil && !r.ProxyLogins[ctx.ProxyLogin] { + return false + } + if len(r.SourceCIDRs) > 0 && !matchSourceCIDR(r.SourceCIDRs, ctx.SourceIP) { + return false + } return true } @@ -76,6 +98,25 @@ func CompileRules(configs []RuleConfig, prefix string) ([]Rule, error) { if len(rc.Paths) > 0 { r.Paths = rc.Paths } + if len(rc.ProxyLogins) > 0 { + r.ProxyLogins = make(map[string]bool, len(rc.ProxyLogins)) + for _, login := range rc.ProxyLogins { + if login == "" { + return nil, fmt.Errorf("%s: rules[%d]: proxy_logins contains empty login", prefix, i) + } + r.ProxyLogins[login] = true + } + } + if len(rc.SourceCIDRs) > 0 { + r.SourceCIDRs = make([]*net.IPNet, 0, len(rc.SourceCIDRs)) + for _, cidr := range rc.SourceCIDRs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + return nil, fmt.Errorf("%s: rules[%d]: parsing source_cidrs %q: %w", prefix, i, cidr, err) + } + r.SourceCIDRs = append(r.SourceCIDRs, ipNet) + } + } rules = append(rules, r) } @@ -88,9 +129,28 @@ func isWildcard(methods []string) bool { // MatchAnyRule returns true if the request matches any rule in the list. func MatchAnyRule(rules []Rule, req *http.Request) bool { + return MatchAnyRuleContext(rules, req, MatchContext{}) +} + +// MatchAnyRuleContext returns true if the request and connection metadata +// match any rule in the list. +func MatchAnyRuleContext(rules []Rule, req *http.Request, ctx MatchContext) bool { host := StripPort(req.Host) for _, r := range rules { - if r.Matches(host, req.Method, req.URL.Path) { + if r.MatchesContext(host, req.Method, req.URL.Path, ctx) { + return true + } + } + return false +} + +func matchSourceCIDR(cidrs []*net.IPNet, sourceIP string) bool { + ip := net.ParseIP(sourceIP) + if ip == nil { + return false + } + for _, cidr := range cidrs { + if cidr.Contains(ip) { return true } } diff --git a/internal/hostmatch/rule_test.go b/internal/hostmatch/rule_test.go index 95507d7..3825f2c 100644 --- a/internal/hostmatch/rule_test.go +++ b/internal/hostmatch/rule_test.go @@ -43,3 +43,27 @@ func TestCompileRules_NoMethodsMatchesAll(t *testing.T) { require.True(t, rules[0].Matches("example.com", "GET", "/")) require.True(t, rules[0].Matches("example.com", "DELETE", "/")) } + +func TestCompileRules_ProxyLoginFilter(t *testing.T) { + rules, err := CompileRules([]RuleConfig{ + {Host: "example.com", ProxyLogins: []string{"ci", "dev"}}, + }, "test") + require.NoError(t, err) + + require.True(t, rules[0].MatchesContext("example.com", "GET", "/", MatchContext{ProxyLogin: "ci"})) + require.True(t, rules[0].MatchesContext("example.com", "GET", "/", MatchContext{ProxyLogin: "dev"})) + require.False(t, rules[0].MatchesContext("example.com", "GET", "/", MatchContext{ProxyLogin: "prod"})) + require.False(t, rules[0].MatchesContext("example.com", "GET", "/", MatchContext{})) +} + +func TestCompileRules_SourceCIDRFilter(t *testing.T) { + rules, err := CompileRules([]RuleConfig{ + {Host: "example.com", SourceCIDRs: []string{"10.0.0.0/8", "192.168.1.0/24"}}, + }, "test") + require.NoError(t, err) + + require.True(t, rules[0].MatchesContext("example.com", "GET", "/", MatchContext{SourceIP: "10.1.2.3"})) + require.True(t, rules[0].MatchesContext("example.com", "GET", "/", MatchContext{SourceIP: "192.168.1.10"})) + require.False(t, rules[0].MatchesContext("example.com", "GET", "/", MatchContext{SourceIP: "172.16.0.1"})) + require.False(t, rules[0].MatchesContext("example.com", "GET", "/", MatchContext{})) +} diff --git a/internal/proxy/auth.go b/internal/proxy/auth.go new file mode 100644 index 0000000..f57b66a --- /dev/null +++ b/internal/proxy/auth.go @@ -0,0 +1,144 @@ +package proxy + +import ( + "context" + "crypto/subtle" + "encoding/base64" + "fmt" + "log/slog" + "net" + "net/http" + "strings" + "sync/atomic" + + "github.com/ironsh/iron-proxy/internal/config" + "github.com/ironsh/iron-proxy/internal/transform/secrets" +) + +const proxyAuthRealm = `Basic realm="iron-proxy"` + +type authenticator struct { + required bool + passwords map[string]string +} + +type authenticatorHolder struct { + value atomic.Value // *authenticator +} + +type proxyAuth struct { + Login string +} + +func newAuthenticatorHolder(ctx context.Context, cfg config.ProxyAuth, logger *slog.Logger) (*authenticatorHolder, error) { + h := &authenticatorHolder{} + if err := h.Store(ctx, cfg, logger); err != nil { + return nil, err + } + return h, nil +} + +func (h *authenticatorHolder) Store(ctx context.Context, cfg config.ProxyAuth, logger *slog.Logger) error { + a, err := newAuthenticator(ctx, cfg, logger) + if err != nil { + return err + } + h.value.Store(a) + return nil +} + +func (h *authenticatorHolder) Load() *authenticator { + v := h.value.Load() + if v == nil { + return emptyAuthenticator() + } + return v.(*authenticator) +} + +func emptyAuthenticator() *authenticator { + return &authenticator{passwords: map[string]string{}} +} + +func newAuthenticator(ctx context.Context, cfg config.ProxyAuth, logger *slog.Logger) (*authenticator, error) { + passwords := make(map[string]string, len(cfg.Users)) + for i, user := range cfg.Users { + source, err := secrets.BuildSource(user.Password, logger) + if err != nil { + return nil, fmt.Errorf("proxy.auth.users[%d].password: %w", i, err) + } + password, err := source.Get(ctx) + if err != nil { + return nil, fmt.Errorf("proxy.auth.users[%d].password from %q: %w", i, source.Name(), err) + } + passwords[user.Login] = password + } + return &authenticator{ + required: cfg.Required, + passwords: passwords, + }, nil +} + +func (a *authenticator) enabled() bool { + return a.required || len(a.passwords) > 0 +} + +func (a *authenticator) authenticateHeader(header string) (proxyAuth, bool) { + if header == "" { + return proxyAuth{}, !a.required + } + login, password, ok := parseBasicProxyAuth(header) + if !ok { + return proxyAuth{}, false + } + if !a.authenticateLoginPassword(login, password) { + return proxyAuth{}, false + } + return proxyAuth{Login: login}, true +} + +func (a *authenticator) authenticateLoginPassword(login, password string) bool { + want, ok := a.passwords[login] + if !ok { + _ = subtle.ConstantTimeCompare([]byte(password), []byte("")) + return false + } + return subtle.ConstantTimeCompare([]byte(password), []byte(want)) == 1 +} + +func parseBasicProxyAuth(header string) (string, string, bool) { + scheme, encoded, ok := strings.Cut(strings.TrimSpace(header), " ") + if !ok || !strings.EqualFold(scheme, "Basic") { + return "", "", false + } + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", "", false + } + login, password, ok := strings.Cut(string(decoded), ":") + if !ok || login == "" { + return "", "", false + } + return login, password, true +} + +func proxyAuthRequiredResponse(req *http.Request) *http.Response { + return &http.Response{ + StatusCode: http.StatusProxyAuthRequired, + Status: "407 Proxy Authentication Required", + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{"Proxy-Authenticate": {proxyAuthRealm}}, + Body: http.NoBody, + Request: req, + ContentLength: 0, + } +} + +func sourceIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) + if err == nil { + return host + } + return remoteAddr +} diff --git a/internal/proxy/dnsguard_test.go b/internal/proxy/dnsguard_test.go index 095bfe2..bdecf2f 100644 --- a/internal/proxy/dnsguard_test.go +++ b/internal/proxy/dnsguard_test.go @@ -54,12 +54,13 @@ func TestUpstreamDenyGuard_HTTP(t *testing.T) { guard, err := dnsguard.New(denyCIDRs) require.NoError(t, err) - p := New(Options{ + p, err := New(Options{ HTTPAddr: "127.0.0.1:0", Pipeline: transform.NewPipelineHolder(pipeline), Guard: guard, Logger: logger, }) + require.NoError(t, err) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -147,13 +148,14 @@ func TestUpstreamDenyGuard_SNIPassthrough(t *testing.T) { guard, err := dnsguard.New([]string{"127.0.0.0/8"}) require.NoError(t, err) - p := New(Options{ + p, err := New(Options{ HTTPSAddr: "127.0.0.1:0", TLSMode: "sni-only", Pipeline: transform.NewPipelineHolder(pipeline), Guard: guard, Logger: logger, }) + require.NoError(t, err) p.sniUpstreamPort = upstreamPort ln, err := net.Listen("tcp", "127.0.0.1:0") diff --git a/internal/proxy/integration_test.go b/internal/proxy/integration_test.go index 956521c..b12585f 100644 --- a/internal/proxy/integration_test.go +++ b/internal/proxy/integration_test.go @@ -54,7 +54,7 @@ func startTunnelIntegrationProxy(t *testing.T, allowedHosts []string, logger *sl pipeline := transform.NewPipeline([]transform.Transformer{al}, transform.BodyLimits{}, logger) holder := transform.NewPipelineHolder(pipeline) - p := New(Options{ + p, err := New(Options{ HTTPAddr: "127.0.0.1:0", HTTPSAddr: "127.0.0.1:0", TunnelAddr: "127.0.0.1:0", @@ -62,6 +62,7 @@ func startTunnelIntegrationProxy(t *testing.T, allowedHosts []string, logger *sl Pipeline: holder, Logger: logger, }) + require.NoError(t, err) tunnelLn, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -118,13 +119,14 @@ func TestIntegration_DNSToProxyToUpstream(t *testing.T) { holder := transform.NewPipelineHolder(pipeline) // 4. Start proxy with HTTPS - p := New(Options{ + p, err := New(Options{ HTTPAddr: "127.0.0.1:0", HTTPSAddr: "127.0.0.1:0", CertCache: ca.certCache, Pipeline: holder, Logger: logger, }) + require.NoError(t, err) // Start HTTP listener httpLn, err := net.Listen("tcp", "127.0.0.1:0") @@ -493,4 +495,3 @@ func TestIntegration_SOCKS5(t *testing.T) { require.ErrorIs(t, err, io.EOF, "expected proxy to close the connection") }) } - diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 1900d38..e6356b2 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -45,6 +45,7 @@ type Proxy struct { resolver *net.Resolver guard *dnsguard.Guard mcpPolicy *mcp.PolicyHolder + auth *authenticatorHolder logger *slog.Logger // shutdownCtx is canceled by Shutdown to unblock in-flight TCP-passthrough @@ -69,6 +70,7 @@ type Options struct { Resolver *net.Resolver Guard *dnsguard.Guard // nil is treated as an empty (no-op) guard MCPPolicy *mcp.PolicyHolder // optional MCP-aware policy interceptor; nil disables MCP handling + Auth config.ProxyAuth Logger *slog.Logger // UpstreamResponseHeaderTimeout overrides the upstream HTTP transport's // ResponseHeaderTimeout. Zero falls back to @@ -82,7 +84,7 @@ type Options struct { // New creates a new Proxy. In TLSModeMITM, certCache must be non-nil. In // TLSModeSNIOnly, certCache is unused and may be nil. -func New(opts Options) *Proxy { +func New(opts Options) (*Proxy, error) { if opts.TLSMode == "" { opts.TLSMode = config.TLSModeMITM } @@ -91,6 +93,11 @@ func New(opts Options) *Proxy { if guard == nil { guard, _ = dnsguard.New(nil) } + auth, err := newAuthenticatorHolder(context.Background(), opts.Auth, opts.Logger) + if err != nil { + shutdownCancel() + return nil, err + } p := &Proxy{ httpsAddr: opts.HTTPSAddr, tlsMode: opts.TLSMode, @@ -102,6 +109,7 @@ func New(opts Options) *Proxy { resolver: opts.Resolver, guard: guard, mcpPolicy: opts.MCPPolicy, + auth: auth, logger: opts.Logger, shutdownCtx: shutdownCtx, shutdownCancel: shutdownCancel, @@ -120,7 +128,14 @@ func New(opts Options) *Proxy { }, } - return p + return p, nil +} + +// ReloadAuth atomically swaps proxy client authentication config for new +// HTTP/CONNECT/SOCKS5 handshakes. Existing tunnels keep their authenticated +// login in TunnelInfo. +func (p *Proxy) ReloadAuth(ctx context.Context, auth config.ProxyAuth) error { + return p.auth.Store(ctx, auth, p.logger) } // ListenAndServe starts the HTTP, HTTPS, and (optionally) tunnel listeners. @@ -293,9 +308,16 @@ func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request, tunnelInfo *t // Clone tunnelInfo so a transform that mutates the annotations map can't // leak state into sibling requests that share the same tunnel. tctx := &transform.TransformContext{ - Logger: p.logger, - Mode: transform.ModeMITM, - Tunnel: cloneTunnelInfo(tunnelInfo), + Logger: p.logger, + Mode: transform.ModeMITM, + Tunnel: cloneTunnelInfo(tunnelInfo), + SourceIP: sourceIP(r.RemoteAddr), + } + if tctx.Tunnel != nil { + tctx.ProxyLogin = tctx.Tunnel.ProxyLogin + if tctx.Tunnel.SourceIP != "" { + tctx.SourceIP = tctx.Tunnel.SourceIP + } } if r.TLS != nil { tctx.SNI = r.TLS.ServerName @@ -306,6 +328,8 @@ func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request, tunnelInfo *t Method: r.Method, Path: r.URL.Path, RemoteAddr: r.RemoteAddr, + ProxyLogin: tctx.ProxyLogin, + SourceIP: tctx.SourceIP, SNI: tctx.SNI, Mode: transform.ModeMITM, Tunnel: tctx.Tunnel, @@ -313,6 +337,20 @@ func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request, tunnelInfo *t pl, finish := p.beginPipelineRun(result) defer finish() + authSnapshot := p.auth.Load() + if tunnelInfo == nil && authSnapshot.enabled() { + auth, ok := authSnapshot.authenticateHeader(r.Header.Get("Proxy-Authorization")) + if !ok { + result.Action = transform.ActionReject + result.StatusCode = http.StatusProxyAuthRequired + p.writeResponse(w, proxyAuthRequiredResponse(r)) + return + } + tctx.ProxyLogin = auth.Login + result.ProxyLogin = auth.Login + r.Header.Del("Proxy-Authorization") + } + bodyLimits := pl.BodyLimits() // Wrap request body for lazy buffering by transforms. r.Body = transform.NewBufferedBody(r.Body, bodyLimits.MaxRequestBodyBytes) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 2ca3c48..adfe987 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -8,6 +8,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" "fmt" "io" "log/slog" @@ -17,14 +18,17 @@ import ( "net/http/httptest" "net/url" "os" + "path/filepath" "strings" "sync" "testing" "time" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" "github.com/ironsh/iron-proxy/internal/certcache" + "github.com/ironsh/iron-proxy/internal/config" "github.com/ironsh/iron-proxy/internal/transform" ) @@ -32,6 +36,24 @@ func testLogger() *slog.Logger { return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) } +func authUser(t *testing.T, login, password string) config.ProxyAuthUser { + t.Helper() + path := filepath.Join(t.TempDir(), login) + require.NoError(t, os.WriteFile(path, []byte(password), 0o600)) + return config.ProxyAuthUser{ + Login: login, + Password: yaml.Node{ + Kind: yaml.MappingNode, + Content: []*yaml.Node{ + {Kind: yaml.ScalarNode, Value: "type"}, + {Kind: yaml.ScalarNode, Value: "file"}, + {Kind: yaml.ScalarNode, Value: "path"}, + {Kind: yaml.ScalarNode, Value: path}, + }, + }, + } +} + func generateTestCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) { t.Helper() @@ -95,6 +117,11 @@ func (r *replacerTransform) TransformResponse(_ context.Context, _ *transform.Tr func startProxyWithTransforms(t *testing.T, transforms []transform.Transformer) (*Proxy, string, string, *x509.CertPool) { t.Helper() + return startProxyWithAuth(t, transforms, config.ProxyAuth{}) +} + +func startProxyWithAuth(t *testing.T, transforms []transform.Transformer, auth config.ProxyAuth) (*Proxy, string, string, *x509.CertPool) { + t.Helper() caCert, caKey := generateTestCA(t) cache, err := certcache.NewFromCA(caCert, caKey, 100, 72*time.Hour) @@ -102,13 +129,15 @@ func startProxyWithTransforms(t *testing.T, transforms []transform.Transformer) pipeline := transform.NewPipeline(transforms, transform.BodyLimits{}, testLogger()) holder := transform.NewPipelineHolder(pipeline) - p := New(Options{ + p, err := New(Options{ HTTPAddr: "127.0.0.1:0", HTTPSAddr: "127.0.0.1:0", CertCache: cache, Pipeline: holder, + Auth: auth, Logger: testLogger(), }) + require.NoError(t, err) httpLn, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -161,6 +190,86 @@ func TestHTTPProxy(t *testing.T) { require.Equal(t, "hello from upstream", string(body)) } +func TestHTTPProxy_AuthRequired(t *testing.T) { + _, httpAddr, _, _ := startProxyWithAuth(t, nil, config.ProxyAuth{ + Required: true, + Users: []config.ProxyAuthUser{authUser(t, "ci", "secret")}, + }) + + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/test", httpAddr), nil) + require.NoError(t, err) + req.Host = "example.com" + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusProxyAuthRequired, resp.StatusCode) + require.Equal(t, proxyAuthRealm, resp.Header.Get("Proxy-Authenticate")) +} + +func TestHTTPProxy_AuthValid(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Empty(t, r.Header.Get("Proxy-Authorization")) + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, "ok") + })) + defer upstream.Close() + + _, httpAddr, _, _ := startProxyWithAuth(t, nil, config.ProxyAuth{ + Required: true, + Users: []config.ProxyAuthUser{authUser(t, "ci", "secret")}, + }) + + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/test", httpAddr), nil) + require.NoError(t, err) + req.Host = upstream.Listener.Addr().String() + req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("ci:secret"))) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestHTTPProxy_ReloadAuth(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, "ok") + })) + defer upstream.Close() + + p, httpAddr, _, _ := startProxyWithAuth(t, nil, config.ProxyAuth{ + Required: true, + Users: []config.ProxyAuthUser{authUser(t, "ci", "secret")}, + }) + + do := func(login string) int { + t.Helper() + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/test", httpAddr), nil) + require.NoError(t, err) + req.Host = upstream.Listener.Addr().String() + req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(login+":secret"))) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + return resp.StatusCode + } + + require.Equal(t, http.StatusProxyAuthRequired, do("dev")) + + require.NoError(t, p.ReloadAuth(context.Background(), config.ProxyAuth{ + Required: true, + Users: []config.ProxyAuthUser{ + authUser(t, "ci", "secret"), + authUser(t, "dev", "secret"), + }, + })) + + require.Equal(t, http.StatusOK, do("dev")) +} + func TestHTTPProxy_PostBody(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) @@ -292,11 +401,12 @@ func TestHTTPProxy_ClientCancel(t *testing.T) { close(done) }) - p := New(Options{ + p, err := New(Options{ HTTPAddr: "127.0.0.1:0", Pipeline: transform.NewPipelineHolder(pipeline), Logger: testLogger(), }) + require.NoError(t, err) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) go func() { _ = p.httpServer.Serve(ln) }() diff --git a/internal/proxy/sni_passthrough.go b/internal/proxy/sni_passthrough.go index 84d3118..25d971b 100644 --- a/internal/proxy/sni_passthrough.go +++ b/internal/proxy/sni_passthrough.go @@ -25,7 +25,7 @@ const ( // the transform pipeline with a host-only synthetic request, and on accept // TCP-passthroughs the connection to the upstream server. func (p *Proxy) handleSNIPassthrough(clientConn net.Conn) { - if err := p.serveSNIPassthrough(clientConn); err != nil { + if err := p.serveSNIPassthrough(clientConn, nil); err != nil { p.logger.Debug("sni passthrough error", slog.String("error", err.Error())) } } @@ -35,7 +35,7 @@ func (p *Proxy) handleSNIPassthrough(clientConn net.Conn) { // listener and the CONNECT/SOCKS5 tunnel TLS branch in sni-only mode. The // upstream port is fixed at 443 — a client-supplied CONNECT port is ignored // so an attacker cannot pivot an allowlisted hostname onto a different port. -func (p *Proxy) serveSNIPassthrough(clientConn net.Conn) error { +func (p *Proxy) serveSNIPassthrough(clientConn net.Conn, tunnelInfo *transform.TunnelInfo) error { defer clientConn.Close() sni, peeked, err := peekSNI(clientConn, sniPeekTimeout) @@ -47,8 +47,16 @@ func (p *Proxy) serveSNIPassthrough(clientConn net.Conn) error { result := &transform.PipelineResult{ Host: sni, RemoteAddr: clientConn.RemoteAddr().String(), + SourceIP: sourceIP(clientConn.RemoteAddr().String()), SNI: sni, Mode: transform.ModeSNIOnly, + Tunnel: cloneTunnelInfo(tunnelInfo), + } + if result.Tunnel != nil { + result.ProxyLogin = result.Tunnel.ProxyLogin + if result.Tunnel.SourceIP != "" { + result.SourceIP = result.Tunnel.SourceIP + } } pl, finish := p.beginPipelineRun(result) defer finish() @@ -59,11 +67,21 @@ func (p *Proxy) serveSNIPassthrough(clientConn net.Conn) error { result.Err = fmt.Errorf("client hello missing sni") return result.Err } + authSnapshot := p.auth.Load() + if authSnapshot.required && result.ProxyLogin == "" { + result.Action = transform.ActionReject + result.StatusCode = http.StatusProxyAuthRequired + result.Err = fmt.Errorf("proxy auth requires an HTTP proxy or SOCKS5 auth handshake") + return result.Err + } tctx := &transform.TransformContext{ - Logger: p.logger, - SNI: sni, - Mode: transform.ModeSNIOnly, + Logger: p.logger, + SNI: sni, + Mode: transform.ModeSNIOnly, + Tunnel: cloneTunnelInfo(tunnelInfo), + ProxyLogin: result.ProxyLogin, + SourceIP: result.SourceIP, } req := &http.Request{ diff --git a/internal/proxy/sni_passthrough_test.go b/internal/proxy/sni_passthrough_test.go index d9cd678..8955888 100644 --- a/internal/proxy/sni_passthrough_test.go +++ b/internal/proxy/sni_passthrough_test.go @@ -98,7 +98,8 @@ func buildSNIProxy(t *testing.T, allowed []string, withTunnel bool) (*Proxy, fun if withTunnel { opts.TunnelAddr = "127.0.0.1:0" } - p := New(opts) + p, err := New(opts) + require.NoError(t, err) return p, func() []transform.PipelineResult { mu.Lock() @@ -163,7 +164,7 @@ func startSNIPassthroughProxy(t *testing.T, allowed []string, upstream string) ( p, getResults := buildSNIProxy(t, allowed, false) p.sniUpstreamPort = upstreamPort - addr := startAcceptLoop(t, func(c net.Conn) { _ = p.serveSNIPassthrough(c) }) + addr := startAcceptLoop(t, func(c net.Conn) { _ = p.serveSNIPassthrough(c, nil) }) return addr, getResults } @@ -289,7 +290,7 @@ func TestSNIPassthrough_ShutdownClosesInFlight(t *testing.T) { done := make(chan struct{}) addr := startAcceptLoop(t, func(c net.Conn) { - _ = p.serveSNIPassthrough(c) + _ = p.serveSNIPassthrough(c, nil) close(done) }) diff --git a/internal/proxy/tunnel.go b/internal/proxy/tunnel.go index 8485300..976a6f1 100644 --- a/internal/proxy/tunnel.go +++ b/internal/proxy/tunnel.go @@ -92,7 +92,23 @@ func (p *Proxy) handleCONNECT(conn net.Conn, br *bufio.Reader) error { p.logger.Debug("tunnel CONNECT", slog.String("target", host)) - ok, rejectResp, tunnelInfo := p.tunnelTransformCheck(conn.RemoteAddr().String(), host, req.Header) + authSnapshot := p.auth.Load() + var auth proxyAuth + if authSnapshot.enabled() { + var ok bool + auth, ok = authSnapshot.authenticateHeader(req.Header.Get("Proxy-Authorization")) + if !ok { + resp := proxyAuthRequiredResponse(req) + p.emitTunnelAuthReject(conn.RemoteAddr().String(), host, resp.StatusCode) + if err := resp.Write(conn); err != nil { + return fmt.Errorf("write auth rejection: %w", err) + } + return nil + } + req.Header.Del("Proxy-Authorization") + } + + ok, rejectResp, tunnelInfo := p.tunnelTransformCheck(conn.RemoteAddr().String(), host, req.Header, auth) if !ok { if rejectResp == nil { rejectResp = &http.Response{ @@ -140,23 +156,52 @@ func (p *Proxy) handleSOCKS5(conn net.Conn, br *bufio.Reader) error { return fmt.Errorf("read methods: %w", err) } - // We only support no-auth (0x00) + authSnapshot := p.auth.Load() + var auth proxyAuth hasNoAuth := false + hasUserPass := false for _, m := range methods { if m == 0x00 { hasNoAuth = true - break } - } - if !hasNoAuth { - if err := p.socks5Reply(conn, 0xFF); err != nil { - return fmt.Errorf("write no-acceptable-methods: %w", err) + if m == 0x02 { + hasUserPass = true } - return nil } - // Reply: use no-auth - if _, err := conn.Write([]byte{0x05, 0x00}); err != nil { - return fmt.Errorf("write auth reply: %w", err) + if authSnapshot.enabled() { + switch { + case hasUserPass: + if _, err := conn.Write([]byte{0x05, 0x02}); err != nil { + return fmt.Errorf("write username-password method: %w", err) + } + var ok bool + auth, ok, err = p.readSOCKS5UserPass(conn, br, authSnapshot) + if err != nil { + return err + } + if !ok { + return nil + } + case !authSnapshot.required && hasNoAuth: + if _, err := conn.Write([]byte{0x05, 0x00}); err != nil { + return fmt.Errorf("write auth reply: %w", err) + } + default: + if _, err := conn.Write([]byte{0x05, 0xFF}); err != nil { + return fmt.Errorf("write no-acceptable-methods: %w", err) + } + return nil + } + } else { + if !hasNoAuth { + if _, err := conn.Write([]byte{0x05, 0xFF}); err != nil { + return fmt.Errorf("write no-acceptable-methods: %w", err) + } + return nil + } + if _, err := conn.Write([]byte{0x05, 0x00}); err != nil { + return fmt.Errorf("write auth reply: %w", err) + } } // --- Connect request --- @@ -215,7 +260,7 @@ func (p *Proxy) handleSOCKS5(conn net.Conn, br *bufio.Reader) error { p.logger.Debug("tunnel SOCKS5 CONNECT", slog.String("target", target)) - ok, _, tunnelInfo := p.tunnelTransformCheck(conn.RemoteAddr().String(), target, nil) + ok, _, tunnelInfo := p.tunnelTransformCheck(conn.RemoteAddr().String(), target, nil, auth) if !ok { if err := p.socks5Reply(conn, 0x02); err != nil { return fmt.Errorf("write connection-not-allowed: %w", err) @@ -239,12 +284,48 @@ func (p *Proxy) socks5Reply(conn net.Conn, status byte) error { return err } +func (p *Proxy) readSOCKS5UserPass(conn net.Conn, br *bufio.Reader, authSnapshot *authenticator) (proxyAuth, bool, error) { + ver, err := br.ReadByte() + if err != nil { + return proxyAuth{}, false, fmt.Errorf("read username-password version: %w", err) + } + if ver != 0x01 { + return proxyAuth{}, false, fmt.Errorf("unsupported username-password version: %d", ver) + } + ulen, err := br.ReadByte() + if err != nil { + return proxyAuth{}, false, fmt.Errorf("read username length: %w", err) + } + username := make([]byte, ulen) + if _, err := io.ReadFull(br, username); err != nil { + return proxyAuth{}, false, fmt.Errorf("read username: %w", err) + } + plen, err := br.ReadByte() + if err != nil { + return proxyAuth{}, false, fmt.Errorf("read password length: %w", err) + } + password := make([]byte, plen) + if _, err := io.ReadFull(br, password); err != nil { + return proxyAuth{}, false, fmt.Errorf("read password: %w", err) + } + if !authSnapshot.authenticateLoginPassword(string(username), string(password)) { + if _, err := conn.Write([]byte{0x01, 0x01}); err != nil { + return proxyAuth{}, false, fmt.Errorf("write username-password failure: %w", err) + } + return proxyAuth{}, false, nil + } + if _, err := conn.Write([]byte{0x01, 0x00}); err != nil { + return proxyAuth{}, false, fmt.Errorf("write username-password success: %w", err) + } + return proxyAuth{Login: string(username)}, true, nil +} + // tunnelTransformCheck runs a synthetic CONNECT request through the transform // pipeline to decide whether the tunnel should be allowed. When connectHeaders // is non-nil the headers from the original CONNECT request (e.g. // Proxy-Authorization) are forwarded to transforms so they can make // authentication and policy decisions at the tunnel level. -func (p *Proxy) tunnelTransformCheck(remoteAddr, target string, connectHeaders http.Header) (bool, *http.Response, *transform.TunnelInfo) { +func (p *Proxy) tunnelTransformCheck(remoteAddr, target string, connectHeaders http.Header, auth proxyAuth) (bool, *http.Response, *transform.TunnelInfo) { host, _, _ := net.SplitHostPort(target) hdr := http.Header{} @@ -270,9 +351,11 @@ func (p *Proxy) tunnelTransformCheck(remoteAddr, target string, connectHeaders h } tctx := &transform.TransformContext{ - Logger: p.logger, - SNI: host, - Mode: mode, + Logger: p.logger, + SNI: host, + Mode: mode, + ProxyLogin: auth.Login, + SourceIP: sourceIP(remoteAddr), } result := &transform.PipelineResult{ @@ -280,6 +363,8 @@ func (p *Proxy) tunnelTransformCheck(remoteAddr, target string, connectHeaders h Method: http.MethodConnect, Path: "", RemoteAddr: remoteAddr, + ProxyLogin: auth.Login, + SourceIP: tctx.SourceIP, SNI: host, Mode: mode, } @@ -311,10 +396,28 @@ func (p *Proxy) tunnelTransformCheck(remoteAddr, target string, connectHeaders h result.StatusCode = http.StatusOK return true, nil, &transform.TunnelInfo{ Target: target, + ProxyLogin: auth.Login, + SourceIP: tctx.SourceIP, RequestTransforms: result.RequestTransforms, } } +func (p *Proxy) emitTunnelAuthReject(remoteAddr, target string, status int) { + host, _, _ := net.SplitHostPort(target) + result := &transform.PipelineResult{ + Host: target, + Method: http.MethodConnect, + RemoteAddr: remoteAddr, + SourceIP: sourceIP(remoteAddr), + SNI: host, + Mode: transform.ModeMITM, + Action: transform.ActionReject, + StatusCode: status, + } + _, finish := p.beginPipelineRun(result) + finish() +} + // serveTunnel peeks at the client's first byte after the CONNECT/SOCKS5 // handshake to detect TLS (0x16) vs plain HTTP. TLS connections get MITM'd; // plain HTTP is served directly through handleHTTP. Anything else is rejected. @@ -347,7 +450,7 @@ func (p *Proxy) serveTunnel(clientConn net.Conn, target string, tunnelInfo *tran // port is ignored to prevent port-pivot attacks). func (p *Proxy) serveTunnelTLS(clientConn net.Conn, target string, tunnelInfo *transform.TunnelInfo) error { if p.tlsMode == config.TLSModeSNIOnly { - return p.serveSNIPassthrough(clientConn) + return p.serveSNIPassthrough(clientConn, tunnelInfo) } tlsConn := tls.Server(clientConn, &tls.Config{ @@ -391,6 +494,8 @@ func cloneTunnelInfo(info *transform.TunnelInfo) *transform.TunnelInfo { } return &transform.TunnelInfo{ Target: info.Target, + ProxyLogin: info.ProxyLogin, + SourceIP: info.SourceIP, RequestTransforms: traces, } } diff --git a/internal/proxy/tunnel_test.go b/internal/proxy/tunnel_test.go index b70c903..24827c8 100644 --- a/internal/proxy/tunnel_test.go +++ b/internal/proxy/tunnel_test.go @@ -17,11 +17,17 @@ import ( "github.com/stretchr/testify/require" "github.com/ironsh/iron-proxy/internal/certcache" + "github.com/ironsh/iron-proxy/internal/config" "github.com/ironsh/iron-proxy/internal/transform" ) func startTunnelProxy(t *testing.T, transforms []transform.Transformer) (*Proxy, string, *x509.CertPool) { t.Helper() + return startTunnelProxyWithAuth(t, transforms, config.ProxyAuth{}) +} + +func startTunnelProxyWithAuth(t *testing.T, transforms []transform.Transformer, auth config.ProxyAuth) (*Proxy, string, *x509.CertPool) { + t.Helper() caCert, caKey := generateTestCA(t) cache, err := certcache.NewFromCA(caCert, caKey, 100, 72*time.Hour) @@ -29,14 +35,16 @@ func startTunnelProxy(t *testing.T, transforms []transform.Transformer) (*Proxy, pipeline := transform.NewPipeline(transforms, transform.BodyLimits{}, testLogger()) holder := transform.NewPipelineHolder(pipeline) - p := New(Options{ + p, err := New(Options{ HTTPAddr: "127.0.0.1:0", HTTPSAddr: "127.0.0.1:0", TunnelAddr: "127.0.0.1:0", CertCache: cache, Pipeline: holder, + Auth: auth, Logger: testLogger(), }) + require.NoError(t, err) // Start tunnel listener tunnelLn, err := net.Listen("tcp", "127.0.0.1:0") @@ -189,6 +197,47 @@ func TestTunnel_CONNECT_MethodNotAllowed(t *testing.T) { require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) } +func TestTunnel_CONNECT_AuthRequired(t *testing.T) { + _, tunnelAddr, _ := startTunnelProxyWithAuth(t, nil, config.ProxyAuth{ + Required: true, + Users: []config.ProxyAuthUser{authUser(t, "ci", "secret")}, + }) + + conn, err := net.DialTimeout("tcp", tunnelAddr, 5*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, err = fmt.Fprintf(conn, "CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\n\r\n") + require.NoError(t, err) + + resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusProxyAuthRequired, resp.StatusCode) + require.Equal(t, proxyAuthRealm, resp.Header.Get("Proxy-Authenticate")) +} + +func TestTunnel_CONNECT_AuthValid(t *testing.T) { + _, tunnelAddr, _ := startTunnelProxyWithAuth(t, nil, config.ProxyAuth{ + Required: true, + Users: []config.ProxyAuthUser{authUser(t, "ci", "secret")}, + }) + + conn, err := net.DialTimeout("tcp", tunnelAddr, 5*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, err = fmt.Fprintf(conn, "CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Basic Y2k6c2VjcmV0\r\n\r\n") + require.NoError(t, err) + + resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) +} + func TestTunnel_SOCKS5_HTTP(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Socks", "true") @@ -258,6 +307,76 @@ func TestTunnel_SOCKS5_HTTP(t *testing.T) { require.Equal(t, "hello from socks5", string(body)) } +func TestTunnel_SOCKS5_AuthRequired(t *testing.T) { + _, tunnelAddr, _ := startTunnelProxyWithAuth(t, nil, config.ProxyAuth{ + Required: true, + Users: []config.ProxyAuthUser{authUser(t, "ci", "secret")}, + }) + + conn, err := net.DialTimeout("tcp", tunnelAddr, 5*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{0x05, 0x01, 0x00}) + require.NoError(t, err) + + resp := make([]byte, 2) + _, err = io.ReadFull(conn, resp) + require.NoError(t, err) + require.Equal(t, []byte{0x05, 0xFF}, resp) +} + +func TestTunnel_SOCKS5_UsernamePasswordAuth(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, "ok") + })) + defer upstream.Close() + + _, tunnelAddr, _ := startTunnelProxyWithAuth(t, nil, config.ProxyAuth{ + Required: true, + Users: []config.ProxyAuthUser{authUser(t, "ci", "secret")}, + }) + + conn, err := net.DialTimeout("tcp", tunnelAddr, 5*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{0x05, 0x01, 0x02}) + require.NoError(t, err) + method := make([]byte, 2) + _, err = io.ReadFull(conn, method) + require.NoError(t, err) + require.Equal(t, []byte{0x05, 0x02}, method) + + _, err = conn.Write([]byte{0x01, 0x02, 'c', 'i', 0x06, 's', 'e', 'c', 'r', 'e', 't'}) + require.NoError(t, err) + authResp := make([]byte, 2) + _, err = io.ReadFull(conn, authResp) + require.NoError(t, err) + require.Equal(t, []byte{0x01, 0x00}, authResp) + + upstreamHost, upstreamPortStr, _ := net.SplitHostPort(upstream.Listener.Addr().String()) + ip := net.ParseIP(upstreamHost).To4() + require.NotNil(t, ip) + var port uint16 + _, err = fmt.Sscanf(upstreamPortStr, "%d", &port) + require.NoError(t, err) + + connectReq := []byte{0x05, 0x01, 0x00, 0x01} + connectReq = append(connectReq, ip...) + portBuf := make([]byte, 2) + binary.BigEndian.PutUint16(portBuf, port) + connectReq = append(connectReq, portBuf...) + _, err = conn.Write(connectReq) + require.NoError(t, err) + + connectResp := make([]byte, 10) + _, err = io.ReadFull(conn, connectResp) + require.NoError(t, err) + require.Equal(t, byte(0x00), connectResp[1]) +} + func TestTunnel_SOCKS5_DomainName(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/internal/transform/allowlist/allowlist.go b/internal/transform/allowlist/allowlist.go index f85a39c..19b4b79 100644 --- a/internal/transform/allowlist/allowlist.go +++ b/internal/transform/allowlist/allowlist.go @@ -79,7 +79,7 @@ func New(domains []string, cidrs []string) (*Allowlist, error) { func (a *Allowlist) Name() string { return "allowlist" } func (a *Allowlist) TransformRequest(_ context.Context, tctx *transform.TransformContext, req *http.Request) (*transform.TransformResult, error) { - if hostmatch.MatchAnyRule(a.rules, req) { + if hostmatch.MatchAnyRuleContext(a.rules, req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) { return &transform.TransformResult{Action: transform.ActionContinue}, nil } if a.warn { diff --git a/internal/transform/allowlist/allowlist_test.go b/internal/transform/allowlist/allowlist_test.go index 762fb2e..cc84780 100644 --- a/internal/transform/allowlist/allowlist_test.go +++ b/internal/transform/allowlist/allowlist_test.go @@ -31,6 +31,15 @@ func resultWithMethodAndPath(t *testing.T, a *Allowlist, host, method, path stri return res } +func resultWithContext(t *testing.T, a *Allowlist, host string, tctx *transform.TransformContext) *transform.TransformResult { + t.Helper() + req := httptest.NewRequest("GET", "http://"+host+"/", nil) + req.Host = host + res, err := a.TransformRequest(context.Background(), tctx, req) + require.NoError(t, err) + return res +} + // --- Existing tests (backwards compat via New) --- func TestAllowlist_ExactDomainMatch(t *testing.T) { @@ -299,6 +308,34 @@ func TestAllowlist_RuleWithCIDR(t *testing.T) { require.Equal(t, transform.ActionReject, resultWithMethodAndPath(t, a, "internal.service", "GET", "/").Action) } +func TestAllowlist_ProxyLoginFilter(t *testing.T) { + a, err := newFromConfig(allowlistConfig{ + Rules: []hostmatch.RuleConfig{{ + Host: "api.openai.com", + ProxyLogins: []string{"ci"}, + }}, + }) + require.NoError(t, err) + + require.Equal(t, transform.ActionContinue, resultWithContext(t, a, "api.openai.com", &transform.TransformContext{ProxyLogin: "ci"}).Action) + require.Equal(t, transform.ActionReject, resultWithContext(t, a, "api.openai.com", &transform.TransformContext{ProxyLogin: "dev"}).Action) + require.Equal(t, transform.ActionReject, resultWithContext(t, a, "api.openai.com", &transform.TransformContext{}).Action) +} + +func TestAllowlist_SourceCIDRFilter(t *testing.T) { + a, err := newFromConfig(allowlistConfig{ + Rules: []hostmatch.RuleConfig{{ + Host: "api.openai.com", + SourceCIDRs: []string{"10.0.0.0/8"}, + }}, + }) + require.NoError(t, err) + + require.Equal(t, transform.ActionContinue, resultWithContext(t, a, "api.openai.com", &transform.TransformContext{SourceIP: "10.1.2.3"}).Action) + require.Equal(t, transform.ActionReject, resultWithContext(t, a, "api.openai.com", &transform.TransformContext{SourceIP: "192.168.1.5"}).Action) + require.Equal(t, transform.ActionReject, resultWithContext(t, a, "api.openai.com", &transform.TransformContext{}).Action) +} + // --- Warn mode tests --- func TestAllowlist_WarnModeAllowsBlockedRequests(t *testing.T) { diff --git a/internal/transform/annotate/annotate.go b/internal/transform/annotate/annotate.go index 5e0d562..8ee41c5 100644 --- a/internal/transform/annotate/annotate.go +++ b/internal/transform/annotate/annotate.go @@ -80,7 +80,7 @@ func (a *Annotate) Name() string { return "annotate" } func (a *Annotate) TransformRequest(_ context.Context, tctx *transform.TransformContext, req *http.Request) (*transform.TransformResult, error) { for _, g := range a.groups { - if !hostmatch.MatchAnyRule(g.rules, req) { + if !hostmatch.MatchAnyRuleContext(g.rules, req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) { continue } for _, h := range g.headers { diff --git a/internal/transform/audit.go b/internal/transform/audit.go index a8e7165..75714f5 100644 --- a/internal/transform/audit.go +++ b/internal/transform/audit.go @@ -32,6 +32,7 @@ func NewAuditLogger(logger *slog.Logger) AuditFunc { slog.String("method", result.Method), slog.String("path", result.Path), slog.String("remote_addr", result.RemoteAddr), + slog.String("source_ip", result.SourceIP), slog.String("sni", result.SNI), slog.String("mode", result.Mode.String()), slog.String("action", action), @@ -39,8 +40,17 @@ func NewAuditLogger(logger *slog.Logger) AuditFunc { slog.Float64("duration_ms", float64(result.Duration.Microseconds())/1000.0), ), } + if result.ProxyLogin != "" { + attrs = append(attrs, slog.String("proxy_login", result.ProxyLogin)) + } if result.Tunnel != nil { tunnelAttrs := []any{slog.String("target", result.Tunnel.Target)} + if result.Tunnel.ProxyLogin != "" { + tunnelAttrs = append(tunnelAttrs, slog.String("proxy_login", result.Tunnel.ProxyLogin)) + } + if result.Tunnel.SourceIP != "" { + tunnelAttrs = append(tunnelAttrs, slog.String("source_ip", result.Tunnel.SourceIP)) + } if len(result.Tunnel.RequestTransforms) > 0 { tunnelAttrs = append(tunnelAttrs, slog.Any("request_transforms", buildTraceEntries(result.Tunnel.RequestTransforms)), diff --git a/internal/transform/awsauth/awsauth.go b/internal/transform/awsauth/awsauth.go index c26b0dd..1a3f013 100644 --- a/internal/transform/awsauth/awsauth.go +++ b/internal/transform/awsauth/awsauth.go @@ -231,7 +231,7 @@ func (a *AWSAuth) TransformRequest(ctx context.Context, tctx *transform.Transfor return &transform.TransformResult{Action: transform.ActionContinue}, nil } - if !hostmatch.MatchAnyRule(a.rules, req) { + if !hostmatch.MatchAnyRuleContext(a.rules, req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) { return &transform.TransformResult{Action: transform.ActionContinue}, nil } diff --git a/internal/transform/bodycapture/bodycapture.go b/internal/transform/bodycapture/bodycapture.go index 4df38b9..e248424 100644 --- a/internal/transform/bodycapture/bodycapture.go +++ b/internal/transform/bodycapture/bodycapture.go @@ -79,7 +79,7 @@ func (b *bodyCapture) Name() string { return "body_capture" } // so a misbehaving body reader can't take down the request. func (b *bodyCapture) TransformRequest(_ context.Context, tctx *transform.TransformContext, req *http.Request) (*transform.TransformResult, error) { cont := &transform.TransformResult{Action: transform.ActionContinue} - if !hostmatch.MatchAnyRule(b.rules, req) { + if !hostmatch.MatchAnyRuleContext(b.rules, req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) { return cont, nil } if req.Body == nil || req.Body == http.NoBody { diff --git a/internal/transform/gcpauth/gcpauth.go b/internal/transform/gcpauth/gcpauth.go index 579a71b..11bf4de 100644 --- a/internal/transform/gcpauth/gcpauth.go +++ b/internal/transform/gcpauth/gcpauth.go @@ -193,7 +193,7 @@ func (g *GCPAuth) TransformRequest(ctx context.Context, tctx *transform.Transfor }, nil } - if len(g.rules) > 0 && !hostmatch.MatchAnyRule(g.rules, req) { + if len(g.rules) > 0 && !hostmatch.MatchAnyRuleContext(g.rules, req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) { return &transform.TransformResult{Action: transform.ActionContinue}, nil } diff --git a/internal/transform/grpc/grpc.go b/internal/transform/grpc/grpc.go index 676490c..6f598b6 100644 --- a/internal/transform/grpc/grpc.go +++ b/internal/transform/grpc/grpc.go @@ -130,7 +130,7 @@ func newGRPCTransform(cfg grpcConfig) (*GRPCTransform, error) { func (g *GRPCTransform) Name() string { return g.name } func (g *GRPCTransform) TransformRequest(ctx context.Context, tctx *transform.TransformContext, req *http.Request) (*transform.TransformResult, error) { - if len(g.rules) > 0 && !hostmatch.MatchAnyRule(g.rules, req) { + if len(g.rules) > 0 && !hostmatch.MatchAnyRuleContext(g.rules, req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) { return &transform.TransformResult{Action: transform.ActionContinue}, nil } @@ -167,7 +167,7 @@ func (g *GRPCTransform) TransformRequest(ctx context.Context, tctx *transform.Tr } func (g *GRPCTransform) TransformResponse(ctx context.Context, tctx *transform.TransformContext, req *http.Request, resp *http.Response) (*transform.TransformResult, error) { - if len(g.rules) > 0 && !hostmatch.MatchAnyRule(g.rules, req) { + if len(g.rules) > 0 && !hostmatch.MatchAnyRuleContext(g.rules, req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) { return &transform.TransformResult{Action: transform.ActionContinue}, nil } diff --git a/internal/transform/headerallowlist/headerallowlist.go b/internal/transform/headerallowlist/headerallowlist.go index 2406c79..2b3170c 100644 --- a/internal/transform/headerallowlist/headerallowlist.go +++ b/internal/transform/headerallowlist/headerallowlist.go @@ -94,7 +94,7 @@ func parseHeaderMatchers(patterns []string) ([]headerMatcher, error) { func (h *HeaderAllowlist) Name() string { return "header_allowlist" } func (h *HeaderAllowlist) TransformRequest(_ context.Context, tctx *transform.TransformContext, req *http.Request) (*transform.TransformResult, error) { - if len(h.rules) > 0 && !hostmatch.MatchAnyRule(h.rules, req) { + if len(h.rules) > 0 && !hostmatch.MatchAnyRuleContext(h.rules, req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) { return &transform.TransformResult{Action: transform.ActionContinue}, nil } diff --git a/internal/transform/hmacsign/hmacsign.go b/internal/transform/hmacsign/hmacsign.go index f645875..803dd3f 100644 --- a/internal/transform/hmacsign/hmacsign.go +++ b/internal/transform/hmacsign/hmacsign.go @@ -192,7 +192,7 @@ func newFromConfig(c config, logger *slog.Logger, build sourceBuilder) (*HMACSig func (h *HMACSign) Name() string { return "hmac_sign" } func (h *HMACSign) TransformRequest(ctx context.Context, tctx *transform.TransformContext, req *http.Request) (*transform.TransformResult, error) { - if !hostmatch.MatchAnyRule(h.rules, req) { + if !hostmatch.MatchAnyRuleContext(h.rules, req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) { return &transform.TransformResult{Action: transform.ActionContinue}, nil } diff --git a/internal/transform/judge/transform.go b/internal/transform/judge/transform.go index 614055d..89150de 100644 --- a/internal/transform/judge/transform.go +++ b/internal/transform/judge/transform.go @@ -170,7 +170,7 @@ func (j *Judge) TransformResponse(_ context.Context, _ *transform.TransformConte // TransformRequest runs the judge over a request. See the package doc-style // comment on the judge transform for the full control-flow description. func (j *Judge) TransformRequest(ctx context.Context, tctx *transform.TransformContext, req *http.Request) (*transform.TransformResult, error) { - if !hostmatch.MatchAnyRule(j.rules, req) { + if !hostmatch.MatchAnyRuleContext(j.rules, req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) { return &transform.TransformResult{Action: transform.ActionContinue}, nil } diff --git a/internal/transform/oauth/oauth.go b/internal/transform/oauth/oauth.go index 5d7617b..efc5c66 100644 --- a/internal/transform/oauth/oauth.go +++ b/internal/transform/oauth/oauth.go @@ -394,7 +394,7 @@ func (o *OAuth) TransformRequest(ctx context.Context, tctx *transform.TransformC } // First entry whose rules host-match wins; config order is the tie-breaker. - entry := o.matchEntry(req) + entry := o.matchEntry(req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) if entry == nil { return &transform.TransformResult{Action: transform.ActionContinue}, nil } @@ -424,9 +424,9 @@ func (o *OAuth) TransformResponse(context.Context, *transform.TransformContext, } // matchEntry returns the first entry whose host rules match req, or nil. -func (o *OAuth) matchEntry(req *http.Request) *tokenEntry { +func (o *OAuth) matchEntry(req *http.Request, ctx hostmatch.MatchContext) *tokenEntry { for _, e := range o.entries { - if hostmatch.MatchAnyRule(e.rules, req) { + if hostmatch.MatchAnyRuleContext(e.rules, req, ctx) { return e } } diff --git a/internal/transform/otelaudit.go b/internal/transform/otelaudit.go index e0b7999..6ecb283 100644 --- a/internal/transform/otelaudit.go +++ b/internal/transform/otelaudit.go @@ -46,12 +46,16 @@ func NewOTELAuditFunc(provider *sdklog.LoggerProvider) AuditFunc { log.String("method", result.Method), log.String("path", result.Path), log.String("remote_addr", result.RemoteAddr), + log.String("source_ip", result.SourceIP), log.String("sni", result.SNI), log.String("mode", result.Mode.String()), log.String("action", action), log.Int("status_code", result.StatusCode), log.Float64("duration_ms", float64(result.Duration.Microseconds())/1000.0), } + if result.ProxyLogin != "" { + attrs = append(attrs, log.String("proxy_login", result.ProxyLogin)) + } if result.Action == ActionReject { for _, tr := range result.RequestTransforms { diff --git a/internal/transform/secrets/secrets.go b/internal/transform/secrets/secrets.go index 7d1ad2e..6fab88c 100644 --- a/internal/transform/secrets/secrets.go +++ b/internal/transform/secrets/secrets.go @@ -338,7 +338,7 @@ func (s *Secrets) TransformRequest(ctx context.Context, tctx *transform.Transfor var unavailable []string for _, sec := range s.secrets { - if !hostmatch.MatchAnyRule(sec.rules, req) { + if !hostmatch.MatchAnyRuleContext(sec.rules, req, hostmatch.MatchContext{ProxyLogin: tctx.ProxyLogin, SourceIP: tctx.SourceIP}) { continue } diff --git a/internal/transform/transform.go b/internal/transform/transform.go index 62505ab..92c3dad 100644 --- a/internal/transform/transform.go +++ b/internal/transform/transform.go @@ -80,6 +80,8 @@ type TransformContext struct { Logger *slog.Logger Mode Mode Tunnel *TunnelInfo + ProxyLogin string + SourceIP string // BodyCapture is the side channel a body_capture transform uses to // communicate captured request body bytes out of the pipeline. The proxy @@ -98,6 +100,10 @@ type TransformContext struct { type TunnelInfo struct { // Target is the host:port from the CONNECT request or SOCKS5 target. Target string + // ProxyLogin is the authenticated proxy login that opened the tunnel. + ProxyLogin string + // SourceIP is the client IP that opened the tunnel. + SourceIP string // RequestTransforms are the traces from the CONNECT/SOCKS5 request // pipeline, in the order the transforms ran. Inner request transforms and @@ -128,6 +134,8 @@ type PipelineResult struct { Method string Path string RemoteAddr string + ProxyLogin string + SourceIP string SNI string Mode Mode diff --git a/iron-proxy.example.yaml b/iron-proxy.example.yaml index a9f446c..c29e4c9 100644 --- a/iron-proxy.example.yaml +++ b/iron-proxy.example.yaml @@ -24,6 +24,22 @@ dns: proxy: http_listen: ":80" https_listen: ":443" + # Optional CONNECT/SOCKS5 proxy listener. + # tunnel_listen: ":8080" + # + # Optional client auth for HTTP proxy requests, CONNECT, and SOCKS5. + # Default is false; existing no-auth configs keep working. + # auth: + # required: true + # users: + # - login: "ci" + # password: + # type: env + # var: "IRON_PROXY_CI_PASSWORD" + # - login: "dev" + # password: + # type: file + # path: "/run/secrets/iron_proxy_dev_password" # Global body buffer limits for transforms. max_request_body_bytes: 1048576 # default: 1 MiB # max_response_body_bytes: 0 # default: 0 (uncapped) @@ -102,6 +118,9 @@ transforms: - host: "api.openai.com" methods: ["POST"] paths: ["/v1/*"] + # Optional: also require an authenticated proxy login and source IP. + # proxy_logins: ["ci"] + # source_cidrs: ["10.16.0.0/16"] - host: "*.anthropic.com" methods: ["POST"] paths: ["/v1/messages", "/v1/complete"]