diff --git a/multistream.go b/multistream.go index f61b24e..a34a7d1 100644 --- a/multistream.go +++ b/multistream.go @@ -8,7 +8,9 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" + "strings" "sync" ) @@ -23,20 +25,16 @@ const ProtocolID = "/multistream/1.0.0" // handle a protocol/stream. type HandlerFunc func(protocol string, rwc io.ReadWriteCloser) error -// Handler is a wrapper to HandlerFunc which attaches a name (protocol) and a -// match function which can optionally be used to select a handler by other -// means than the name. -type Handler struct { - MatchFunc func(string) bool - Handle HandlerFunc - AddName string +type handler struct { + prefix, exclusive bool + handler HandlerFunc } // MultistreamMuxer is a muxer for multistream. Depending on the stream // protocol tag it will select the right handler and hand the stream off to it. type MultistreamMuxer struct { - handlerlock sync.Mutex - handlers []Handler + handlerlock sync.RWMutex + handlers map[string]*handler } // NewMultistreamMuxer creates a muxer. @@ -107,56 +105,134 @@ func Ls(rw io.ReadWriter) ([]string, error) { return out, nil } -func fulltextMatch(s string) func(string) bool { - return func(a string) bool { - return a == s +type Options struct { + Prefix, Override, Exclusive bool +} + +func (opts *Options) Apply(options ...Option) error { + for _, opt := range options { + if err := opt(opts); err != nil { + return err + } } + return nil } -// AddHandler attaches a new protocol handler to the muxer. -func (msm *MultistreamMuxer) AddHandler(protocol string, handler HandlerFunc) { - msm.AddHandlerWithFunc(protocol, fulltextMatch(protocol), handler) +// Option is a stream handler option. +type Option func(*Options) error + +// Prefix configures the protocol handler to handle all protocols prefixed +// with the specified protocol name. +// +// Note: This only works for paths. That is, `/a/b` is a protocol-prefix of +// `/a/b/c` but not `/a/bad`. +// +// Defaults to false. +func Prefix(isPrefix bool) Option { + return func(opts *Options) error { + opts.Prefix = isPrefix + return nil + } } -// AddHandlerWithFunc attaches a new protocol handler to the muxer with a match. -// If the match function returns true for a given protocol tag, the protocol -// will be selected even if the handler name and protocol tags are different. -func (msm *MultistreamMuxer) AddHandlerWithFunc(protocol string, match func(string) bool, handler HandlerFunc) { - msm.handlerlock.Lock() - msm.removeHandler(protocol) - msm.handlers = append(msm.handlers, Handler{ - MatchFunc: match, - Handle: handler, - AddName: protocol, - }) - msm.handlerlock.Unlock() +// Exclusive configures the protocol handler as the *exclusive* handler for the +// specified protocol and all sub-protocols. +// +// Defaults to false. +func Exclusive(exclusive bool) Option { + return func(opts *Options) error { + opts.Exclusive = exclusive + return nil + } } -// RemoveHandler removes the handler with the given name from the muxer. -func (msm *MultistreamMuxer) RemoveHandler(protocol string) { +// Override configures the protocol handler to *override* any existing protocol +// handlers. +// +// If the Exclusive option is passed, any sub-protocol handlers will be +// removed before registering this protocol handler. +// +// Protocol registration will fail if there exists a protocol that: +// * Is a strict prefix of this protocol. +// * Has the Exclusive option set. +// Not doing so would either (a) leave some protocols unhandled or (b) break the expected behavior of Exclusive +// +// Defaults to false. +func Override(override bool) Option { + return func(opts *Options) error { + opts.Override = override + return nil + } +} + +// AddHandler attaches a new protocol handler to the muxer, overriding any +// existing handlers. +func (msm *MultistreamMuxer) AddHandler(protocol string, handlerFunc HandlerFunc, options ...Option) error { + var opts Options + if err := opts.Apply(options...); err != nil { + return err + } msm.handlerlock.Lock() defer msm.handlerlock.Unlock() - msm.removeHandler(protocol) -} + if msm.handlers == nil { + msm.handlers = make(map[string]*handler, 1) + } -func (msm *MultistreamMuxer) removeHandler(protocol string) { - for i, h := range msm.handlers { - if h.AddName == protocol { - msm.handlers = append(msm.handlers[:i], msm.handlers[i+1:]...) - return + if _, ok := msm.handlers[protocol]; ok { + // Handle exact match case. + if !opts.Override { + // If we haven't specified override, bail. + return fmt.Errorf("protocol %q already registered", protocol) + } + delete(msm.handlers, protocol) + + // If we have successfully overridden the protocol, we *know* + // there can't be any exclusive prefixes registered. + } else if currentName, currentHandler := msm.findHandlerLocked(protocol); currentHandler != nil && currentHandler.exclusive { + // If there *is* an exclusive strict-prefix registered, we can't + // register this protocol (even if we've specified override), + // bail. + return fmt.Errorf("when registering protocol %q, found conflicting exclusive protocol %q", protocol, currentName) + } + + // If we're registering this protocol exclusively, check for any + // already-registered sub-protocols. + if opts.Exclusive { + prefix := protocol + "/" + for existing := range msm.handlers { + if !strings.HasPrefix(existing, prefix) { + continue + } + if !opts.Override { + return fmt.Errorf("when registering exclusive protocol %q, found conflicting protocol %q", protocol, existing) + } + delete(msm.handlers, existing) } } + msm.handlers[protocol] = &handler{ + handler: handlerFunc, + prefix: opts.Prefix, + exclusive: opts.Exclusive, + } + return nil +} + +// RemoveHandler removes the handler with the given name from the muxer. +func (msm *MultistreamMuxer) RemoveHandler(protocol string) { + msm.handlerlock.Lock() + defer msm.handlerlock.Unlock() + delete(msm.handlers, protocol) } // Protocols returns the list of handler-names added to this this muxer. func (msm *MultistreamMuxer) Protocols() []string { - var out []string - msm.handlerlock.Lock() - for _, h := range msm.handlers { - out = append(out, h.AddName) + msm.handlerlock.RLock() + out := make([]string, 0, len(msm.handlers)) + for k := range msm.handlers { + out = append(out, k) } - msm.handlerlock.Unlock() + msm.handlerlock.RUnlock() return out } @@ -164,17 +240,36 @@ func (msm *MultistreamMuxer) Protocols() []string { // fails because of a ProtocolID mismatch. var ErrIncorrectVersion = errors.New("client connected with incorrect version") -func (msm *MultistreamMuxer) findHandler(proto string) *Handler { +// findHandler tries to find a handler for the given protocol +func (msm *MultistreamMuxer) findHandler(proto string) (name string, h *handler) { msm.handlerlock.Lock() defer msm.handlerlock.Unlock() + return msm.findHandlerLocked(proto) +} - for _, h := range msm.handlers { - if h.MatchFunc(proto) { - return &h +// findHandlerLocked is a version of findHandler that expects the lock to already have been taken. +func (msm *MultistreamMuxer) findHandlerLocked(proto string) (name string, h *handler) { + handler, ok := msm.handlers[proto] + if ok { + return name, handler + } + for { + idx := strings.LastIndexByte(proto, '/') + if idx < 0 { + break } + proto = proto[:idx] + handler, ok := msm.handlers[proto] + if !ok { + continue + } + // Found a handler but it doesn't handle sub-protocols, bailing. + if !handler.prefix { + break + } + return name, handler } - - return nil + return "", nil } // NegotiateLazy performs protocol selection and returns @@ -240,7 +335,7 @@ loop: return nil, "", nil, err } default: - h := msm.findHandler(tok) + _, h := msm.findHandler(tok) if h == nil { select { case pval <- "na": @@ -260,7 +355,7 @@ loop: } // hand off processing to the sub-protocol handler - return lzc, tok, h.Handle, nil + return lzc, tok, h.handler, nil } } } @@ -299,7 +394,7 @@ loop: return "", nil, err } default: - h := msm.findHandler(tok) + _, h := msm.findHandler(tok) if h == nil { err := delimWriteBuffered(rwc, []byte("na")) if err != nil { @@ -314,7 +409,7 @@ loop: } // hand off processing to the sub-protocol handler - return tok, h.Handle, nil + return tok, h.handler, nil } } @@ -324,20 +419,21 @@ loop: // supported protocols to the given Writer. func (msm *MultistreamMuxer) Ls(w io.Writer) error { buf := new(bytes.Buffer) - msm.handlerlock.Lock() + msm.handlerlock.RLock() err := writeUvarint(buf, uint64(len(msm.handlers))) if err != nil { + msm.handlerlock.RUnlock() return err } - for _, h := range msm.handlers { - err := delimWrite(buf, []byte(h.AddName)) + for k := range msm.handlers { + err := delimWrite(buf, []byte(k)) if err != nil { - msm.handlerlock.Unlock() + msm.handlerlock.RUnlock() return err } } - msm.handlerlock.Unlock() + msm.handlerlock.RUnlock() ll := make([]byte, 16) nw := binary.PutUvarint(ll, uint64(buf.Len())) diff --git a/multistream_test.go b/multistream_test.go index 5076209..15437e3 100644 --- a/multistream_test.go +++ b/multistream_test.go @@ -442,6 +442,79 @@ func TestHandleFunc(t *testing.T) { verifyPipe(t, a, b) } +func TestAddHandlerOptions(t *testing.T) { + mux := NewMultistreamMuxer() + handlerCalled := make(chan string, 1) + setHandler := func(protocol string, opts ...Option) { + if err := mux.AddHandler(protocol, func(p string, rwc io.ReadWriteCloser) error { + handlerCalled <- protocol + return nil + }, opts...); err != nil { + t.Fatal(err) + } + } + setHandler("/foo", Prefix(true)) + setHandler("/foo/bar") + setHandler("/a/b/c/d") + + if err := mux.AddHandler("/a/b/c", nil, Exclusive(true)); err == nil { + t.Error("expected error registering exclusive, conflicting, protocol") + } + setHandler("/a/b/c", Exclusive(true), Override(true), Prefix(true)) + + if err := mux.AddHandler("/a/b/c/d", nil, Exclusive(true), Override(true)); err == nil { + t.Error("expected error registering exclusive, conflicting, sub-protocol") + } + + for proto, expected := range map[string]string{ + "/foo": "/foo", + "/foo/baz": "/foo", + "/foo/barbar": "/foo", + "/foo/bar": "/foo/bar", + "/a/b/c/d": "/a/b/c", + } { + a, b := newPipe(t) + go func() { + err := SelectProtoOrFail(proto, a) + if err != nil { + t.Error(err) + } + }() + err := mux.Handle(b) + if err != nil { + t.Error(err) + continue + } + verifyPipe(t, a, b) + actual := <-handlerCalled + if actual != expected { + t.Errorf("wrong handler called: expected %s, got %s", expected, actual) + } + a.Close() + b.Close() + } + a, b := newPipe(t) + go func() { + err := SelectProtoOrFail("/foo/bar/baz", a) + a.Close() + if err == nil { + t.Error("protocol selection should have failed") + } + }() + err := mux.Handle(b) + b.Close() + if err == nil { + t.Error("protocol selection should have failed") + } + + if mux.AddHandler("/foo", nil) == nil { + t.Error("should have failed to override handler") + } + if err := mux.AddHandler("/foo", nil, Override(true)); err != nil { + t.Error(err) + } +} + func TestAddHandlerOverride(t *testing.T) { a, b := newPipe(t) @@ -453,7 +526,7 @@ func TestAddHandlerOverride(t *testing.T) { mux.AddHandler("/foo", func(p string, rwc io.ReadWriteCloser) error { return nil - }) + }, Override(true)) go func() { err := SelectProtoOrFail("/foo", a)