Skip to content

Commit

Permalink
Merge pull request #66 from GoogleCloudPlatform/dialfunction
Browse files Browse the repository at this point in the history
feat: allow providing custom dial function to GCPMultiEndpoint
  • Loading branch information
nimf authored Jan 22, 2024
2 parents 2271df0 + 043a235 commit aba71ce
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 9 deletions.
20 changes: 11 additions & 9 deletions grpcgcp/gcp_multiendpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ import (
pb "github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp/grpc_gcp"
)

var (
// To be redefined in tests.
grpcDial = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
return grpc.Dial(target, opts...)
}
)

var gmeCounter uint32

type contextMEKey int
Expand Down Expand Up @@ -124,13 +117,14 @@ type GCPMultiEndpoint struct {
pools map[string]*monitoredConn
opts []grpc.DialOption
gcpConfig *pb.ApiConfig
dialFunc func(ctx context.Context, target string, dopts ...grpc.DialOption) (*grpc.ClientConn, error)
log grpclog.LoggerV2

grpc.ClientConnInterface
}

// Make sure GcpMultiEndpoint implements grpc.ClientConnInterface.
var _ grpc.ClientConnInterface = &GCPMultiEndpoint{}
var _ grpc.ClientConnInterface = (*GCPMultiEndpoint)(nil)

func (gme *GCPMultiEndpoint) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
return gme.pickConn(ctx).Invoke(ctx, method, args, reply, opts...)
Expand Down Expand Up @@ -173,6 +167,8 @@ type GCPMultiEndpointOptions struct {
MultiEndpoints map[string]*multiendpoint.MultiEndpointOptions
// Name of the default MultiEndpoint.
Default string
// Func to dial grpc ClientConn.
DialFunc func(ctx context.Context, target string, dopts ...grpc.DialOption) (*grpc.ClientConn, error)
}

// NewGcpMultiEndpoint creates new [GCPMultiEndpoint] -- MultiEndpoints-enabled gRPC client
Expand All @@ -192,8 +188,14 @@ func NewGcpMultiEndpoint(meOpts *GCPMultiEndpointOptions, opts ...grpc.DialOptio
defaultName: meOpts.Default,
opts: o,
gcpConfig: meOpts.GRPCgcpConfig,
dialFunc: meOpts.DialFunc,
log: NewGCPLogger(compLogger, fmt.Sprintf("[GCPMultiEndpoint #%d]", atomic.AddUint32(&gmeCounter, 1))),
}
if gme.dialFunc == nil {
gme.dialFunc = func(_ context.Context, target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
return grpc.Dial(target, opts...)
}
}
if err := gme.UpdateMultiEndpoints(meOpts); err != nil {
return nil, err
}
Expand Down Expand Up @@ -288,7 +290,7 @@ func (gme *GCPMultiEndpoint) UpdateMultiEndpoints(meOpts *GCPMultiEndpointOption
for e := range validPools {
if _, ok := gme.pools[e]; !ok {
// This creates a ClientConn with the gRPC-GCP balancer managing connection pool.
conn, err := grpcDial(e, gme.opts...)
conn, err := gme.dialFunc(context.Background(), e, gme.opts...)
if err != nil {
return err
}
Expand Down
61 changes: 61 additions & 0 deletions grpcgcp/test_grpc/gcp_multiendpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package test_grpc
import (
"context"
"net"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -701,3 +702,63 @@ func TestGcpMultiEndpointInstantShutdown(t *testing.T) {
// Closing GcpMultiEndpoint immediately should not cause panic.
conn.Close()
}

func TestGcpMultiEndpointDialFunc(t *testing.T) {

lEndpoint, fEndpoint := "localhost:50051", "127.0.0.3:50051"

defaultME, followerME := "default", "follower"

apiCfg := &configpb.ApiConfig{
ChannelPool: &configpb.ChannelPoolConfig{
MinSize: 3,
MaxSize: 3,
},
}

dialUsedFor := make(map[string]*atomic.Int32)
dialUsedFor[lEndpoint] = &atomic.Int32{}
dialUsedFor[fEndpoint] = &atomic.Int32{}

conn, err := grpcgcp.NewGcpMultiEndpoint(
&grpcgcp.GCPMultiEndpointOptions{
GRPCgcpConfig: apiCfg,
MultiEndpoints: map[string]*multiendpoint.MultiEndpointOptions{
defaultME: {
Endpoints: []string{lEndpoint, fEndpoint},
},
followerME: {
Endpoints: []string{fEndpoint, lEndpoint},
},
},
Default: defaultME,
DialFunc: func(ctx context.Context, target string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
dialUsedFor[target].Add(1)
return grpc.DialContext(ctx, target, dopts...)
},
},
grpc.WithInsecure(),
)

if err != nil {
t.Fatalf("NewMultiEndpointConn returns unexpected error: %v", err)
}

defer conn.Close()
c := pb.NewGreeterClient(conn)
tc := &testingClient{
c: c,
t: t,
}

// Make a call to make sure GCPMultiEndpoint is up and running.
tc.SayHelloWorks(context.Background(), lEndpoint)

if got, want := dialUsedFor[lEndpoint].Load(), int32(1); got != want {
t.Fatalf("provided dial function was called for %q endpoint %v times, want %v times", lEndpoint, got, want)
}

if got, want := dialUsedFor[fEndpoint].Load(), int32(1); got != want {
t.Fatalf("provided dial function was called for %q endpoint %v times, want %v times", fEndpoint, got, want)
}
}

0 comments on commit aba71ce

Please sign in to comment.