diff --git a/grpcgcp/gcp_multiendpoint.go b/grpcgcp/gcp_multiendpoint.go index a268762..a2e6539 100644 --- a/grpcgcp/gcp_multiendpoint.go +++ b/grpcgcp/gcp_multiendpoint.go @@ -33,13 +33,6 @@ import ( pb "github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp/grpc_gcp" ) -var ( - // To be redefined in tests. - grpcDial = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { - return grpc.Dial(target, opts...) - } -) - var gmeCounter uint32 type contextMEKey int @@ -124,13 +117,14 @@ type GCPMultiEndpoint struct { pools map[string]*monitoredConn opts []grpc.DialOption gcpConfig *pb.ApiConfig + dialFunc func(ctx context.Context, target string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) log grpclog.LoggerV2 grpc.ClientConnInterface } // Make sure GcpMultiEndpoint implements grpc.ClientConnInterface. -var _ grpc.ClientConnInterface = &GCPMultiEndpoint{} +var _ grpc.ClientConnInterface = (*GCPMultiEndpoint)(nil) func (gme *GCPMultiEndpoint) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { return gme.pickConn(ctx).Invoke(ctx, method, args, reply, opts...) @@ -173,6 +167,8 @@ type GCPMultiEndpointOptions struct { MultiEndpoints map[string]*multiendpoint.MultiEndpointOptions // Name of the default MultiEndpoint. Default string + // Func to dial grpc ClientConn. + DialFunc func(ctx context.Context, target string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) } // NewGcpMultiEndpoint creates new [GCPMultiEndpoint] -- MultiEndpoints-enabled gRPC client @@ -192,8 +188,14 @@ func NewGcpMultiEndpoint(meOpts *GCPMultiEndpointOptions, opts ...grpc.DialOptio defaultName: meOpts.Default, opts: o, gcpConfig: meOpts.GRPCgcpConfig, + dialFunc: meOpts.DialFunc, log: NewGCPLogger(compLogger, fmt.Sprintf("[GCPMultiEndpoint #%d]", atomic.AddUint32(&gmeCounter, 1))), } + if gme.dialFunc == nil { + gme.dialFunc = func(_ context.Context, target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + return grpc.Dial(target, opts...) + } + } if err := gme.UpdateMultiEndpoints(meOpts); err != nil { return nil, err } @@ -288,7 +290,7 @@ func (gme *GCPMultiEndpoint) UpdateMultiEndpoints(meOpts *GCPMultiEndpointOption for e := range validPools { if _, ok := gme.pools[e]; !ok { // This creates a ClientConn with the gRPC-GCP balancer managing connection pool. - conn, err := grpcDial(e, gme.opts...) + conn, err := gme.dialFunc(context.Background(), e, gme.opts...) if err != nil { return err } diff --git a/grpcgcp/test_grpc/gcp_multiendpoint_test.go b/grpcgcp/test_grpc/gcp_multiendpoint_test.go index e2208d1..b653b10 100644 --- a/grpcgcp/test_grpc/gcp_multiendpoint_test.go +++ b/grpcgcp/test_grpc/gcp_multiendpoint_test.go @@ -21,6 +21,7 @@ package test_grpc import ( "context" "net" + "sync/atomic" "testing" "time" @@ -701,3 +702,63 @@ func TestGcpMultiEndpointInstantShutdown(t *testing.T) { // Closing GcpMultiEndpoint immediately should not cause panic. conn.Close() } + +func TestGcpMultiEndpointDialFunc(t *testing.T) { + + lEndpoint, fEndpoint := "localhost:50051", "127.0.0.3:50051" + + defaultME, followerME := "default", "follower" + + apiCfg := &configpb.ApiConfig{ + ChannelPool: &configpb.ChannelPoolConfig{ + MinSize: 3, + MaxSize: 3, + }, + } + + dialUsedFor := make(map[string]*atomic.Int32) + dialUsedFor[lEndpoint] = &atomic.Int32{} + dialUsedFor[fEndpoint] = &atomic.Int32{} + + conn, err := grpcgcp.NewGcpMultiEndpoint( + &grpcgcp.GCPMultiEndpointOptions{ + GRPCgcpConfig: apiCfg, + MultiEndpoints: map[string]*multiendpoint.MultiEndpointOptions{ + defaultME: { + Endpoints: []string{lEndpoint, fEndpoint}, + }, + followerME: { + Endpoints: []string{fEndpoint, lEndpoint}, + }, + }, + Default: defaultME, + DialFunc: func(ctx context.Context, target string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) { + dialUsedFor[target].Add(1) + return grpc.DialContext(ctx, target, dopts...) + }, + }, + grpc.WithInsecure(), + ) + + if err != nil { + t.Fatalf("NewMultiEndpointConn returns unexpected error: %v", err) + } + + defer conn.Close() + c := pb.NewGreeterClient(conn) + tc := &testingClient{ + c: c, + t: t, + } + + // Make a call to make sure GCPMultiEndpoint is up and running. + tc.SayHelloWorks(context.Background(), lEndpoint) + + if got, want := dialUsedFor[lEndpoint].Load(), int32(1); got != want { + t.Fatalf("provided dial function was called for %q endpoint %v times, want %v times", lEndpoint, got, want) + } + + if got, want := dialUsedFor[fEndpoint].Load(), int32(1); got != want { + t.Fatalf("provided dial function was called for %q endpoint %v times, want %v times", fEndpoint, got, want) + } +}