Skip to content

Commit

Permalink
Add conn read/write to prevent concurrent write bad data
Browse files Browse the repository at this point in the history
Signed-off-by: Yilun <[email protected]>
  • Loading branch information
yilunzhang committed Aug 6, 2020
1 parent 539cb8c commit 761befc
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
27 changes: 17 additions & 10 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type TunaSessionClient struct {
tunaExits []*tuna.TunaExit
acceptAddrs []*regexp.Regexp
sessions map[string]*ncp.Session
sessionConns map[string]map[string]net.Conn
sessionConns map[string]map[string]*Conn
sharedKeys map[string]*[sharedKeySize]byte
connCount map[string]int
isClosed bool
Expand All @@ -61,7 +61,7 @@ func NewTunaSessionClient(clientAccount *nkn.Account, m *nkn.MultiClient, wallet
acceptSession: make(chan *ncp.Session, acceptSessionBufSize),
onClose: make(chan struct{}, 0),
sessions: make(map[string]*ncp.Session),
sessionConns: make(map[string]map[string]net.Conn),
sessionConns: make(map[string]map[string]*Conn),
sharedKeys: make(map[string]*[sharedKeySize]byte),
connCount: make(map[string]int),
}
Expand Down Expand Up @@ -238,14 +238,16 @@ func (c *TunaSessionClient) listenNKN() {

func (c *TunaSessionClient) listenNet(i int) {
for {
conn, err := c.listeners[i].Accept()
netConn, err := c.listeners[i].Accept()
if err != nil {
log.Printf("Accept connection error: %v", err)
time.Sleep(time.Second)
continue
}

go func(conn net.Conn) {
conn := newConn(netConn)

go func(conn *Conn) {
defer conn.Close()

buf, err := readMessage(conn)
Expand Down Expand Up @@ -286,7 +288,7 @@ func (c *TunaSessionClient) listenNet(i int) {
return
}
c.sessions[sessionKey] = sess
c.sessionConns[sessionKey] = make(map[string]net.Conn, c.config.NumTunaListeners)
c.sessionConns[sessionKey] = make(map[string]*Conn, c.config.NumTunaListeners)
}
c.sessionConns[sessionKey][connID(i)] = conn
c.Unlock()
Expand Down Expand Up @@ -420,33 +422,38 @@ func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *DialConfig

var lock sync.Mutex
var wg sync.WaitGroup
conns := make(map[string]net.Conn, len(pubAddrs.Addrs))
conns := make(map[string]*Conn, len(pubAddrs.Addrs))
dialer := &net.Dialer{}
for i := range pubAddrs.Addrs {
wg.Add(1)
go func(i int) {
defer wg.Done()
conn, err := dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", pubAddrs.Addrs[i].IP, pubAddrs.Addrs[i].Port))
netConn, err := dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", pubAddrs.Addrs[i].IP, pubAddrs.Addrs[i].Port))
if err != nil {
log.Printf("Dial error: %v", err)
return
}

conn := newConn(netConn)

err = writeMessage(conn, []byte(c.addr.String()))
if err != nil {
log.Printf("Write message error: %v", err)
conn.Close()
return
}

buf, err := c.encode(sessionID, remoteAddr)
if err != nil {
log.Printf("Encode message error: %v", err)
conn.Close()
return
}

err = writeMessage(conn, buf)
if err != nil {
log.Printf("Write message error: %v", err)
conn.Close()
return
}

Expand Down Expand Up @@ -477,7 +484,7 @@ func (c *TunaSessionClient) DialWithConfig(remoteAddr string, config *DialConfig

for i := 0; i < len(pubAddrs.Addrs); i++ {
if conn, ok := conns[connID(i)]; ok {
go func(conn net.Conn, i int) {
go func(conn *Conn, i int) {
c.Lock()
c.connCount[sessionKey]++
c.Unlock()
Expand Down Expand Up @@ -610,7 +617,7 @@ func (c *TunaSessionClient) newSession(remoteAddr string, sessionID []byte, conn
}), config)
}

func (c *TunaSessionClient) handleMsg(conn net.Conn, sess *ncp.Session, i int) error {
func (c *TunaSessionClient) handleMsg(conn *Conn, sess *ncp.Session, i int) error {
buf, err := readMessage(conn)
if err != nil {
return err
Expand All @@ -629,7 +636,7 @@ func (c *TunaSessionClient) handleMsg(conn net.Conn, sess *ncp.Session, i int) e
return nil
}

func (c *TunaSessionClient) handleConn(conn net.Conn, sess *ncp.Session, i int) {
func (c *TunaSessionClient) handleConn(conn *Conn, sess *ncp.Session, i int) {
for {
err := c.handleMsg(conn, sess, i)
if err != nil {
Expand Down
16 changes: 16 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package session

import (
"net"
"sync"
)

type Conn struct {
net.Conn
ReadLock sync.Mutex
WriteLock sync.Mutex
}

func newConn(conn net.Conn) *Conn {
return &Conn{Conn: conn}
}
10 changes: 8 additions & 2 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ func decrypt(message []byte, nonce [nonceSize]byte, sharedKey *[sharedKeySize]by
return decrypted, nil
}

func writeMessage(conn net.Conn, buf []byte) error {
func writeMessage(conn *Conn, buf []byte) error {
conn.WriteLock.Lock()
defer conn.WriteLock.Unlock()

msgSizeBuf := make([]byte, 4)
binary.LittleEndian.PutUint32(msgSizeBuf, uint32(len(buf)))
_, err := conn.Write(msgSizeBuf)
Expand All @@ -108,7 +111,10 @@ func readFull(conn net.Conn, buf []byte) error {
}
}

func readMessage(conn net.Conn) ([]byte, error) {
func readMessage(conn *Conn) ([]byte, error) {
conn.ReadLock.Lock()
defer conn.ReadLock.Unlock()

msgSizeBuf := make([]byte, 4)
err := readFull(conn, msgSizeBuf)
if err != nil {
Expand Down

0 comments on commit 761befc

Please sign in to comment.