Skip to content

Commit

Permalink
Do not use url.Parse to parse the method address
Browse files Browse the repository at this point in the history
  • Loading branch information
nhatthm committed Oct 31, 2021
1 parent 6639f43 commit 060a98c
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 37 deletions.
24 changes: 9 additions & 15 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"fmt"
"io"
"net"
"net/url"
"reflect"
"regexp"
"strings"

"google.golang.org/grpc"
Expand All @@ -17,6 +17,8 @@ import (
grpcReflect "github.com/nhatthm/grpcmock/reflect"
)

var methodRegex = regexp.MustCompile(`/?[^/]+/[^/]+$`)

// ContextDialer is to set up the dialer.
type ContextDialer = func(context.Context, string) (net.Conn, error)

Expand Down Expand Up @@ -161,24 +163,16 @@ func prepInvoke(ctx context.Context, method string, opts ...InvokeOption) (conte
}

func parseMethod(method string) (string, string, error) {
u, err := url.Parse(method)
if err != nil {
return "", "", err
if !methodRegex.MatchString(method) {
return "", "", ErrMalformedMethod
}

method = fmt.Sprintf("/%s", strings.TrimLeft(u.Path, "/"))
addr := methodRegex.ReplaceAllString(method, "")

if method == "/" {
return "", "", ErrMissingMethod
}

addr := url.URL{
Scheme: u.Scheme,
User: u.User,
Host: u.Host,
}
method = strings.Replace(method, addr, "", 1)
method = fmt.Sprintf("/%s", strings.TrimLeft(method, "/"))

return addr.String(), method, nil
return addr, method, nil
}

func invokeOptions(ctx context.Context, opts ...InvokeOption) (context.Context, []grpc.DialOption, []grpc.CallOption) {
Expand Down
51 changes: 39 additions & 12 deletions client_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,56 @@ func TestParseMethod(t *testing.T) {
expectedError string
}{
{
scenario: "method is not valid",
method: "://",
expectedError: `parse "://": missing protocol scheme`,
scenario: "method is empty",
method: "",
expectedError: "malformed method",
},
{
scenario: "method is missing",
method: "tcp://:8000/",
expectedError: `missing method`,
scenario: "method is /",
method: "/",
expectedError: "malformed method",
},
{
scenario: "full url",
method: "tcp://127.0.0.1:8000/server/GetItem",
expectedAddr: "tcp://127.0.0.1:8000",
scenario: "method is //",
method: "//",
expectedError: "malformed method",
},
{
scenario: "missing method",
method: "/server/",
expectedError: "malformed method",
},
{
scenario: "missing service",
method: "//method",
expectedError: "malformed method",
},
{
scenario: "method without leading /",
method: "server/GetItem",
expectedMethod: "/server/GetItem",
},
{
scenario: "method only",
scenario: "method with leading /",
method: "/server/GetItem",
expectedMethod: "/server/GetItem",
},
{
scenario: "relative method",
method: "server/GetItem",
scenario: "method only port",
method: ":9090/server/GetItem",
expectedAddr: ":9090",
expectedMethod: "/server/GetItem",
},
{
scenario: "method with ip and port",
method: "127.0.0.1:9090/server/GetItem",
expectedAddr: "127.0.0.1:9090",
expectedMethod: "/server/GetItem",
},
{
scenario: "method with hostname and port",
method: "localhost:9090/server/GetItem",
expectedAddr: "localhost:9090",
expectedMethod: "/server/GetItem",
},
}
Expand Down
20 changes: 10 additions & 10 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestInvokeUnary_MethodError(t *testing.T) {
t.Parallel()

err := grpcmock.InvokeUnary(context.Background(), "://", nil, nil)
expected := `coulld not parse method url: parse "://": missing protocol scheme`
expected := `coulld not parse method url: malformed method`

assert.EqualError(t, err, expected)
}
Expand All @@ -40,7 +40,7 @@ func TestInvokeUnary_DialError(t *testing.T) {
return nil, errors.New("dial error")
}

err := grpcmock.InvokeUnary(context.Background(), "NotFound", nil, nil,
err := grpcmock.InvokeUnary(context.Background(), "Service/NotFound", nil, nil,
grpcmock.WithContextDialer(dialer),
grpcmock.WithInsecure(),
)
Expand All @@ -52,7 +52,7 @@ func TestInvokeUnary_DialError(t *testing.T) {
func TestInvokeUnary_WithoutInsecure(t *testing.T) {
t.Parallel()

err := grpcmock.InvokeUnary(context.Background(), "NotFound", nil, nil)
err := grpcmock.InvokeUnary(context.Background(), "Service/NotFound", nil, nil)
expected := "grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)"

assert.EqualError(t, err, expected)
Expand Down Expand Up @@ -131,7 +131,7 @@ func TestInvokeServerStream_DialError(t *testing.T) {
return nil, errors.New("dial error")
}

err := grpcmock.InvokeServerStream(context.Background(), "NotFound", nil, nil,
err := grpcmock.InvokeServerStream(context.Background(), "Service/NotFound", nil, nil,
grpcmock.WithContextDialer(dialer),
grpcmock.WithInsecure(),
)
Expand All @@ -143,7 +143,7 @@ func TestInvokeServerStream_DialError(t *testing.T) {
func TestInvokeServerStream_WithoutInsecure(t *testing.T) {
t.Parallel()

err := grpcmock.InvokeServerStream(context.Background(), "NotFound", nil, nil)
err := grpcmock.InvokeServerStream(context.Background(), "Service/NotFound", nil, nil)
expected := "grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)"

assert.EqualError(t, err, expected)
Expand All @@ -154,7 +154,7 @@ func TestInvokeServerStream_NoHandlerShouldBeFine(t *testing.T) {

dialer := testSrv.StartServer(t)

err := grpcmock.InvokeServerStream(context.Background(), "NotFound", nil, nil,
err := grpcmock.InvokeServerStream(context.Background(), "Service/NotFound", nil, nil,
grpcmock.WithContextDialer(dialer),
grpcmock.WithInsecure(),
)
Expand Down Expand Up @@ -257,7 +257,7 @@ func TestInvokeClientStream_DialError(t *testing.T) {
return nil, errors.New("dial error")
}

err := grpcmock.InvokeClientStream(context.Background(), "NotFound", nil, nil,
err := grpcmock.InvokeClientStream(context.Background(), "Service/NotFound", nil, nil,
grpcmock.WithContextDialer(dialer),
grpcmock.WithInsecure(),
)
Expand All @@ -269,7 +269,7 @@ func TestInvokeClientStream_DialError(t *testing.T) {
func TestInvokeClientStream_WithoutInsecure(t *testing.T) {
t.Parallel()

err := grpcmock.InvokeClientStream(context.Background(), "NotFound", nil, nil)
err := grpcmock.InvokeClientStream(context.Background(), "Service/NotFound", nil, nil)
expected := "grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)"

assert.EqualError(t, err, expected)
Expand Down Expand Up @@ -365,7 +365,7 @@ func TestInvokeBidirectionalStream_DialError(t *testing.T) {
return nil, errors.New("dial error")
}

err := grpcmock.InvokeBidirectionalStream(context.Background(), "NotFound", nil,
err := grpcmock.InvokeBidirectionalStream(context.Background(), "Service/NotFound", nil,
grpcmock.WithContextDialer(dialer),
grpcmock.WithInsecure(),
)
Expand All @@ -377,7 +377,7 @@ func TestInvokeBidirectionalStream_DialError(t *testing.T) {
func TestInvokeBidirectionalStream_WithoutInsecure(t *testing.T) {
t.Parallel()

err := grpcmock.InvokeBidirectionalStream(context.Background(), "NotFound", nil)
err := grpcmock.InvokeBidirectionalStream(context.Background(), "Service/NotFound", nil)
expected := "grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)"

assert.EqualError(t, err, expected)
Expand Down
2 changes: 2 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package grpcmock
const (
// ErrMissingMethod indicates that there is no method in the url.
ErrMissingMethod err = "missing method"
// ErrMalformedMethod indicates that the method is malformed.
ErrMalformedMethod err = "malformed method"
)

type err string
Expand Down

0 comments on commit 060a98c

Please sign in to comment.