From 761befc7715fdc88ec1e9d50c7b79d2f7a6cd457 Mon Sep 17 00:00:00 2001 From: Yilun Date: Wed, 5 Aug 2020 18:11:30 -0700 Subject: [PATCH] Add conn read/write to prevent concurrent write bad data Signed-off-by: Yilun --- client.go | 27 +++++++++++++++++---------- conn.go | 16 ++++++++++++++++ message.go | 10 ++++++++-- 3 files changed, 41 insertions(+), 12 deletions(-) create mode 100644 conn.go diff --git a/client.go b/client.go index 9fc1aa1..c5972d7 100644 --- a/client.go +++ b/client.go @@ -40,7 +40,7 @@ type TunaSessionClient struct { tunaExits []*tuna.TunaExit acceptAddrs []*regexp.Regexp sessions map[string]*ncp.Session - sessionConns map[string]map[string]net.Conn + sessionConns map[string]map[string]*Conn sharedKeys map[string]*[sharedKeySize]byte connCount map[string]int isClosed bool @@ -61,7 +61,7 @@ func NewTunaSessionClient(clientAccount *nkn.Account, m *nkn.MultiClient, wallet acceptSession: make(chan *ncp.Session, acceptSessionBufSize), onClose: make(chan struct{}, 0), sessions: make(map[string]*ncp.Session), - sessionConns: make(map[string]map[string]net.Conn), + sessionConns: make(map[string]map[string]*Conn), sharedKeys: make(map[string]*[sharedKeySize]byte), connCount: make(map[string]int), } @@ -238,14 +238,16 @@ func (c *TunaSessionClient) listenNKN() { func (c *TunaSessionClient) listenNet(i int) { for { - conn, err := c.listeners[i].Accept() + netConn, err := c.listeners[i].Accept() if err != nil { log.Printf("Accept connection error: %v", err) time.Sleep(time.Second) continue } - go func(conn net.Conn) { + conn := newConn(netConn) + + go func(conn *Conn) { defer conn.Close() buf, err := readMessage(conn) @@ -286,7 +288,7 @@ func (c *TunaSessionClient) listenNet(i int) { return } c.sessions[sessionKey] = sess - c.sessionConns[sessionKey] = make(map[string]net.Conn, c.config.NumTunaListeners) + c.sessionConns[sessionKey] = make(map[string]*Conn, c.config.NumTunaListeners) } c.sessionConns[sessionKey][connID(i)] = conn c.Unlock() @@ -420,33 +422,38 @@ func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *DialConfig var lock sync.Mutex var wg sync.WaitGroup - conns := make(map[string]net.Conn, len(pubAddrs.Addrs)) + conns := make(map[string]*Conn, len(pubAddrs.Addrs)) dialer := &net.Dialer{} for i := range pubAddrs.Addrs { wg.Add(1) go func(i int) { defer wg.Done() - conn, err := dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", pubAddrs.Addrs[i].IP, pubAddrs.Addrs[i].Port)) + netConn, err := dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", pubAddrs.Addrs[i].IP, pubAddrs.Addrs[i].Port)) if err != nil { log.Printf("Dial error: %v", err) return } + conn := newConn(netConn) + err = writeMessage(conn, []byte(c.addr.String())) if err != nil { log.Printf("Write message error: %v", err) + conn.Close() return } buf, err := c.encode(sessionID, remoteAddr) if err != nil { log.Printf("Encode message error: %v", err) + conn.Close() return } err = writeMessage(conn, buf) if err != nil { log.Printf("Write message error: %v", err) + conn.Close() return } @@ -477,7 +484,7 @@ func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *DialConfig for i := 0; i < len(pubAddrs.Addrs); i++ { if conn, ok := conns[connID(i)]; ok { - go func(conn net.Conn, i int) { + go func(conn *Conn, i int) { c.Lock() c.connCount[sessionKey]++ c.Unlock() @@ -610,7 +617,7 @@ func (c *TunaSessionClient) newSession(remoteAddr string, sessionID []byte, conn }), config) } -func (c *TunaSessionClient) handleMsg(conn net.Conn, sess *ncp.Session, i int) error { +func (c *TunaSessionClient) handleMsg(conn *Conn, sess *ncp.Session, i int) error { buf, err := readMessage(conn) if err != nil { return err @@ -629,7 +636,7 @@ func (c *TunaSessionClient) handleMsg(conn net.Conn, sess *ncp.Session, i int) e return nil } -func (c *TunaSessionClient) handleConn(conn net.Conn, sess *ncp.Session, i int) { +func (c *TunaSessionClient) handleConn(conn *Conn, sess *ncp.Session, i int) { for { err := c.handleMsg(conn, sess, i) if err != nil { diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..0ca395c --- /dev/null +++ b/conn.go @@ -0,0 +1,16 @@ +package session + +import ( + "net" + "sync" +) + +type Conn struct { + net.Conn + ReadLock sync.Mutex + WriteLock sync.Mutex +} + +func newConn(conn net.Conn) *Conn { + return &Conn{Conn: conn} +} diff --git a/message.go b/message.go index a425ca0..3b22622 100644 --- a/message.go +++ b/message.go @@ -82,7 +82,10 @@ func decrypt(message []byte, nonce [nonceSize]byte, sharedKey *[sharedKeySize]by return decrypted, nil } -func writeMessage(conn net.Conn, buf []byte) error { +func writeMessage(conn *Conn, buf []byte) error { + conn.WriteLock.Lock() + defer conn.WriteLock.Unlock() + msgSizeBuf := make([]byte, 4) binary.LittleEndian.PutUint32(msgSizeBuf, uint32(len(buf))) _, err := conn.Write(msgSizeBuf) @@ -108,7 +111,10 @@ func readFull(conn net.Conn, buf []byte) error { } } -func readMessage(conn net.Conn) ([]byte, error) { +func readMessage(conn *Conn) ([]byte, error) { + conn.ReadLock.Lock() + defer conn.ReadLock.Unlock() + msgSizeBuf := make([]byte, 4) err := readFull(conn, msgSizeBuf) if err != nil {