Skip to content

Commit

Permalink
Add write timeout and use nkn-sdk-go dial config type
Browse files Browse the repository at this point in the history
Signed-off-by: Yilun <[email protected]>
  • Loading branch information
yilunzhang committed Feb 25, 2022
1 parent 2fda026 commit d86138d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 37 deletions.
12 changes: 7 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ func (c *TunaSessionClient) DialSession(remoteAddr string) (*ncp.Session, error)
return c.DialWithConfig(remoteAddr, nil)
}

func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *DialConfig) (*ncp.Session, error) {
config, err := MergeDialConfig(c.config.SessionConfig, config)
func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *nkn.DialConfig) (*ncp.Session, error) {
config, err := nkn.MergeDialConfig(c.config.SessionConfig, config)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -534,7 +534,7 @@ func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *DialConfig

conn := newConn(netConn)

err = writeMessage(conn, []byte(c.addr.String()))
err = writeMessage(conn, []byte(c.addr.String()), time.Duration(config.DialTimeout)*time.Millisecond)
if err != nil {
log.Printf("Write message error: %v", err)
conn.Close()
Expand All @@ -558,7 +558,7 @@ func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *DialConfig
return
}

err = writeMessage(conn, buf)
err = writeMessage(conn, buf, time.Duration(config.DialTimeout)*time.Millisecond)
if err != nil {
log.Printf("Write message error: %v", err)
conn.Close()
Expand Down Expand Up @@ -700,9 +700,10 @@ func (c *TunaSessionClient) newSession(remoteAddr string, sessionID []byte, conn
if err != nil {
return err
}
err = writeMessage(conn, buf)
err = writeMessage(conn, buf, writeTimeout)
if err != nil {
log.Println("Write message error:", err)
conn.Close()
return ncp.ErrConnClosed
}
return nil
Expand Down Expand Up @@ -769,6 +770,7 @@ func (c *TunaSessionClient) handleConn(conn *Conn, sessKey string, i int) {
default:
}
log.Printf("handle msg error: %v", err)
return
}
}
}
Expand Down
28 changes: 0 additions & 28 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,6 @@ func DefaultSessionConfig() *ncp.Config {
return &sessionConf
}

type DialConfig struct {
DialTimeout int32 //in millisecond
SessionConfig *ncp.Config
}

var defaultDialConfig = DialConfig{
DialTimeout: 0,
SessionConfig: nil,
}

func DefaultDialConfig(baseSessionConfig *ncp.Config) *DialConfig {
dialConf := defaultDialConfig
sessionConfig := *baseSessionConfig
dialConf.SessionConfig = &sessionConfig
return &dialConf
}

func MergedConfig(conf *Config) (*Config, error) {
merged := DefaultConfig()
if conf != nil {
Expand All @@ -84,14 +67,3 @@ func MergedConfig(conf *Config) (*Config, error) {
}
return merged, nil
}

func MergeDialConfig(baseSessionConfig *ncp.Config, conf *DialConfig) (*DialConfig, error) {
merged := DefaultDialConfig(baseSessionConfig)
if conf != nil {
err := mergo.Merge(merged, conf, mergo.WithOverride)
if err != nil {
return nil, err
}
}
return merged, nil
}
4 changes: 2 additions & 2 deletions examples/throughput/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func main() {
log.Println("Seed:", hex.EncodeToString(account.Seed()))

clientConfig := &nkn.ClientConfig{ConnectRetries: 1}
dialConfig := &ts.DialConfig{DialTimeout: 5000}
dialConfig := &nkn.DialConfig{DialTimeout: 5000}
config := &ts.Config{
NumTunaListeners: *numTunaListeners,
TunaServiceName: *tunaServiceName,
Expand Down Expand Up @@ -195,7 +195,7 @@ func main() {
}

<-m.OnConnect.C
time.Sleep(time.Second)
time.Sleep(5 * time.Second)

c, err := ts.NewTunaSessionClient(account, m, wallet, config)
if err != nil {
Expand Down
18 changes: 16 additions & 2 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net"
"time"

"github.com/nknorg/nkn/v2/crypto/ed25519"
"golang.org/x/crypto/nacl/box"
Expand Down Expand Up @@ -87,19 +88,32 @@ func decrypt(message []byte, nonce [nonceSize]byte, sharedKey *[sharedKeySize]by
return decrypted, nil
}

func writeMessage(conn *Conn, buf []byte) error {
func writeMessage(conn *Conn, buf []byte, writeTimeout time.Duration) error {
conn.WriteLock.Lock()
defer conn.WriteLock.Unlock()

msgSizeBuf := make([]byte, 4)
binary.LittleEndian.PutUint32(msgSizeBuf, uint32(len(buf)))

if writeTimeout > 0 {
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
}

_, err := conn.Write(msgSizeBuf)
if err != nil {
return err
}

_, err = conn.Write(buf)
return err
if err != nil {
return err
}

if writeTimeout > 0 {
conn.SetWriteDeadline(zeroTime)
}

return nil
}

func readFull(conn net.Conn, buf []byte) error {
Expand Down

0 comments on commit d86138d

Please sign in to comment.