Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 174 additions & 26 deletions kmipclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,20 @@
package kmipclient

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"slices"
"sync"
"time"

"github.com/ovh/kmip-go"
"github.com/ovh/kmip-go/payloads"
Expand All @@ -60,6 +64,7 @@ var supportedVersions = []kmip.ProtocolVersion{kmip.V1_4, kmip.V1_3, kmip.V1_2,

type opts struct {
middlewares []Middleware
HttpMiddlewares []HttpMiddleware
supportedVersions []kmip.ProtocolVersion
enforceVersion *kmip.ProtocolVersion
rootCAs [][]byte
Expand All @@ -68,6 +73,9 @@ type opts struct {
tlsCfg *tls.Config
tlsCiphers []uint16
dialer func(context.Context, string) (net.Conn, error)
enabledHttp bool
contentType string
httpURI string
//TODO: Add KMIP Authentication / Credentials
//TODO: Overwrite default/preferred/supported key formats for register
}
Expand Down Expand Up @@ -137,6 +145,19 @@ func WithMiddlewares(middlewares ...Middleware) Option {
}
}

// WithHttpMiddlewares returns an Option that appends the provided HttpMiddleware(s) to the client's HTTP middleware chain.
// This allows customization of the client's HTTP behavior by injecting additional processing steps.
//
// Usage:
//
// client.New(WithHttpMiddlewares(mw1, mw2, ...))
func WithHttpMiddlewares(middlewares ...HttpMiddleware) Option {
return func(o *opts) error {
o.HttpMiddlewares = append(o.HttpMiddlewares, middlewares...)
return nil
}
}

// WithKmipVersions returns an Option that sets the supported KMIP protocol versions for the client.
// It appends the provided versions to the existing list, sorts them in descending order,
// and removes any duplicate versions. This allows the client to negotiate the highest mutually
Expand Down Expand Up @@ -353,6 +374,30 @@ func WithDialerUnsafe(dialer func(ctx context.Context, addr string) (net.Conn, e
}
}

// WithEnabledHttp returns an Option that enables HTTP support for the client.
func WithEnabledHttp() Option {
return func(o *opts) error {
o.enabledHttp = true
return nil
}
}

// WithContentType returns an Option that sets the content type for the client.
func WithContentType(contentType string) Option {
return func(o *opts) error {
o.contentType = contentType
return nil
}
}

// WithHttpURI returns an Option that sets the HTTP URI for the client.
func WithHttpURI(uri string) Option {
return func(o *opts) error {
o.httpURI = uri
return nil
}
}

// Client represents a KMIP client that manages a connection to a KMIP server,
// handles protocol version negotiation, and supports middleware for request/response
// processing. It provides thread-safe access to the underlying connection and
Expand All @@ -365,6 +410,10 @@ type Client struct {
dialer func(context.Context) (*conn, error)
middlewares []Middleware
addr string
httpClient *http.Client
HttpMiddlewares []HttpMiddleware
contentType string
httpURI string
}

// Dial establishes a connection to the KMIP server at the specified address using the provided options.
Expand Down Expand Up @@ -399,6 +448,25 @@ func DialContext(ctx context.Context, addr string, options ...Option) (*Client,
opts.supportedVersions = append(opts.supportedVersions, supportedVersions...)
}

c := &Client{
lock: new(sync.Mutex),
supportedVersions: opts.supportedVersions,
version: opts.enforceVersion,
middlewares: opts.middlewares,
HttpMiddlewares: opts.HttpMiddlewares,
addr: addr,
contentType: opts.contentType,
httpURI: opts.httpURI,
}

if c.httpURI == "" {
c.httpURI = "/kmip"
}

if c.contentType == "" {
c.contentType = "application/octet-stream"
}

netDial := opts.dialer
if netDial == nil {
tlsCfg, err := opts.tlsConfig()
Expand All @@ -413,27 +481,37 @@ func DialContext(ctx context.Context, addr string, options ...Option) (*Client,
}
}

dialer := func(ctx context.Context) (*conn, error) {
conn, err := netDial(ctx, addr)
if opts.enabledHttp {
c.httpClient = &http.Client{
Transport: &http.Transport{
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := netDial(ctx, addr)
if err != nil {
return nil, err
}
return conn, nil
},
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}
} else {
dialer := func(ctx context.Context) (*conn, error) {
conn, err := netDial(ctx, addr)
if err != nil {
return nil, err
}
return newConn(conn), nil
}

stream, err := dialer(ctx)
if err != nil {
return nil, err
}
return newConn(conn), nil
}

stream, err := dialer(ctx)
if err != nil {
return nil, err
}

c := &Client{
lock: new(sync.Mutex),
conn: stream,
dialer: dialer,
supportedVersions: opts.supportedVersions,
version: opts.enforceVersion,
middlewares: opts.middlewares,
addr: addr,
c.conn = stream
c.dialer = dialer
}

// Negotiate protocol version
Expand All @@ -457,20 +535,31 @@ func (c *Client) Clone() (*Client, error) {
//
// Cloning a closed client is valid and will create a new connected client.
func (c *Client) CloneCtx(ctx context.Context) (*Client, error) {
stream, err := c.dialer(ctx)
if err != nil {
return nil, err
}
version := *c.version
return &Client{
clone := &Client{
lock: new(sync.Mutex),
version: &version,
supportedVersions: slices.Clone(c.supportedVersions),
dialer: c.dialer,
middlewares: slices.Clone(c.middlewares),
conn: stream,
HttpMiddlewares: slices.Clone(c.HttpMiddlewares),
addr: c.addr,
}, nil
contentType: c.contentType,
httpURI: c.httpURI,
}
if c.conn != nil {
stream, err := c.dialer(ctx)
if err != nil {
return nil, err
}
clone.conn = stream
}
if c.httpClient != nil {
httpClient := *c.httpClient
clone.httpClient = &httpClient
}

return clone, nil
}

// Version returns the KMIP protocol version used by the client.
Expand All @@ -486,7 +575,10 @@ func (c *Client) Addr() string {
// Close terminates the client's connection and releases any associated resources.
// It returns an error if the connection could not be closed.
func (c *Client) Close() error {
return c.conn.Close()
if c.conn != nil {
return c.conn.Close()
}
return nil
}

func (c *Client) reconnect(ctx context.Context) error {
Expand All @@ -509,6 +601,62 @@ func (c *Client) reconnect(ctx context.Context) error {
// io.EOF and io.ErrClosedPipe, attempting to reconnect and resend the request up to three times before failing.
// Returns the response message on success, or an error if the operation ultimately fails.
func (c *Client) doRountrip(ctx context.Context, msg *kmip.RequestMessage) (*kmip.ResponseMessage, error) {
if c.httpClient != nil {
url := url.URL{
Scheme: "https",
Host: c.addr,
Path: c.httpURI,
}

var unmarshaller func(data []byte, ptr any) error
var marshaller func(data any) []byte
switch c.contentType {
case "text/xml":
unmarshaller = ttlv.UnmarshalXML
marshaller = ttlv.MarshalXML
case "application/json":
unmarshaller = ttlv.UnmarshalJSON
marshaller = ttlv.MarshalJSON
case "application/octet-stream":
unmarshaller = ttlv.UnmarshalTTLV
marshaller = ttlv.MarshalTTLV
default:
return nil, fmt.Errorf("Unsupported Content-Type header: %s", c.contentType)
}

data := marshaller(msg)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", c.contentType)
req.Header.Set("Cache-Control", "no-cache")

i := 0
if i < len(c.HttpMiddlewares) {
mdl := c.HttpMiddlewares[i]
err = mdl(req)
if err != nil {
return nil, err
}
}

httpResp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer httpResp.Body.Close()

data, err = io.ReadAll(httpResp.Body)
if err != nil {
return nil, err
}

var resp kmip.ResponseMessage
err = unmarshaller(data, &resp)
return &resp, err
}

c.lock.Lock()
defer c.lock.Unlock()
if c.conn == nil {
Expand Down
31 changes: 31 additions & 0 deletions kmipclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net"
"net/http"
"os"
"sync"
"testing"
Expand Down Expand Up @@ -395,3 +396,33 @@ func TestWithDialerUnsafe(t *testing.T) {
assert.True(t, called)
assert.ErrorIs(t, err, retErr)
}

func TestWithHttpMiddlewares(t *testing.T) {
router := kmipserver.NewBatchExecutor()
router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) {
return &payloads.DiscoverVersionsResponsePayload{
ProtocolVersion: []kmip.ProtocolVersion{
kmip.V1_3, kmip.V1_2,
},
}, nil
}))
addr, ca := kmiptest.NewHttpServer(t, router)
client, err := kmipclient.Dial(
addr,
kmipclient.WithRootCAPem([]byte(ca)),
kmipclient.WithEnabledHttp(),
kmipclient.WithHttpMiddlewares(
func(req *http.Request) error {
req.Header.Set("X-My-Header", "myvalue")
return nil
},
),
kmipclient.WithMiddlewares(
kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML),
),
kmipclient.WithKmipVersions(kmip.V1_2, kmip.V1_3),
)
require.NoError(t, err)
require.NotNil(t, client)
require.EqualValues(t, client.Version(), kmip.V1_3)
}
4 changes: 4 additions & 0 deletions kmipclient/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"net/http"
"time"

"github.com/ovh/kmip-go"
Expand All @@ -21,6 +22,9 @@ type Next func(context.Context, *kmip.RequestMessage) (*kmip.ResponseMessage, er
// cross-cutting concerns such as logging, authentication, or error handling.
type Middleware func(next Next, ctx context.Context, msg *kmip.RequestMessage) (*kmip.ResponseMessage, error)

// HttpMiddleware defines a function type that wraps the processing of an HTTP request.
type HttpMiddleware func(req *http.Request) error

// DebugMiddleware returns a Middleware that logs the KMIP request and response messages
// to the specified io.Writer. The messages are marshaled using the provided marshal
// function, or ttlv.MarshalXML if marshal is nil. The middleware also logs the duration
Expand Down
Loading