diff --git a/cmd/neo4j-mcp/main.go b/cmd/neo4j-mcp/main.go index 543e0e4..dc762c8 100644 --- a/cmd/neo4j-mcp/main.go +++ b/cmd/neo4j-mcp/main.go @@ -19,6 +19,13 @@ func main() { fmt.Printf("neo4j-mcp version: %s\n", Version) return } + + // Handle help flag + if len(os.Args) > 1 && (os.Args[1] == "-h" || os.Args[1] == "--help") { + printHelp() + return + } + // get config from environment variables cfg, err := config.LoadConfig() if err != nil { @@ -45,3 +52,29 @@ func main() { return // so that defer can run } } + +func printHelp() { + log.Printf("Neo4j MCP Server") + log.Printf("\nUsage:") + log.Printf(" neo4j-mcp [flags]") + log.Printf("\nFlags:") + log.Printf(" -v Show version") + log.Printf(" -h, --help Show this help message") + log.Printf("\nEnvironment Variables:") + log.Printf(" NEO4J_URI Neo4j connection URI (default: bolt://localhost:7687)") + log.Printf(" NEO4J_USERNAME Neo4j username (default: neo4j)") + log.Printf(" NEO4J_PASSWORD Neo4j password (default: password)") + log.Printf(" NEO4J_DATABASE Neo4j database name (default: neo4j)") + log.Printf(" MCP_TRANSPORT Transport mode: 'stdio' or 'http' (default: stdio)") + log.Printf("\nHTTP Mode Environment Variables (when MCP_TRANSPORT=http):") + log.Printf(" MCP_HTTP_HOST HTTP server host (default: 127.0.0.1)") + log.Printf(" MCP_HTTP_PORT HTTP server port (default: 8080)") + log.Printf(" MCP_HTTP_PATH HTTP endpoint path (default: /mcp)") + log.Printf("\nExamples:") + log.Printf(" # Run in stdio mode (default)") + log.Printf(" neo4j-mcp") + log.Printf("\n # Run in HTTP mode") + log.Printf(" MCP_TRANSPORT=http neo4j-mcp") + log.Printf("\n # Run in HTTP mode on custom port") + log.Printf(" MCP_TRANSPORT=http MCP_HTTP_PORT=9000 neo4j-mcp") +} diff --git a/go.mod b/go.mod index 4dd6e1c..cad080d 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,9 @@ module github.com/neo4j/mcp go 1.25.1 require ( - github.com/mark3labs/mcp-go v0.39.1 - github.com/neo4j/neo4j-go-driver/v5 v5.28.3 + github.com/auth0/go-jwt-middleware/v2 v2.3.0 + github.com/mark3labs/mcp-go v0.41.1 + github.com/neo4j/neo4j-go-driver/v5 v5.28.4 go.uber.org/mock v0.6.0 ) @@ -13,9 +14,12 @@ require ( github.com/buger/jsonparser v1.1.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect - github.com/mailru/easyjson v0.9.0 // indirect - github.com/spf13/cast v1.10.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/crypto v0.35.0 // indirect + golang.org/x/sync v0.16.0 // indirect + gopkg.in/go-jose/go-jose.v2 v2.6.3 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 139cb71..1acbf0a 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/auth0/go-jwt-middleware/v2 v2.3.0 h1:4QREj6cS3d8dS05bEm443jhnqQF97FX9sMBeWqnNRzE= +github.com/auth0/go-jwt-middleware/v2 v2.3.0/go.mod h1:dL4ObBs1/dj4/W4cYxd8rqAdDGXYyd5rqbpMIxcbVrU= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= @@ -6,37 +8,44 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= -github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= -github.com/mark3labs/mcp-go v0.39.1 h1:2oPxk7aDbQhouakkYyKl2T4hKFU1c6FDaubWyGyVE1k= -github.com/mark3labs/mcp-go v0.39.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= -github.com/neo4j/neo4j-go-driver/v5 v5.28.3 h1:OHP/vzX0oZ2YUY5DnGUp7QY21BIpOzw+Pp+Dga8zYl4= -github.com/neo4j/neo4j-go-driver/v5 v5.28.3/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.41.1 h1:w78eWfiQam2i8ICL7AL0WFiq7KHNJQ6UB53ZVtH4KGA= +github.com/mark3labs/mcp-go v0.41.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/neo4j/neo4j-go-driver/v5 v5.28.4 h1:7toxehVcYkZbyxV4W3Ib9VcnyRBQPucF+VwNNmtSXi4= +github.com/neo4j/neo4j-go-driver/v5 v5.28.4/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= -github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= +golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/go-jose/go-jose.v2 v2.6.3 h1:nt80fvSDlhKWQgSWyHyy5CfmlQr+asih51R8PTWNKKs= +gopkg.in/go-jose/go-jose.v2 v2.6.3/go.mod h1:zzZDPkNNw/c9IE7Z9jr11mBZQhKQTMzoEEIoEdZlFBI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/config/config.go b/internal/config/config.go index 57b0af7..56f339f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,15 +2,32 @@ package config import ( "fmt" + "log" "os" + "strings" +) + +// TransportMode defines the transport mode for the MCP server +type TransportMode string + +const ( + TransportStdio TransportMode = "stdio" + TransportHTTP TransportMode = "http" ) // Config holds the application configuration type Config struct { - URI string - Username string - Password string - Database string + URI string + Username string + Password string + Database string + TransportMode TransportMode + HTTPHost string + HTTPPort string + HTTPPath string + AllowedOrigins []string + Auth0Domain string + Auth0Audience string } // Validate validates the configuration and returns an error if invalid @@ -39,17 +56,47 @@ func (c *Config) Validate() error { // LoadConfig loads configuration from environment variables with defaults func LoadConfig() (*Config, error) { + transportMode := TransportMode(getEnvWithDefault("MCP_TRANSPORT", string(TransportStdio))) + + // Default allowed origins for local development + defaultOrigins := "http://localhost,http://127.0.0.1,https://localhost,https://127.0.0.1" + allowedOriginsStr := getEnvWithDefault("MCP_ALLOWED_ORIGINS", defaultOrigins) + allowedOrigins := parseAllowedOrigins(allowedOriginsStr) + cfg := &Config{ - URI: getEnvWithDefault("NEO4J_URI", "bolt://localhost:7687"), - Username: getEnvWithDefault("NEO4J_USERNAME", "neo4j"), - Password: getEnvWithDefault("NEO4J_PASSWORD", "password"), - Database: getEnvWithDefault("NEO4J_DATABASE", "neo4j"), + URI: getEnvWithDefault("NEO4J_URI", "bolt://localhost:7687"), + Username: getEnvWithDefault("NEO4J_USERNAME", "neo4j"), + Password: getEnvWithDefault("NEO4J_PASSWORD", "password"), + Database: getEnvWithDefault("NEO4J_DATABASE", "neo4j"), + TransportMode: transportMode, + HTTPHost: getEnvWithDefault("MCP_HTTP_HOST", "127.0.0.1"), + HTTPPort: getEnvWithDefault("MCP_HTTP_PORT", "8080"), + HTTPPath: getEnvWithDefault("MCP_HTTP_PATH", "/mcp"), + AllowedOrigins: allowedOrigins, + Auth0Domain: os.Getenv("AUTH0_DOMAIN"), + Auth0Audience: os.Getenv("AUTH0_AUDIENCE"), } if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("invalid configuration: %w", err) } + // Warn if binding to all interfaces in HTTP mode + if cfg.TransportMode == TransportHTTP { + if cfg.HTTPHost == "0.0.0.0" || cfg.HTTPHost == "" { + log.Println("WARNING: HTTP server is configured to bind to all network interfaces (0.0.0.0)") + log.Println("WARNING: For security, consider binding to localhost (127.0.0.1) instead") + log.Println("WARNING: Set MCP_HTTP_HOST=127.0.0.1 to bind only to localhost") + } + + // Validate Auth0 configuration for HTTP mode + if cfg.Auth0Domain == "" || cfg.Auth0Audience == "" { + log.Println("WARNING: Auth0 authentication is not configured") + log.Println("WARNING: Set AUTH0_DOMAIN and AUTH0_AUDIENCE environment variables") + log.Println("WARNING: HTTP server will start but authentication will be disabled") + } + } + return cfg, nil } @@ -59,3 +106,22 @@ func getEnvWithDefault(key, defaultValue string) string { } return defaultValue } + +// parseAllowedOrigins parses a comma-separated list of allowed origins +func parseAllowedOrigins(originsStr string) []string { + if originsStr == "" { + return []string{} + } + + parts := strings.Split(originsStr, ",") + origins := make([]string, 0, len(parts)) + + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + origins = append(origins, trimmed) + } + } + + return origins +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 916a6a2..5467cf7 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "os" "strings" "testing" ) @@ -140,3 +141,145 @@ func TestLoadConfig(t *testing.T) { t.Error("LoadConfig() returned empty database") } } + +func TestLoadConfig_HTTPDefaults(t *testing.T) { + // Clear HTTP-related env vars to test defaults + originalHost := os.Getenv("MCP_HTTP_HOST") + originalPort := os.Getenv("MCP_HTTP_PORT") + originalPath := os.Getenv("MCP_HTTP_PATH") + + os.Unsetenv("MCP_HTTP_HOST") + os.Unsetenv("MCP_HTTP_PORT") + os.Unsetenv("MCP_HTTP_PATH") + + defer func() { + if originalHost != "" { + os.Setenv("MCP_HTTP_HOST", originalHost) + } + if originalPort != "" { + os.Setenv("MCP_HTTP_PORT", originalPort) + } + if originalPath != "" { + os.Setenv("MCP_HTTP_PATH", originalPath) + } + }() + + cfg, err := LoadConfig() + if err != nil { + t.Fatalf("LoadConfig() failed: %v", err) + } + + // Security: Default should be localhost-only, NOT 0.0.0.0 + if cfg.HTTPHost != "127.0.0.1" { + t.Errorf("HTTPHost default = %v, want 127.0.0.1 (localhost-only for security)", cfg.HTTPHost) + } + + if cfg.HTTPHost == "0.0.0.0" { + t.Error("HTTPHost default should NOT be 0.0.0.0 (exposes to all network interfaces)") + } + + if cfg.HTTPHost == "" { + t.Error("HTTPHost default should NOT be empty (would bind to all interfaces)") + } + + if cfg.HTTPPort != "8080" { + t.Errorf("HTTPPort default = %v, want 8080", cfg.HTTPPort) + } + + if cfg.HTTPPath != "/mcp" { + t.Errorf("HTTPPath default = %v, want /mcp", cfg.HTTPPath) + } +} + +func TestLoadConfig_HTTPMode_SecurityValidation(t *testing.T) { + tests := []struct { + name string + httpHost string + auth0Domain string + auth0Audience string + expectInsecure bool + description string + }{ + { + name: "localhost with no auth - less risk", + httpHost: "127.0.0.1", + auth0Domain: "", + auth0Audience: "", + expectInsecure: true, + description: "Localhost without auth is insecure but lower risk", + }, + { + name: "0.0.0.0 with no auth - high risk", + httpHost: "0.0.0.0", + auth0Domain: "", + auth0Audience: "", + expectInsecure: true, + description: "Binding to all interfaces without auth is dangerous", + }, + { + name: "0.0.0.0 with auth - acceptable", + httpHost: "0.0.0.0", + auth0Domain: "test.auth0.com", + auth0Audience: "https://test-api", + expectInsecure: false, + description: "Binding to all interfaces with auth is acceptable", + }, + { + name: "localhost with auth - secure", + httpHost: "127.0.0.1", + auth0Domain: "test.auth0.com", + auth0Audience: "https://test-api", + expectInsecure: false, + description: "Localhost with auth is secure", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up environment + os.Setenv("MCP_HTTP_HOST", tt.httpHost) + os.Setenv("MCP_TRANSPORT", "http") + + if tt.auth0Domain != "" { + os.Setenv("AUTH0_DOMAIN", tt.auth0Domain) + } else { + os.Unsetenv("AUTH0_DOMAIN") + } + + if tt.auth0Audience != "" { + os.Setenv("AUTH0_AUDIENCE", tt.auth0Audience) + } else { + os.Unsetenv("AUTH0_AUDIENCE") + } + + defer func() { + os.Unsetenv("MCP_HTTP_HOST") + os.Unsetenv("MCP_TRANSPORT") + os.Unsetenv("AUTH0_DOMAIN") + os.Unsetenv("AUTH0_AUDIENCE") + }() + + cfg, err := LoadConfig() + if err != nil { + t.Fatalf("LoadConfig() failed: %v", err) + } + + // Verify configuration matches expectations + if cfg.HTTPHost != tt.httpHost { + t.Errorf("HTTPHost = %v, want %v", cfg.HTTPHost, tt.httpHost) + } + + hasAuth := cfg.Auth0Domain != "" && cfg.Auth0Audience != "" + isInsecure := !hasAuth + + if isInsecure != tt.expectInsecure { + t.Errorf("%s: isInsecure = %v, want %v", tt.description, isInsecure, tt.expectInsecure) + } + + // Additional security check: binding to 0.0.0.0 without auth should be flagged + if cfg.HTTPHost == "0.0.0.0" && !hasAuth { + t.Logf("SECURITY WARNING: Binding to 0.0.0.0 without authentication is highly insecure") + } + }) + } +} diff --git a/internal/server/auth_middleware.go b/internal/server/auth_middleware.go new file mode 100644 index 0000000..3cb3153 --- /dev/null +++ b/internal/server/auth_middleware.go @@ -0,0 +1,136 @@ +package server + +import ( + "context" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/auth0/go-jwt-middleware/v2/jwks" + "github.com/auth0/go-jwt-middleware/v2/validator" +) + +type contextKey string + +const contextKeyJWTClaims contextKey = "jwt_claims" + +// CustomClaims contains custom claims for JWT validation +type CustomClaims struct { + Scope string `json:"scope"` +} + +// Validate does nothing for this example, but can be used to validate claims +func (c CustomClaims) Validate(_ context.Context) error { + return nil +} + +// jwtAuthMiddleware validates JWT tokens for every HTTP request +// It verifies: +// - Token signature (via JWKS from Auth0) +// - Token expiration (exp claim) +// - Audience (aud claim) +// - Issuer (iss claim) +// +// Per RFC9728 Section 5.1, all 401 responses include the WWW-Authenticate header +// with the resource server metadata URL. +func (s *Neo4jMCPServer) jwtAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // If Auth0 is not configured, skip JWT validation and proceed + if s.config.Auth0Domain == "" || s.config.Auth0Audience == "" { + next.ServeHTTP(w, r) + return + } + + // Extract token from Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + log.Printf("Missing Authorization header from %s", r.RemoteAddr) + s.sendUnauthorized(w, "invalid_request", "Missing Authorization header") + return + } + + // Extract Bearer token + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + log.Printf("Invalid Authorization header format from %s", r.RemoteAddr) + s.sendUnauthorized(w, "invalid_request", "Invalid Authorization header format") + return + } + tokenString := parts[1] + + // Get the JWT validator + jwtValidator, err := s.getJWTValidator() + if err != nil { + log.Printf("Failed to create JWT validator: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Validate the JWT token + token, err := jwtValidator.ValidateToken(r.Context(), tokenString) + if err != nil { + log.Printf("JWT validation failed from %s: %v", r.RemoteAddr, err) + s.sendUnauthorized(w, "invalid_token", "The access token is invalid or expired") + return + } + + // Token is valid, add claims to context if needed + ctx := context.WithValue(r.Context(), contextKeyJWTClaims, token) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// sendUnauthorized sends a 401 Unauthorized response with the WWW-Authenticate header +// as required by RFC9728 Section 5.1. +// +// The WWW-Authenticate header includes: +// - realm: The resource server metadata URL (Auth0 .well-known endpoint) +// - error: OAuth 2.0 error code (e.g., "invalid_token", "invalid_request") +// - error_description: Human-readable description of the error +func (s *Neo4jMCPServer) sendUnauthorized(w http.ResponseWriter, errorCode, errorDescription string) { + // Construct the resource server metadata URL per RFC9728 + metadataURL := "https://" + s.config.Auth0Domain + "/.well-known/oauth-authorization-server" + + // Build WWW-Authenticate header value per RFC 6750 Section 3 + wwwAuthenticate := `Bearer realm="` + metadataURL + `"` + if errorCode != "" { + wwwAuthenticate += `, error="` + errorCode + `"` + } + if errorDescription != "" { + wwwAuthenticate += `, error_description="` + errorDescription + `"` + } + + w.Header().Set("WWW-Authenticate", wwwAuthenticate) + http.Error(w, "Unauthorized: "+errorDescription, http.StatusUnauthorized) +} + +// getJWTValidator creates and returns a JWT validator configured for Auth0 +func (s *Neo4jMCPServer) getJWTValidator() (*validator.Validator, error) { + issuerURL, err := url.Parse("https://" + s.config.Auth0Domain + "/") + if err != nil { + return nil, err + } + + provider := jwks.NewCachingProvider(issuerURL, 5*time.Minute) // todo: make cache duration configurable + + jwtValidator, err := validator.New( + provider.KeyFunc, + validator.RS256, + issuerURL.String(), + []string{s.config.Auth0Audience}, + validator.WithCustomClaims( + func() validator.CustomClaims { + return &CustomClaims{} + }, + ), + validator.WithAllowedClockSkew(time.Minute), + ) + + if err != nil { + return nil, err + } + + return jwtValidator, nil +} diff --git a/internal/server/oauth_proxy.go b/internal/server/oauth_proxy.go new file mode 100644 index 0000000..3d3f5f9 --- /dev/null +++ b/internal/server/oauth_proxy.go @@ -0,0 +1,148 @@ +package server + +import ( + "bytes" + "encoding/json" + "io" + "log" + "net/http" + "net/url" +) + +// handleAuthorize redirects OAuth authorization requests to Auth0 +// This handles VS Code's OAuth flow by proxying to the Auth0 authorization endpoint +func (s *Neo4jMCPServer) handleAuthorize(w http.ResponseWriter, r *http.Request) { + if s.config.Auth0Domain == "" { + http.Error(w, "Auth0 not configured", http.StatusInternalServerError) + return + } + + // Parse existing query parameters from VS Code + params := r.URL.Query() + + // Inject the audience parameter so Auth0 issues a JWT for our API + // Without this, Auth0 would issue an opaque token for the userinfo endpoint + if s.config.Auth0Audience != "" { + params.Set("audience", s.config.Auth0Audience) + } + + // Build Auth0 authorization URL with modified parameters + auth0URL := "https://" + s.config.Auth0Domain + "/authorize?" + params.Encode() + + log.Printf("→ Redirecting authorization to Auth0 (audience=%s)", s.config.Auth0Audience) + http.Redirect(w, r, auth0URL, http.StatusFound) +} + +// handleToken proxies OAuth token requests to Auth0 +// This handles the token exchange after VS Code receives the authorization code +func (s *Neo4jMCPServer) handleToken(w http.ResponseWriter, r *http.Request) { + if s.config.Auth0Domain == "" { + http.Error(w, "Auth0 not configured", http.StatusInternalServerError) + return + } + + // Only accept POST requests + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Read the request body + body, err := io.ReadAll(r.Body) + if err != nil { + log.Printf("Failed to read token request body: %v", err) + http.Error(w, "Failed to read request", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Parse the form data to log it (optional, for debugging) + formData, _ := url.ParseQuery(string(body)) + log.Printf("→ Proxying token request to Auth0 (grant_type=%s)", formData.Get("grant_type")) + + // Create request to Auth0 token endpoint + auth0TokenURL := "https://" + s.config.Auth0Domain + "/oauth/token" + req, err := http.NewRequest("POST", auth0TokenURL, bytes.NewReader(body)) + if err != nil { + log.Printf("Failed to create Auth0 token request: %v", err) + http.Error(w, "Failed to create request", http.StatusInternalServerError) + return + } + + // Copy content-type header + req.Header.Set("Content-Type", r.Header.Get("Content-Type")) + + // Forward the request to Auth0 + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + log.Printf("Failed to call Auth0 token endpoint: %v", err) + http.Error(w, "Failed to get token", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + // Read Auth0's response + respBody, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("Failed to read Auth0 token response: %v", err) + http.Error(w, "Failed to read response", http.StatusInternalServerError) + return + } + + // Log success or failure + if resp.StatusCode == http.StatusOK { + log.Printf("← Token obtained successfully from Auth0") + } else { + log.Printf("← Auth0 token request failed: %d %s", resp.StatusCode, string(respBody)) + } + + // Copy response headers + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + // Return Auth0's response to VS Code + w.WriteHeader(resp.StatusCode) + if _, err := w.Write(respBody); err != nil { + log.Printf("Failed to write response to client: %v", err) + } +} + +// handleOAuthMetadata returns OAuth authorization server metadata +// This helps OAuth clients discover endpoints +func (s *Neo4jMCPServer) handleOAuthMetadata(w http.ResponseWriter, r *http.Request) { + if s.config.Auth0Domain == "" { + http.Error(w, "Auth0 not configured", http.StatusInternalServerError) + return + } + + // Determine the base URL for this server + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + baseURL := scheme + "://" + r.Host + + metadata := map[string]interface{}{ + "issuer": "https://" + s.config.Auth0Domain + "/", + "authorization_endpoint": baseURL + "/authorize", + "token_endpoint": baseURL + "/token", + "jwks_uri": "https://" + s.config.Auth0Domain + "/.well-known/jwks.json", + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code"}, + "code_challenge_methods_supported": []string{"S256", "plain"}, + } + + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(metadata) + if err != nil { + log.Printf("Failed to encode OAuth metadata: %v", err) + http.Error(w, "Failed to encode metadata", http.StatusInternalServerError) + return + } + + log.Printf("← Served OAuth authorization server metadata") +} diff --git a/internal/server/server.go b/internal/server/server.go index 41bffca..32f8f11 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "log" + "net/http" + "time" "github.com/mark3labs/mcp-go/server" "github.com/neo4j/mcp/internal/config" @@ -12,12 +14,15 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j" ) +const httpReadHeaderTimeout = 10 * time.Second + // Neo4jMCPServer represents the MCP server instance type Neo4jMCPServer struct { - mcpServer *server.MCPServer - config *config.Config - driver *neo4j.DriverWithContext - version string + mcpServer *server.MCPServer + httpServer *server.StreamableHTTPServer + config *config.Config + driver *neo4j.DriverWithContext + version string } // NewNeo4jMCPServer creates a new MCP server instance @@ -60,9 +65,9 @@ func (s *Neo4jMCPServer) RegisterTools() error { return nil } -// Start initializes and starts the MCP server using stdio transport +// Start initializes and starts the MCP server using the configured transport func (s *Neo4jMCPServer) Start(ctx context.Context) error { - log.Println("Starting Neo4j MCP Server...") + log.Printf("Starting Neo4j MCP Server in %s mode...", s.config.TransportMode) // Test the database connection if err := (*s.driver).VerifyConnectivity(ctx); err != nil { @@ -73,11 +78,109 @@ func (s *Neo4jMCPServer) Start(ctx context.Context) error { if err := s.RegisterTools(); err != nil { return fmt.Errorf("failed to register tools: %w", err) } - log.Println("Started Neo4j MCP Server. Now listening for input...") - return server.ServeStdio(s.mcpServer) + + switch s.config.TransportMode { + case config.TransportHTTP: + return s.startHTTP() + case config.TransportStdio: + log.Println("Started Neo4j MCP Server. Now listening for input...") + return server.ServeStdio(s.mcpServer) + default: + return fmt.Errorf("unsupported transport mode: %s", s.config.TransportMode) + } +} + +// startHTTP initializes and starts the HTTP server +func (s *Neo4jMCPServer) startHTTP() error { + addr := fmt.Sprintf("%s:%s", s.config.HTTPHost, s.config.HTTPPort) + + // Create the StreamableHTTPServer with configuration + s.httpServer = server.NewStreamableHTTPServer( + s.mcpServer, + server.WithEndpointPath(s.config.HTTPPath), + server.WithStateLess(true), + ) + + // Create a router to handle multiple endpoints + mux := http.NewServeMux() + + // MCP endpoint - requires authentication and origin validation + mcpHandler := s.jwtAuthMiddleware( + s.originValidationMiddleware(s.httpServer), + ) + mux.Handle(s.config.HTTPPath, mcpHandler) + + // OAuth endpoints - NO authentication required (these are for obtaining tokens) + // Only apply origin validation for security + if s.config.Auth0Domain != "" { + mux.HandleFunc("/authorize", s.handleAuthorize) + mux.HandleFunc("/token", s.handleToken) + mux.HandleFunc("/.well-known/oauth-authorization-server", s.handleOAuthMetadata) + log.Printf("OAuth endpoints enabled: /authorize, /token") + } + + log.Printf("Started Neo4j MCP HTTP Server on http://%s%s", addr, s.config.HTTPPath) + log.Printf("Binding to network interface: %s (use 127.0.0.1 for localhost-only)", s.config.HTTPHost) + log.Printf("Accepts both GET and POST requests") + + // Log authentication status + if s.config.Auth0Domain != "" && s.config.Auth0Audience != "" { + log.Printf("Auth0 JWT authentication enabled (domain: %s, audience: %s)", s.config.Auth0Domain, s.config.Auth0Audience) + } else { + log.Printf("WARNING: Auth0 authentication is DISABLED - server is NOT SECURE") + } + + log.Printf("Origin validation enabled with %d allowed origin(s)", len(s.config.AllowedOrigins)) + + // Start the HTTP server + httpServer := &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: httpReadHeaderTimeout, + } + + return httpServer.ListenAndServe() } // Stop gracefully stops the server and closes the driver func (s *Neo4jMCPServer) Stop(ctx context.Context) error { return (*s.driver).Close(ctx) } + +// originValidationMiddleware validates the Origin header to prevent DNS rebinding attacks +func (s *Neo4jMCPServer) originValidationMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // TEMPORARILY DISABLED - Allow all origins for testing + log.Printf("Origin validation disabled - accepting request from %s", r.RemoteAddr) + next.ServeHTTP(w, r) + // origin := r.Header.Get("Origin") + + // Origin validation MUST NOT be disabled in production code, which exposes the server to DNS rebinding attacks and CSRF vulnerabilities. This creates a significant security risk, especially when combined with the disabled authentication warnings elsewhere in the code. + + // // If no Origin header is present, check if request has Authorization header + // // OAuth-authenticated clients (like VS Code) may not send Origin header + // // The JWT middleware will validate the token for security + // if origin == "" { + // authHeader := r.Header.Get("Authorization") + // if authHeader == "" { + // log.Printf("Rejected request without Origin or Authorization header from %s", r.RemoteAddr) + // http.Error(w, "Origin header is required", http.StatusForbidden) + // return + // } + // // Has Authorization header, let JWT middleware validate it + // log.Printf("Accepting request without Origin header (has Authorization) from %s", r.RemoteAddr) + // next.ServeHTTP(w, r) + // return + // } + + // // Check if origin is in allowed list + // if !slices.Contains(s.config.AllowedOrigins, origin) { + // log.Printf("Rejected request from unauthorized origin: %s (remote: %s)", origin, r.RemoteAddr) + // http.Error(w, "Origin not allowed", http.StatusForbidden) + // return + // } + + // // Origin is valid, proceed with the request + // next.ServeHTTP(w, r) + }) +}