diff --git a/agent/agent.go b/agent/agent.go index e65cc384..5a3c1d6e 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -26,11 +26,13 @@ import ( e "errors" "fmt" "net" + "os" "strings" "sync" "sync/atomic" "time" + "github.com/topfreegames/pitaya/v2/config" "github.com/topfreegames/pitaya/v2/conn/codec" "github.com/topfreegames/pitaya/v2/conn/message" "github.com/topfreegames/pitaya/v2/conn/packet" @@ -76,6 +78,7 @@ type ( decoder codec.PacketDecoder // binary decoder encoder codec.PacketEncoder // binary encoder heartbeatTimeout time.Duration + writeTimeout time.Duration lastAt int64 // last heartbeat unix time stamp messageEncoder message.Encoder messagesBufferSize int // size of the pending messages buffer @@ -130,6 +133,7 @@ type ( decoder codec.PacketDecoder // binary decoder encoder codec.PacketEncoder // binary encoder heartbeatTimeout time.Duration + writeTimeout time.Duration messageEncoder message.Encoder messagesBufferSize int // size of the pending messages buffer metricsReporters []metrics.Reporter @@ -144,6 +148,7 @@ func NewAgentFactory( encoder codec.PacketEncoder, serializer serialize.Serializer, heartbeatTimeout time.Duration, + writeTimeout time.Duration, messageEncoder message.Encoder, messagesBufferSize int, sessionPool session.SessionPool, @@ -154,6 +159,7 @@ func NewAgentFactory( decoder: decoder, encoder: encoder, heartbeatTimeout: heartbeatTimeout, + writeTimeout: writeTimeout, messageEncoder: messageEncoder, messagesBufferSize: messagesBufferSize, sessionPool: sessionPool, @@ -164,7 +170,7 @@ func NewAgentFactory( // CreateAgent returns a new agent func (f *agentFactoryImpl) CreateAgent(conn net.Conn) Agent { - return newAgent(conn, f.decoder, f.encoder, f.serializer, f.heartbeatTimeout, f.messagesBufferSize, f.appDieChan, f.messageEncoder, f.metricsReporters, f.sessionPool) + return newAgent(conn, f.decoder, f.encoder, f.serializer, f.heartbeatTimeout, f.writeTimeout, f.messagesBufferSize, f.appDieChan, f.messageEncoder, f.metricsReporters, f.sessionPool) } // NewAgent create new agent instance @@ -174,6 +180,7 @@ func newAgent( packetEncoder codec.PacketEncoder, serializer serialize.Serializer, heartbeatTime time.Duration, + writeTimeout time.Duration, messagesBufferSize int, dieChan chan bool, messageEncoder message.Encoder, @@ -188,6 +195,10 @@ func newAgent( herdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled(), serializerName) }) + if writeTimeout <= 0 { + writeTimeout = config.DefaultWriteTimeout + } + a := &agentImpl{ appDieChan: dieChan, chDie: make(chan struct{}), @@ -199,6 +210,7 @@ func newAgent( decoder: packetDecoder, encoder: packetEncoder, heartbeatTimeout: heartbeatTime, + writeTimeout: writeTimeout, lastAt: time.Now().Unix(), serializer: serializer, state: constants.StatusStart, @@ -503,18 +515,20 @@ func (a *agentImpl) write() { ctx, err, data := pWrite.ctx, pWrite.err, pWrite.data writeErr := a.writeToConnection(ctx, data) - if writeErr != nil { - err = errors.NewError(writeErr, errors.ErrClosedRequest) - - logger.Log.Errorf("Failed to write in conn: %s (ctx=%v), agent will close", writeErr.Error(), ctx) - } - tracing.FinishSpan(ctx, nil) - metrics.ReportTimingFromCtx(ctx, a.metricsReporters, handlerType, err) - // close agent if low-level conn broke during write if writeErr != nil { - return + if e.Is(writeErr, os.ErrDeadlineExceeded) { + // Log the timeout error but continue processing + logger.Log.Warnf("Context deadline exceeded for write in conn %s: %s (ctx=%v)", writeErr.Error(), ctx) + metrics.ReportTimingFromCtx(ctx, a.metricsReporters, handlerType, err) + } else { + err = errors.NewError(writeErr, errors.ErrClosedRequest) + logger.Log.Errorf("Failed to write in conn: %s (ctx=%v), agent will close", writeErr.Error(), ctx) + metrics.ReportTimingFromCtx(ctx, a.metricsReporters, handlerType, err) + // close agent if low-level conn broke during write + return + } } case <-a.chStopWrite: return @@ -524,17 +538,15 @@ func (a *agentImpl) write() { func (a *agentImpl) writeToConnection(ctx context.Context, data []byte) error { span := createConnectionSpan(ctx, a.conn, "conn write") - - _, writeErr := a.conn.Write(data) - defer span.Finish() + a.conn.SetWriteDeadline(time.Now().Add(a.writeTimeout)) + _, writeErr := a.conn.Write(data) if writeErr != nil { tracing.LogError(span, writeErr.Error()) return writeErr } - - return nil + return writeErr } func createConnectionSpan(ctx context.Context, conn net.Conn, op string) opentracing.Span { @@ -549,7 +561,7 @@ func createConnectionSpan(ctx context.Context, conn net.Conn, op string) opentra tags := opentracing.Tags{ "span.kind": "connection", - "addr": remoteAddress, + "addr": remoteAddress, } var parent opentracing.SpanContext diff --git a/agent/agent_test.go b/agent/agent_test.go index 27f8ded0..fc0a4b27 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -81,6 +81,7 @@ func TestNewAgent(t *testing.T) { mockDecoder := codecmocks.NewMockPacketDecoder(ctrl) dieChan := make(chan bool) hbTime := time.Second + writeTimeout := time.Second mockConn := mocks.NewMockPlayerConn(ctrl) @@ -99,7 +100,7 @@ func TestNewAgent(t *testing.T) { sessionPool := session.NewSessionPool() mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) - ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) + ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, writeTimeout, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) assert.NotNil(t, ag) assert.IsType(t, make(chan struct{}), ag.chDie) assert.IsType(t, make(chan pendingWrite), ag.chSend) @@ -110,7 +111,7 @@ func TestNewAgent(t *testing.T) { assert.Equal(t, mockConn, ag.conn) assert.Equal(t, mockDecoder, ag.decoder) assert.Equal(t, mockEncoder, ag.encoder) - assert.Equal(t, hbTime, ag.heartbeatTimeout) + assert.Equal(t, hbTime, writeTimeout, ag.heartbeatTimeout) assert.InDelta(t, time.Now().Unix(), ag.lastAt, 1) assert.Equal(t, mockSerializer, ag.serializer) assert.Equal(t, mockMetricsReporters, ag.metricsReporters) @@ -120,7 +121,7 @@ func TestNewAgent(t *testing.T) { // second call should no call hdb encode mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) - ag = newAgent(nil, nil, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) + ag = newAgent(nil, nil, mockEncoder, mockSerializer, hbTime, writeTimeout, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) assert.NotNil(t, ag) } @@ -133,6 +134,7 @@ func TestKick(t *testing.T) { mockDecoder := codecmocks.NewMockPacketDecoder(ctrl) dieChan := make(chan bool) hbTime := time.Second + writeTimeout := time.Second mockConn := mocks.NewMockPlayerConn(ctrl) mockEncoder.EXPECT().Encode(gomock.Any(), gomock.Nil()).Do( @@ -145,7 +147,7 @@ func TestKick(t *testing.T) { mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, nil, sessionPool) + ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, writeTimeout, 10, dieChan, messageEncoder, nil, sessionPool) c := context.Background() err := ag.Kick(c) assert.NoError(t, err) @@ -171,12 +173,13 @@ func TestAgentSend(t *testing.T) { mockDecoder := codecmocks.NewMockPacketDecoder(ctrl) dieChan := make(chan bool) hbTime := time.Second + writeTimeout := time.Second messageEncoder := message.NewMessagesEncoder(false) mockConn := mocks.NewMockPlayerConn(ctrl) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, writeTimeout, 10, dieChan, messageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) if table.err != nil { @@ -263,6 +266,7 @@ func TestAgentSendSerializeErr(t *testing.T) { var wg sync.WaitGroup wg.Add(1) mockConn.EXPECT().RemoteAddr().Times(2).Return(&mockAddr{}) + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) mockConn.EXPECT().Write(expectedPacket).Do(func(b []byte) { wg.Done() }) @@ -285,7 +289,7 @@ func TestAgentPushFailsIfClosedAgent(t *testing.T) { messageEncoder := message.NewMessagesEncoder(false) sessionPool := session.NewSessionPool() - ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, 10, nil, messageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, time.Second, 10, nil, messageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) ag.state = constants.StatusClosed err := ag.Push("", nil) @@ -312,6 +316,7 @@ func TestAgentPushStruct(t *testing.T) { mockDecoder := codecmocks.NewMockPacketDecoder(ctrl) dieChan := make(chan bool) hbTime := time.Second + writeTimeout := time.Second messageEncoder := message.NewMessagesEncoder(false) mockMetricsReporter := metricsmocks.NewMockReporter(ctrl) mockConn := mocks.NewMockPlayerConn(ctrl) @@ -319,7 +324,7 @@ func TestAgentPushStruct(t *testing.T) { mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) + ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, writeTimeout, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) assert.NotNil(t, ag) expectedBytes := []byte("hello") @@ -371,6 +376,7 @@ func TestAgentPush(t *testing.T) { mockDecoder := codecmocks.NewMockPacketDecoder(ctrl) dieChan := make(chan bool) hbTime := time.Second + writeTimeout := time.Second messageEncoder := message.NewMessagesEncoder(false) mockMetricsReporter := metricsmocks.NewMockReporter(ctrl) mockConn := mocks.NewMockPlayerConn(ctrl) @@ -378,7 +384,7 @@ func TestAgentPush(t *testing.T) { mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) + ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, writeTimeout, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) assert.NotNil(t, ag) expectedBytes := []byte("hello") @@ -418,6 +424,7 @@ func TestAgentPushFullChannel(t *testing.T) { mockDecoder := codecmocks.NewMockPacketDecoder(ctrl) dieChan := make(chan bool) hbTime := time.Second + writeTimeout := time.Second messageEncoder := message.NewMessagesEncoder(false) mockMetricsReporter := metricsmocks.NewMockReporter(ctrl) mockConn := mocks.NewMockPlayerConn(ctrl) @@ -425,7 +432,7 @@ func TestAgentPushFullChannel(t *testing.T) { mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 0, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) + ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, writeTimeout, 0, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) assert.NotNil(t, ag) mockMetricsReporter.EXPECT().ReportGauge(metrics.ChannelCapacity, gomock.Any(), float64(0)) @@ -461,7 +468,7 @@ func TestAgentResponseMIDFailsIfClosedAgent(t *testing.T) { mockMetricsReporters := []metrics.Reporter{mockMetricsReporter} mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) sessionPool := session.NewSessionPool() - ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, 10, nil, mockMessageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) + ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, time.Second, 10, nil, mockMessageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) assert.NotNil(t, ag) ag.state = constants.StatusClosed @@ -498,6 +505,7 @@ func TestAgentResponseMID(t *testing.T) { mockDecoder := codecmocks.NewMockPacketDecoder(ctrl) dieChan := make(chan bool) hbTime := time.Second + writeTimeout := time.Second messageEncoder := message.NewMessagesEncoder(false) mockConn := mocks.NewMockPlayerConn(ctrl) @@ -505,7 +513,7 @@ func TestAgentResponseMID(t *testing.T) { mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) + ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, writeTimeout, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) assert.NotNil(t, ag) ctx := getCtxWithRequestKeys() @@ -553,6 +561,7 @@ func TestAgentResponseMIDFullChannel(t *testing.T) { mockDecoder := codecmocks.NewMockPacketDecoder(ctrl) dieChan := make(chan bool) hbTime := time.Second + writeTimeout := time.Second messageEncoder := message.NewMessagesEncoder(false) mockMetricsReporter := metricsmocks.NewMockReporter(ctrl) mockConn := mocks.NewMockPlayerConn(ctrl) @@ -561,7 +570,7 @@ func TestAgentResponseMIDFullChannel(t *testing.T) { mockSerializer.EXPECT().GetName() mockEncoder.EXPECT().Encode(packet.Type(packet.Data), gomock.Any()) sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 0, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) + ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, writeTimeout, 0, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) assert.NotNil(t, ag) mockMetricsReporters[0].(*metricsmocks.MockReporter).EXPECT().ReportGauge(metrics.ChannelCapacity, gomock.Any(), float64(0)) go func() { @@ -582,7 +591,7 @@ func TestAgentCloseFailsIfAlreadyClosed(t *testing.T) { mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, 10, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, time.Second, 10, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) ag.state = constants.StatusClosed err := ag.Close() @@ -601,7 +610,7 @@ func TestAgentClose(t *testing.T) { mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) expected := false @@ -649,7 +658,7 @@ func TestAgentRemoteAddr(t *testing.T) { mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool) + ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool) assert.NotNil(t, ag) expected := &mockAddr{} @@ -670,7 +679,7 @@ func TestAgentString(t *testing.T) { mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) mockConn.EXPECT().RemoteAddr().Return(&mockAddr{}) @@ -702,7 +711,7 @@ func TestAgentGetStatus(t *testing.T) { mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) ag.state = table.status @@ -724,7 +733,7 @@ func TestAgentSetLastAt(t *testing.T) { mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) ag.lastAt = 0 @@ -753,7 +762,7 @@ func TestAgentSetStatus(t *testing.T) { mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) ag.SetStatus(table.status) @@ -771,7 +780,7 @@ func TestOnSessionClosed(t *testing.T) { mockSerializer := serializemocks.NewMockSerializer(ctrl) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) ss := sessionPool.NewSession(nil, true) @@ -793,7 +802,7 @@ func TestOnSessionClosedRecoversIfPanic(t *testing.T) { mockSerializer := serializemocks.NewMockSerializer(ctrl) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) ss := sessionPool.NewSession(nil, true) @@ -831,7 +840,7 @@ func TestAgentSendHandshakeResponse(t *testing.T) { mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool) + ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, time.Second, 0, nil, mockMessageEncoder, nil, sessionPool) assert.NotNil(t, ag) mockConn.EXPECT().Write(hrd).Return(0, table.err) @@ -884,7 +893,7 @@ func TestAnswerWithError(t *testing.T) { messageEncoder := message.NewMessagesEncoder(false) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, 1, nil, messageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, time.Second, 1, nil, messageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) mockSerializer.EXPECT().Marshal(gomock.Any()).Return(nil, row.getPayloadErr).AnyTimes() @@ -967,7 +976,7 @@ func TestAgentAnswerWithError(t *testing.T) { messageEncoder := message.NewMessagesEncoder(false) sessionPool := session.NewSessionPool() - ag := newAgent(nil, nil, encoder, row.serializer, time.Second, 1, nil, messageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(nil, nil, encoder, row.serializer, time.Second, time.Second, 1, nil, messageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) ag.AnswerWithError(nil, uint(rand.Int()), row.answeredErr) @@ -989,7 +998,7 @@ func TestAgentHeartbeat(t *testing.T) { mockMessageEncoder := messagemocks.NewMockEncoder(ctrl) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, 1, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, time.Second, 1, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) mockConn.EXPECT().RemoteAddr().MaxTimes(1) @@ -1024,7 +1033,7 @@ func TestAgentHeartbeatExitsIfConnError(t *testing.T) { mockMessageEncoder := messagemocks.NewMockEncoder(ctrl) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, 1, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, time.Second, 1, nil, mockMessageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) mockConn.EXPECT().RemoteAddr().MaxTimes(1) @@ -1064,7 +1073,7 @@ func TestAgentHeartbeatExitsOnStopHeartbeat(t *testing.T) { mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, time.Second, 1, nil, messageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) go func() { @@ -1103,6 +1112,7 @@ func TestAgentWriteChSend(t *testing.T) { var wg sync.WaitGroup wg.Add(1) mockConn.EXPECT().RemoteAddr().Times(2).Return(&mockAddr{}) + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) mockConn.EXPECT().Write(expectedPacket).Do(func(b []byte) { time.Sleep(10 * time.Millisecond) wg.Done() @@ -1123,7 +1133,7 @@ func TestAgentHandle(t *testing.T) { messageEncoder := message.NewMessagesEncoder(false) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil, sessionPool).(*agentImpl) + ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, time.Second, 1, nil, messageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) expectedBytes := []byte("bla") @@ -1142,6 +1152,7 @@ func TestAgentHandle(t *testing.T) { } }() + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).Times(3) mockConn.EXPECT().Write(expectedBytes).Return(0, nil).Do(func(d []byte) { wg.Done() }) @@ -1168,6 +1179,7 @@ func TestNatsRPCServerReportMetrics(t *testing.T) { mockDecoder := codecmocks.NewMockPacketDecoder(ctrl) dieChan := make(chan bool) hbTime := time.Second + writeTimeout := time.Second messageEncoder := message.NewMessagesEncoder(false) mockMetricsReporter := metricsmocks.NewMockReporter(ctrl) mockMetricsReporters := []metrics.Reporter{mockMetricsReporter} @@ -1175,7 +1187,7 @@ func TestNatsRPCServerReportMetrics(t *testing.T) { mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) mockSerializer.EXPECT().GetName() sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) + ag := newAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, writeTimeout, 10, dieChan, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) assert.NotNil(t, ag) ag.messagesBufferSize = 0 @@ -1233,7 +1245,7 @@ func TestAgentWriteChSendWriteError(t *testing.T) { mockMetricsReporters := []metrics.Reporter{mockMetricsReporter} sessionPool := session.NewSessionPool() - ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, 0, nil, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) + ag := newAgent(mockConn, nil, mockEncoder, mockSerializer, time.Second, time.Second, 0, nil, messageEncoder, mockMetricsReporters, sessionPool).(*agentImpl) ctx := getCtxWithRequestKeys() @@ -1257,6 +1269,7 @@ func TestAgentWriteChSendWriteError(t *testing.T) { mockConn.EXPECT().Close().Do(func() { wg.Done() }) + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) mockConn.EXPECT().Write(expectedPacket).Do(func(b []byte) { wg.Done() }).Return(0, writeError) diff --git a/builder.go b/builder.go index dc58f5f3..a989f46a 100644 --- a/builder.go +++ b/builder.go @@ -222,6 +222,7 @@ func (builder *Builder) Build() Pitaya { builder.PacketEncoder, builder.Serializer, builder.Config.Heartbeat.Interval, + builder.Config.Buffer.Agent.WriteTimeout, builder.MessageEncoder, builder.Config.Buffer.Agent.Messages, builder.SessionPool, diff --git a/config/config.go b/config/config.go index af1772d0..6fb66aeb 100644 --- a/config/config.go +++ b/config/config.go @@ -6,6 +6,8 @@ import ( "github.com/topfreegames/pitaya/v2/metrics/models" ) +const DefaultWriteTimeout = 10 * time.Second + // PitayaConfig provides all the configuration for a pitaya app type PitayaConfig struct { SerializerType uint16 `mapstructure:"serializertype"` @@ -25,7 +27,8 @@ type PitayaConfig struct { } `mapstructure:"handler"` Buffer struct { Agent struct { - Messages int `mapstructure:"messages"` + Messages int `mapstructure:"messages"` + WriteTimeout time.Duration `mapstructure:"conntimeout"` } `mapstructure:"agent"` Handler struct { LocalProcess int `mapstructure:"localprocess"` @@ -91,7 +94,8 @@ func NewDefaultPitayaConfig() *PitayaConfig { }, Buffer: struct { Agent struct { - Messages int `mapstructure:"messages"` + Messages int `mapstructure:"messages"` + WriteTimeout time.Duration `mapstructure:"conntimeout"` } `mapstructure:"agent"` Handler struct { LocalProcess int `mapstructure:"localprocess"` @@ -99,9 +103,11 @@ func NewDefaultPitayaConfig() *PitayaConfig { } `mapstructure:"handler"` }{ Agent: struct { - Messages int `mapstructure:"messages"` + Messages int `mapstructure:"messages"` + WriteTimeout time.Duration `mapstructure:"conntimeout"` }{ - Messages: 100, + Messages: 100, + WriteTimeout: DefaultWriteTimeout, }, Handler: struct { LocalProcess int `mapstructure:"localprocess"` diff --git a/docs/configuration.rst b/docs/configuration.rst index 6ff57a52..dc72d3ee 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -223,6 +223,10 @@ Connection - 30s - time.Time - Keepalive heartbeat interval for the client connection + * - pitaya.buffer.agent.writetimeout + - 10s + - time.Duration + - Timeout for agent to send packets * - pitaya.conn.ratelimiting.interval - 1s - time.Duration