diff --git a/v3/go.mod b/v3/go.mod index 1fdb6d5..042097f 100644 --- a/v3/go.mod +++ b/v3/go.mod @@ -4,7 +4,7 @@ go 1.23.0 require ( github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 - github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa + github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 github.com/google/uuid v1.6.0 github.com/jcmturner/gokrb5/v8 v8.4.4 diff --git a/v3/go.sum b/v3/go.sum index 9ffad25..ab53440 100644 --- a/v3/go.sum +++ b/v3/go.sum @@ -2,6 +2,8 @@ github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+ github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/v3/gssapi/sspi.go b/v3/gssapi/sspi.go index fa5f025..09ce97c 100644 --- a/v3/gssapi/sspi.go +++ b/v3/gssapi/sspi.go @@ -5,6 +5,8 @@ package gssapi import ( "bytes" + "crypto" + "crypto/x509" "encoding/binary" "fmt" @@ -15,8 +17,9 @@ import ( // SSPIClient implements ldap.GSSAPIClient interface. // Depends on secur32.dll. type SSPIClient struct { - creds *sspi.Credentials - ctx *kerberos.ClientContext + creds *sspi.Credentials + ctx *kerberos.ClientContext + channelBindings []byte } // NewSSPIClient returns a client with credentials of the current user. @@ -49,6 +52,26 @@ func NewSSPIClientWithUserCredentials(domain, username, password string) (*SSPIC }, nil } +// NewSSPIClientWithChannelBinding creates an RFC 5929 compliant client. +func NewSSPIClientWithChannelBinding(cert *x509.Certificate) (*SSPIClient, error) { + creds, err := kerberos.AcquireCurrentUserCredentials() + if err != nil { + return nil, err + } + + certHash := calculateCertificateHash(cert) + if certHash == nil { + return nil, fmt.Errorf("failed to calculate certificate hash") + } + + tlsChannelBinding := append([]byte("tls-server-end-point:"), certHash...) + + return &SSPIClient{ + creds: creds, + channelBindings: createChannelBindingsStructure(tlsChannelBinding), + }, nil +} + // Close deletes any established secure context and closes the client. func (c *SSPIClient) Close() error { err1 := c.DeleteSecContext() @@ -82,7 +105,18 @@ func (c *SSPIClient) InitSecContextWithOptions(target string, token []byte, APOp switch token { case nil: - ctx, completed, output, err := kerberos.NewClientContextWithFlags(c.creds, target, sspiFlags) + // Use channel bindings if available, otherwise fall back to the standard method. + var ctx *kerberos.ClientContext + var completed bool + var output []byte + var err error + + if len(c.channelBindings) > 0 { + ctx, completed, output, err = kerberos.NewClientContextWithChannelBindings(c.creds, target, sspiFlags, c.channelBindings) + } else { + ctx, completed, output, err = kerberos.NewClientContextWithFlags(c.creds, target, sspiFlags) + } + if err != nil { return nil, false, err } @@ -90,7 +124,6 @@ func (c *SSPIClient) InitSecContextWithOptions(target string, token []byte, APOp return output, !completed, nil default: - completed, output, err := c.ctx.Update(token) if err != nil { return nil, false, err @@ -99,7 +132,6 @@ func (c *SSPIClient) InitSecContextWithOptions(target string, token []byte, APOp return nil, false, fmt.Errorf("error verifying flags: %v", err) } return output, !completed, nil - } } @@ -196,3 +228,61 @@ func handshakePayload(secLayer byte, maxSize uint32, authzid []byte) []byte { return payload } + +// createChannelBindingsStructure creates a Windows SEC_CHANNEL_BINDINGS structure. +// This is the format that Windows SSPI expects for channel binding tokens. +// https://learn.microsoft.com/en-us/windows/win32/api/sspi/ns-sspi-sec_channel_bindings +func createChannelBindingsStructure(applicationData []byte) []byte { + const headerSize = 32 // 8 DWORDs * 4 bytes each + appDataLen := uint32(len(applicationData)) + appDataOffset := uint32(headerSize) + + buf := make([]byte, headerSize+len(applicationData)) + + // All initiator and acceptor fields are 0 for TLS channel binding. + binary.LittleEndian.PutUint32(buf[24:], appDataLen) // cbApplicationDataLength + binary.LittleEndian.PutUint32(buf[28:], appDataOffset) // dwApplicationDataOffset + + copy(buf[headerSize:], applicationData) + + return buf +} + +// calculateCertificateHash implements RFC 5929 certificate hash calculation. +// https://www.rfc-editor.org/rfc/rfc5929.html#section-4.1 +func calculateCertificateHash(cert *x509.Certificate) []byte { + var hashFunc crypto.Hash + + switch cert.SignatureAlgorithm { + case x509.SHA256WithRSA, + x509.SHA256WithRSAPSS, + x509.ECDSAWithSHA256, + x509.DSAWithSHA256: + + hashFunc = crypto.SHA256 + case x509.SHA384WithRSA, + x509.SHA384WithRSAPSS, + x509.ECDSAWithSHA384: + + hashFunc = crypto.SHA384 + case x509.SHA512WithRSA, + x509.SHA512WithRSAPSS, + x509.ECDSAWithSHA512: + + hashFunc = crypto.SHA512 + case x509.MD5WithRSA, + x509.SHA1WithRSA, + x509.ECDSAWithSHA1, + x509.DSAWithSHA1: + + hashFunc = crypto.SHA256 + default: + return nil + } + + hasher := hashFunc.New() + + // Important to hash cert in DER format. + hasher.Write(cert.Raw) + return hasher.Sum(nil) +} diff --git a/v3/gssapi/sspi_test.go b/v3/gssapi/sspi_test.go new file mode 100644 index 0000000..175d064 --- /dev/null +++ b/v3/gssapi/sspi_test.go @@ -0,0 +1,108 @@ +//go:build windows +// +build windows + +package gssapi + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "strings" + "testing" + "time" +) + +// createTestCertificate creates a test certificate with the specified signature algorithm. +func createTestCertificate(sigAlg x509.SignatureAlgorithm) (*x509.Certificate, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + template := x509.Certificate{ + SignatureAlgorithm: sigAlg, + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Company"}, + Country: []string{"US"}, + Province: []string{""}, + Locality: []string{"San Francisco"}, + StreetAddress: []string{""}, + PostalCode: []string{""}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, err + } + + cert, err := x509.ParseCertificate(certDER) + if err != nil { + return nil, err + } + + return cert, nil +} + +func TestNewSSPIClientWithChannelBinding(t *testing.T) { + tests := []struct { + name string + sigAlg x509.SignatureAlgorithm + }{ + { + name: x509.SHA256WithRSA.String(), + sigAlg: x509.SHA256WithRSA, + }, + { + name: x509.SHA384WithRSA.String(), + sigAlg: x509.SHA384WithRSA, + }, + { + name: x509.SHA512WithRSA.String(), + sigAlg: x509.SHA512WithRSA, + }, + { + name: x509.SHA1WithRSA.String() + " (should fallback to SHA256)", + sigAlg: x509.SHA1WithRSA, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cert, err := createTestCertificate(tt.sigAlg) + if err != nil { + t.Fatalf("Failed to create test certificate: %v", err) + } + + client, err := NewSSPIClientWithChannelBinding(cert) + t.Cleanup(func() { + client.Close() + }) + + if err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + if client == nil { + t.Error("Expected client but got nil") + } + if len(client.channelBindings) == 0 { + t.Error("Expected channel bindings to be set") + } + + applicationData := client.channelBindings[32:] + expectedPrefix := "tls-server-end-point:" + if !strings.HasPrefix(string(applicationData), expectedPrefix) { + t.Errorf("Expected application data to start with %q, got %q", expectedPrefix, string(applicationData)) + } + }) + } +}