From d86138d7f3aa1c6608b87a496698ecaebd9b1dc9 Mon Sep 17 00:00:00 2001 From: Yilun Date: Fri, 25 Feb 2022 01:10:55 -0800 Subject: [PATCH] Add write timeout and use nkn-sdk-go dial config type Signed-off-by: Yilun --- client.go | 12 +++++++----- config.go | 28 ---------------------------- examples/throughput/main.go | 4 ++-- message.go | 18 ++++++++++++++++-- 4 files changed, 25 insertions(+), 37 deletions(-) diff --git a/client.go b/client.go index 0b61cab..71439c6 100644 --- a/client.go +++ b/client.go @@ -477,8 +477,8 @@ func (c *TunaSessionClient) DialSession(remoteAddr string) (*ncp.Session, error) return c.DialWithConfig(remoteAddr, nil) } -func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *DialConfig) (*ncp.Session, error) { - config, err := MergeDialConfig(c.config.SessionConfig, config) +func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *nkn.DialConfig) (*ncp.Session, error) { + config, err := nkn.MergeDialConfig(c.config.SessionConfig, config) if err != nil { return nil, err } @@ -534,7 +534,7 @@ func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *DialConfig conn := newConn(netConn) - err = writeMessage(conn, []byte(c.addr.String())) + err = writeMessage(conn, []byte(c.addr.String()), time.Duration(config.DialTimeout)*time.Millisecond) if err != nil { log.Printf("Write message error: %v", err) conn.Close() @@ -558,7 +558,7 @@ func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *DialConfig return } - err = writeMessage(conn, buf) + err = writeMessage(conn, buf, time.Duration(config.DialTimeout)*time.Millisecond) if err != nil { log.Printf("Write message error: %v", err) conn.Close() @@ -700,9 +700,10 @@ func (c *TunaSessionClient) newSession(remoteAddr string, sessionID []byte, conn if err != nil { return err } - err = writeMessage(conn, buf) + err = writeMessage(conn, buf, writeTimeout) if err != nil { log.Println("Write message error:", err) + conn.Close() return ncp.ErrConnClosed } return nil @@ -769,6 +770,7 @@ func (c *TunaSessionClient) handleConn(conn *Conn, sessKey string, i int) { default: } log.Printf("handle msg error: %v", err) + return } } } diff --git a/config.go b/config.go index 742bbcd..82e7d0f 100644 --- a/config.go +++ b/config.go @@ -57,23 +57,6 @@ func DefaultSessionConfig() *ncp.Config { return &sessionConf } -type DialConfig struct { - DialTimeout int32 //in millisecond - SessionConfig *ncp.Config -} - -var defaultDialConfig = DialConfig{ - DialTimeout: 0, - SessionConfig: nil, -} - -func DefaultDialConfig(baseSessionConfig *ncp.Config) *DialConfig { - dialConf := defaultDialConfig - sessionConfig := *baseSessionConfig - dialConf.SessionConfig = &sessionConfig - return &dialConf -} - func MergedConfig(conf *Config) (*Config, error) { merged := DefaultConfig() if conf != nil { @@ -84,14 +67,3 @@ func MergedConfig(conf *Config) (*Config, error) { } return merged, nil } - -func MergeDialConfig(baseSessionConfig *ncp.Config, conf *DialConfig) (*DialConfig, error) { - merged := DefaultDialConfig(baseSessionConfig) - if conf != nil { - err := mergo.Merge(merged, conf, mergo.WithOverride) - if err != nil { - return nil, err - } - } - return merged, nil -} diff --git a/examples/throughput/main.go b/examples/throughput/main.go index a30ca3c..5752f00 100644 --- a/examples/throughput/main.go +++ b/examples/throughput/main.go @@ -138,7 +138,7 @@ func main() { log.Println("Seed:", hex.EncodeToString(account.Seed())) clientConfig := &nkn.ClientConfig{ConnectRetries: 1} - dialConfig := &ts.DialConfig{DialTimeout: 5000} + dialConfig := &nkn.DialConfig{DialTimeout: 5000} config := &ts.Config{ NumTunaListeners: *numTunaListeners, TunaServiceName: *tunaServiceName, @@ -195,7 +195,7 @@ func main() { } <-m.OnConnect.C - time.Sleep(time.Second) + time.Sleep(5 * time.Second) c, err := ts.NewTunaSessionClient(account, m, wallet, config) if err != nil { diff --git a/message.go b/message.go index b182214..e1effed 100644 --- a/message.go +++ b/message.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "time" "github.com/nknorg/nkn/v2/crypto/ed25519" "golang.org/x/crypto/nacl/box" @@ -87,19 +88,32 @@ func decrypt(message []byte, nonce [nonceSize]byte, sharedKey *[sharedKeySize]by return decrypted, nil } -func writeMessage(conn *Conn, buf []byte) error { +func writeMessage(conn *Conn, buf []byte, writeTimeout time.Duration) error { conn.WriteLock.Lock() defer conn.WriteLock.Unlock() msgSizeBuf := make([]byte, 4) binary.LittleEndian.PutUint32(msgSizeBuf, uint32(len(buf))) + + if writeTimeout > 0 { + conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + } + _, err := conn.Write(msgSizeBuf) if err != nil { return err } _, err = conn.Write(buf) - return err + if err != nil { + return err + } + + if writeTimeout > 0 { + conn.SetWriteDeadline(zeroTime) + } + + return nil } func readFull(conn net.Conn, buf []byte) error {