diff --git a/kmipclient/client.go b/kmipclient/client.go index c0fc5b0..a351c2d 100644 --- a/kmipclient/client.go +++ b/kmipclient/client.go @@ -40,6 +40,7 @@ package kmipclient import ( + "bytes" "context" "crypto/tls" "crypto/x509" @@ -47,9 +48,12 @@ import ( "fmt" "io" "net" + "net/http" + "net/url" "os" "slices" "sync" + "time" "github.com/ovh/kmip-go" "github.com/ovh/kmip-go/payloads" @@ -60,6 +64,7 @@ var supportedVersions = []kmip.ProtocolVersion{kmip.V1_4, kmip.V1_3, kmip.V1_2, type opts struct { middlewares []Middleware + HttpMiddlewares []HttpMiddleware supportedVersions []kmip.ProtocolVersion enforceVersion *kmip.ProtocolVersion rootCAs [][]byte @@ -68,6 +73,9 @@ type opts struct { tlsCfg *tls.Config tlsCiphers []uint16 dialer func(context.Context, string) (net.Conn, error) + enabledHttp bool + contentType string + httpURI string //TODO: Add KMIP Authentication / Credentials //TODO: Overwrite default/preferred/supported key formats for register } @@ -137,6 +145,19 @@ func WithMiddlewares(middlewares ...Middleware) Option { } } +// WithHttpMiddlewares returns an Option that appends the provided HttpMiddleware(s) to the client's HTTP middleware chain. +// This allows customization of the client's HTTP behavior by injecting additional processing steps. +// +// Usage: +// +// client.New(WithHttpMiddlewares(mw1, mw2, ...)) +func WithHttpMiddlewares(middlewares ...HttpMiddleware) Option { + return func(o *opts) error { + o.HttpMiddlewares = append(o.HttpMiddlewares, middlewares...) + return nil + } +} + // WithKmipVersions returns an Option that sets the supported KMIP protocol versions for the client. // It appends the provided versions to the existing list, sorts them in descending order, // and removes any duplicate versions. This allows the client to negotiate the highest mutually @@ -353,6 +374,30 @@ func WithDialerUnsafe(dialer func(ctx context.Context, addr string) (net.Conn, e } } +// WithEnabledHttp returns an Option that enables HTTP support for the client. +func WithEnabledHttp() Option { + return func(o *opts) error { + o.enabledHttp = true + return nil + } +} + +// WithContentType returns an Option that sets the content type for the client. +func WithContentType(contentType string) Option { + return func(o *opts) error { + o.contentType = contentType + return nil + } +} + +// WithHttpURI returns an Option that sets the HTTP URI for the client. +func WithHttpURI(uri string) Option { + return func(o *opts) error { + o.httpURI = uri + return nil + } +} + // Client represents a KMIP client that manages a connection to a KMIP server, // handles protocol version negotiation, and supports middleware for request/response // processing. It provides thread-safe access to the underlying connection and @@ -365,6 +410,10 @@ type Client struct { dialer func(context.Context) (*conn, error) middlewares []Middleware addr string + httpClient *http.Client + HttpMiddlewares []HttpMiddleware + contentType string + httpURI string } // Dial establishes a connection to the KMIP server at the specified address using the provided options. @@ -399,6 +448,25 @@ func DialContext(ctx context.Context, addr string, options ...Option) (*Client, opts.supportedVersions = append(opts.supportedVersions, supportedVersions...) } + c := &Client{ + lock: new(sync.Mutex), + supportedVersions: opts.supportedVersions, + version: opts.enforceVersion, + middlewares: opts.middlewares, + HttpMiddlewares: opts.HttpMiddlewares, + addr: addr, + contentType: opts.contentType, + httpURI: opts.httpURI, + } + + if c.httpURI == "" { + c.httpURI = "/kmip" + } + + if c.contentType == "" { + c.contentType = "application/octet-stream" + } + netDial := opts.dialer if netDial == nil { tlsCfg, err := opts.tlsConfig() @@ -413,27 +481,37 @@ func DialContext(ctx context.Context, addr string, options ...Option) (*Client, } } - dialer := func(ctx context.Context) (*conn, error) { - conn, err := netDial(ctx, addr) + if opts.enabledHttp { + c.httpClient = &http.Client{ + Transport: &http.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := netDial(ctx, addr) + if err != nil { + return nil, err + } + return conn, nil + }, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + } + } else { + dialer := func(ctx context.Context) (*conn, error) { + conn, err := netDial(ctx, addr) + if err != nil { + return nil, err + } + return newConn(conn), nil + } + + stream, err := dialer(ctx) if err != nil { return nil, err } - return newConn(conn), nil - } - - stream, err := dialer(ctx) - if err != nil { - return nil, err - } - - c := &Client{ - lock: new(sync.Mutex), - conn: stream, - dialer: dialer, - supportedVersions: opts.supportedVersions, - version: opts.enforceVersion, - middlewares: opts.middlewares, - addr: addr, + c.conn = stream + c.dialer = dialer } // Negotiate protocol version @@ -457,20 +535,31 @@ func (c *Client) Clone() (*Client, error) { // // Cloning a closed client is valid and will create a new connected client. func (c *Client) CloneCtx(ctx context.Context) (*Client, error) { - stream, err := c.dialer(ctx) - if err != nil { - return nil, err - } version := *c.version - return &Client{ + clone := &Client{ lock: new(sync.Mutex), version: &version, supportedVersions: slices.Clone(c.supportedVersions), dialer: c.dialer, middlewares: slices.Clone(c.middlewares), - conn: stream, + HttpMiddlewares: slices.Clone(c.HttpMiddlewares), addr: c.addr, - }, nil + contentType: c.contentType, + httpURI: c.httpURI, + } + if c.conn != nil { + stream, err := c.dialer(ctx) + if err != nil { + return nil, err + } + clone.conn = stream + } + if c.httpClient != nil { + httpClient := *c.httpClient + clone.httpClient = &httpClient + } + + return clone, nil } // Version returns the KMIP protocol version used by the client. @@ -486,7 +575,10 @@ func (c *Client) Addr() string { // Close terminates the client's connection and releases any associated resources. // It returns an error if the connection could not be closed. func (c *Client) Close() error { - return c.conn.Close() + if c.conn != nil { + return c.conn.Close() + } + return nil } func (c *Client) reconnect(ctx context.Context) error { @@ -509,6 +601,62 @@ func (c *Client) reconnect(ctx context.Context) error { // io.EOF and io.ErrClosedPipe, attempting to reconnect and resend the request up to three times before failing. // Returns the response message on success, or an error if the operation ultimately fails. func (c *Client) doRountrip(ctx context.Context, msg *kmip.RequestMessage) (*kmip.ResponseMessage, error) { + if c.httpClient != nil { + url := url.URL{ + Scheme: "https", + Host: c.addr, + Path: c.httpURI, + } + + var unmarshaller func(data []byte, ptr any) error + var marshaller func(data any) []byte + switch c.contentType { + case "text/xml": + unmarshaller = ttlv.UnmarshalXML + marshaller = ttlv.MarshalXML + case "application/json": + unmarshaller = ttlv.UnmarshalJSON + marshaller = ttlv.MarshalJSON + case "application/octet-stream": + unmarshaller = ttlv.UnmarshalTTLV + marshaller = ttlv.MarshalTTLV + default: + return nil, fmt.Errorf("Unsupported Content-Type header: %s", c.contentType) + } + + data := marshaller(msg) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", c.contentType) + req.Header.Set("Cache-Control", "no-cache") + + i := 0 + if i < len(c.HttpMiddlewares) { + mdl := c.HttpMiddlewares[i] + err = mdl(req) + if err != nil { + return nil, err + } + } + + httpResp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer httpResp.Body.Close() + + data, err = io.ReadAll(httpResp.Body) + if err != nil { + return nil, err + } + + var resp kmip.ResponseMessage + err = unmarshaller(data, &resp) + return &resp, err + } + c.lock.Lock() defer c.lock.Unlock() if c.conn == nil { diff --git a/kmipclient/client_test.go b/kmipclient/client_test.go index a786867..119944b 100644 --- a/kmipclient/client_test.go +++ b/kmipclient/client_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/http" "os" "sync" "testing" @@ -395,3 +396,33 @@ func TestWithDialerUnsafe(t *testing.T) { assert.True(t, called) assert.ErrorIs(t, err, retErr) } + +func TestWithHttpMiddlewares(t *testing.T) { + router := kmipserver.NewBatchExecutor() + router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) { + return &payloads.DiscoverVersionsResponsePayload{ + ProtocolVersion: []kmip.ProtocolVersion{ + kmip.V1_3, kmip.V1_2, + }, + }, nil + })) + addr, ca := kmiptest.NewHttpServer(t, router) + client, err := kmipclient.Dial( + addr, + kmipclient.WithRootCAPem([]byte(ca)), + kmipclient.WithEnabledHttp(), + kmipclient.WithHttpMiddlewares( + func(req *http.Request) error { + req.Header.Set("X-My-Header", "myvalue") + return nil + }, + ), + kmipclient.WithMiddlewares( + kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML), + ), + kmipclient.WithKmipVersions(kmip.V1_2, kmip.V1_3), + ) + require.NoError(t, err) + require.NotNil(t, client) + require.EqualValues(t, client.Version(), kmip.V1_3) +} diff --git a/kmipclient/middlewares.go b/kmipclient/middlewares.go index 319ddf0..5c22ee6 100644 --- a/kmipclient/middlewares.go +++ b/kmipclient/middlewares.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/http" "time" "github.com/ovh/kmip-go" @@ -21,6 +22,9 @@ type Next func(context.Context, *kmip.RequestMessage) (*kmip.ResponseMessage, er // cross-cutting concerns such as logging, authentication, or error handling. type Middleware func(next Next, ctx context.Context, msg *kmip.RequestMessage) (*kmip.ResponseMessage, error) +// HttpMiddleware defines a function type that wraps the processing of an HTTP request. +type HttpMiddleware func(req *http.Request) error + // DebugMiddleware returns a Middleware that logs the KMIP request and response messages // to the specified io.Writer. The messages are marshaled using the provided marshal // function, or ttlv.MarshalXML if marshal is nil. The middleware also logs the duration diff --git a/kmiptest/clientserver.go b/kmiptest/clientserver.go index 0b78d87..6311651 100644 --- a/kmiptest/clientserver.go +++ b/kmiptest/clientserver.go @@ -11,6 +11,7 @@ package kmiptest import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -21,6 +22,7 @@ import ( "fmt" "math/big" "net" + "net/http" "os" "time" @@ -97,6 +99,61 @@ func NewServer(t TestingT, hdl kmipserver.RequestHandler) (addr, ca string) { return list.Addr().String(), string(pemCA) } +// NewHttpServer starts a new in-memory KMIP HTTP server for testing purposes using the provided +// kmipserver.RequestHandler. It generates a self-signed ECDSA certificate for the server, +// listens on a random local port, and starts serving requests in a separate goroutine. +// The server is automatically shut down when the test completes. The function returns +// the server's address and the PEM-encoded CA certificate as strings. +// +// Parameters: +// - t: A TestingT instance (typically *testing.T or *testing.B) used for test assertions and cleanup. +// - hdl: The kmipserver.RequestHandler to handle incoming requests. +// +// Returns: +// - addr: The address the server is listening on (e.g., "127.0.0.1:port"). +// - ca: The PEM-encoded CA certificate used by the server. +func NewHttpServer(t TestingT, hdl kmipserver.RequestHandler) (addr, ca string) { + caTpl := x509.Certificate{ + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + SerialNumber: big.NewInt(2), + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, + NotAfter: time.Now().AddDate(1, 0, 0), + } + + k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + cert, err := x509.CreateCertificate(rand.Reader, &caTpl, &caTpl, k.Public(), k) + require.NoError(t, err) + + list, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{{Certificate: [][]byte{cert}, PrivateKey: k}}, + MinVersion: tls.VersionTLS12, + }) + require.NoError(t, err) + + mux := http.NewServeMux() + mux.Handle("/kmip", kmipserver.NewHTTPHandler(hdl)) + + srv := &http.Server{ + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + } + go func() { + if err := srv.Serve(list); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Errorf("server error: %w", err) + } + }() + t.Cleanup(func() { + if err := srv.Shutdown(context.Background()); err != nil { + t.Errorf("server failed to shutdown: %w", err) + } + }) + + pemCA := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert}) + return list.Addr().String(), string(pemCA) +} + // NewClientAndServer creates and returns a new KMIP client connected to a test server. // // It starts a new KMIP server using the provided request handler `hdl` and establishes a client @@ -127,3 +184,36 @@ func NewClientAndServer(t TestingT, hdl kmipserver.RequestHandler) *kmipclient.C }) return client } + +// NewHttpClientAndServer creates and returns a new KMIP HTTP client connected to a test server. +// +// It starts a new KMIP HTTP server using the provided request handler `hdl` and establishes a client +// connection to it. The client is configured with the server's CA certificate, a correlation value +// middleware, a testing middleware, and a debug middleware that outputs to stderr in XML format. +// +// The function registers a cleanup function to close the client connection when the test completes. +// +// Parameters: +// - t: A testing interface used for assertions and cleanup registration. +// - hdl: The KMIP server request handler. +// +// Returns: +// - A pointer to the initialized KMIP HTTP client. +// +// The function will fail the test if the client cannot be created. +func NewHttpClientAndServer(t TestingT, hdl kmipserver.RequestHandler) *kmipclient.Client { + addr, ca := NewHttpServer(t, hdl) + client, err := kmipclient.Dial(addr, kmipclient.WithRootCAPem([]byte(ca)), + kmipclient.WithEnabledHttp(), + kmipclient.WithMiddlewares( + kmipclient.CorrelationValueMiddleware(newRequestId), + TestingMiddleware(t), + kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML), + )) + require.NoError(t, err) + require.NotNil(t, client) + t.Cleanup(func() { + _ = client.Close() + }) + return client +} diff --git a/kmiptest/clientserver_test.go b/kmiptest/clientserver_test.go index 1fb9caf..3dcf1f1 100644 --- a/kmiptest/clientserver_test.go +++ b/kmiptest/clientserver_test.go @@ -16,3 +16,10 @@ func TestClientServer(t *testing.T) { require.NoError(t, err) require.NotNil(t, resp) } + +func TestHttpClientServer(t *testing.T) { + client := NewHttpClientAndServer(t, kmipserver.NewBatchExecutor()) + resp, err := client.Request(context.Background(), &payloads.DiscoverVersionsRequestPayload{}) + require.NoError(t, err) + require.NotNil(t, resp) +}