Skip to content

Commit

Permalink
Merge pull request #483 from fosterseth/cherrypick_release_1.1_into_d…
Browse files Browse the repository at this point in the history
…evel

Cherrypick release 1.1 into devel
  • Loading branch information
shanemcd authored Nov 16, 2021
2 parents 6d59160 + 955bcaf commit bfea897
Show file tree
Hide file tree
Showing 20 changed files with 292 additions and 272 deletions.
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,15 @@ else
TAGPARAM=--tags $(TAGS)
endif

DEBUG ?=
ifeq ($(DEBUG),1)
DEBUGFLAGS=-gcflags=all="-N -l"
else
DEBUGFLAGS=
endif

receptor: $(shell find pkg -type f -name '*.go') ./cmd/receptor-cl/receptor.go
CGO_ENABLED=0 go build -o receptor -ldflags "-X 'github.com/ansible/receptor/internal/version.Version=$(APPVER)'" $(TAGPARAM) ./cmd/receptor-cl
CGO_ENABLED=0 go build -o receptor $(DEBUGFLAGS) -ldflags "-X 'github.com/ansible/receptor/internal/version.Version=$(APPVER)'" $(TAGPARAM) ./cmd/receptor-cl

lint:
@golint cmd/... pkg/... example/...
Expand Down
3 changes: 2 additions & 1 deletion pkg/controlsvc/connect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package controlsvc

import (
"context"
"fmt"
"strings"

Expand Down Expand Up @@ -73,7 +74,7 @@ func (t *connectCommandType) InitFromJSON(config map[string]interface{}) (Contro
return c, nil
}

func (c *connectCommand) ControlFunc(nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) {
func (c *connectCommand) ControlFunc(ctx context.Context, nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) {
tlscfg, err := nc.GetClientTLSConfig(c.tlsConfigName, c.targetNode, "receptor")
if err != nil {
return nil, err
Expand Down
25 changes: 16 additions & 9 deletions pkg/controlsvc/controlsvc.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,14 @@ func (s *Server) AddControlFunc(name string, cType ControlCommandType) error {

// RunControlSession runs the server protocol on the given connection.
func (s *Server) RunControlSession(conn net.Conn) {
logger.Info("Client connected to control service\n")
logger.Info("Client connected to control service %s\n", conn.RemoteAddr().String())
defer func() {
logger.Info("Client disconnected from control service\n")
err := conn.Close()
if err != nil {
logger.Error("Error closing connection: %s\n", err)
logger.Info("Client disconnected from control service %s\n", conn.RemoteAddr().String())
if conn != nil {
err := conn.Close()
if err != nil {
logger.Error("Error closing connection: %s\n", err)
}
}
}()
_, err := conn.Write([]byte(fmt.Sprintf("Receptor Control, node %s\n", s.nc.NodeID())))
Expand Down Expand Up @@ -224,7 +226,9 @@ func (s *Server) RunControlSession(conn net.Conn) {
cc, err = ct.InitFromJSON(jsonData)
}
if err == nil {
cfr, err = cc.ControlFunc(s.nc, cfo)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cfr, err = cc.ControlFunc(ctx, s.nc, cfo)
}
if err != nil {
logger.Error(err.Error())
Expand Down Expand Up @@ -331,34 +335,37 @@ func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls.
if ctx.Err() != nil {
return
}
if err != nil {
if strings.HasSuffix(err.Error(), "normal close") {
continue
}
}
if err != nil {
logger.Error("Error accepting connection: %s. Closing listener.\n", err)
_ = listener.Close()

return
}
go func() {
defer conn.Close()
tlsConn, ok := conn.(*tls.Conn)
if ok {
// Explicitly run server TLS handshake so we can deal with timeout and errors here
err = conn.SetDeadline(time.Now().Add(10 * time.Second))
if err != nil {
logger.Error("Error setting timeout: %s. Closing socket.\n", err)
_ = conn.Close()

return
}
err = tlsConn.Handshake()
if err != nil {
logger.Error("TLS handshake error: %s. Closing socket.\n", err)
_ = conn.Close()

return
}
err = conn.SetDeadline(time.Time{})
if err != nil {
logger.Error("Error clearing timeout: %s. Closing socket.\n", err)
_ = conn.Close()

return
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/controlsvc/interfaces.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package controlsvc

import (
"context"
"io"
"net"

Expand All @@ -15,7 +16,7 @@ type ControlCommandType interface {

// ControlCommand is an instance of a command that is being run from the control service.
type ControlCommand interface {
ControlFunc(*netceptor.Netceptor, ControlFuncOperations) (map[string]interface{}, error)
ControlFunc(context.Context, *netceptor.Netceptor, ControlFuncOperations) (map[string]interface{}, error)
}

// ControlFuncOperations provides callbacks for control services to take actions.
Expand Down
11 changes: 5 additions & 6 deletions pkg/controlsvc/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,16 @@ func ping(nc *netceptor.Netceptor, target string, hopsToLive byte) (time.Duratio
_ = pc.Close()
}()
pc.SetHopsToLive(hopsToLive)
unrCh := pc.SubscribeUnreachable()
doneChan := make(chan struct{})
unrCh := pc.SubscribeUnreachable(doneChan)
defer close(doneChan)
type errorResult struct {
err error
fromNode string
}
errorChan := make(chan errorResult)
go func() {
select {
case <-ctx.Done():
return
case msg := <-unrCh:
for msg := range unrCh {
errorChan <- errorResult{
err: fmt.Errorf(msg.Problem),
fromNode: msg.ReceivedFromNode,
Expand Down Expand Up @@ -111,7 +110,7 @@ func ping(nc *netceptor.Netceptor, target string, hopsToLive byte) (time.Duratio
}
}

func (c *pingCommand) ControlFunc(nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) {
func (c *pingCommand) ControlFunc(ctx context.Context, nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) {
pingTime, pingRemote, err := ping(nc, c.target, nc.MaxForwardingHops())
cfr := make(map[string]interface{})
if err == nil {
Expand Down
3 changes: 2 additions & 1 deletion pkg/controlsvc/reload.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package controlsvc

import (
"context"
"fmt"
"io/ioutil"
"strings"
Expand Down Expand Up @@ -157,7 +158,7 @@ func handleError(err error, errorcode int) (map[string]interface{}, error) {
return cfr, nil
}

func (c *reloadCommand) ControlFunc(nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) {
func (c *reloadCommand) ControlFunc(ctx context.Context, nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) {
// Reload command stops all backends, and re-runs the ParseAndRun() on the
// initial config file
logger.Debug("Reloading")
Expand Down
3 changes: 2 additions & 1 deletion pkg/controlsvc/status.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package controlsvc

import (
"context"
"fmt"

"github.com/ansible/receptor/internal/version"
Expand Down Expand Up @@ -46,7 +47,7 @@ func (t *statusCommandType) InitFromJSON(config map[string]interface{}) (Control
return c, nil
}

func (c *statusCommand) ControlFunc(nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) {
func (c *statusCommand) ControlFunc(ctx context.Context, nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) {
status := nc.Status()
statusGetters := make(map[string]func() interface{})
statusGetters["Version"] = func() interface{} { return version.Version }
Expand Down
3 changes: 2 additions & 1 deletion pkg/controlsvc/traceroute.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package controlsvc

import (
"context"
"fmt"
"strconv"

Expand Down Expand Up @@ -41,7 +42,7 @@ func (t *tracerouteCommandType) InitFromJSON(config map[string]interface{}) (Con
return c, nil
}

func (c *tracerouteCommand) ControlFunc(nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) {
func (c *tracerouteCommand) ControlFunc(ctx context.Context, nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) {
cfr := make(map[string]interface{})
for i := 0; i <= int(nc.MaxForwardingHops()); i++ {
thisResult := make(map[string]interface{})
Expand Down
40 changes: 25 additions & 15 deletions pkg/netceptor/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"sync"
"time"

"github.com/ansible/receptor/pkg/utils"
"github.com/ansible/receptor/pkg/logger"
"github.com/lucas-clemente/quic-go"
)

Expand Down Expand Up @@ -197,7 +197,7 @@ func (li *Listener) acceptLoop() {
return
}
doneChan := make(chan struct{}, 1)
cctx, ccancel := utils.ContextWithCancelWithErr(li.s.context)
cctx, ccancel := context.WithCancel(li.s.context)
conn := &Conn{
s: li.s,
pc: li.pc,
Expand Down Expand Up @@ -296,7 +296,7 @@ func (s *Netceptor) DialContext(ctx context.Context, node string, service string
_ = pc.Close()
})
}
cctx, ccancel := utils.ContextWithCancelWithErr(ctx)
cctx, ccancel := context.WithCancel(ctx)
go func() {
select {
case <-okChan:
Expand Down Expand Up @@ -370,18 +370,18 @@ func (s *Netceptor) DialContext(ctx context.Context, node string, service string

// monitorUnreachable receives unreachable messages from the underlying PacketConn, and ends the connection
// if the remote service has gone away.
func monitorUnreachable(pc *PacketConn, doneChan chan struct{}, remoteAddr Addr, cancel utils.CancelWithErrFunc) {
msgCh := pc.SubscribeUnreachable()
for {
select {
case <-pc.context.Done():
return
case <-doneChan:
return
case msg := <-msgCh:
if msg.Problem == ProblemServiceUnknown && msg.ToNode == remoteAddr.node && msg.ToService == remoteAddr.service {
cancel(fmt.Errorf("remote service unreachable"))
}
func monitorUnreachable(pc *PacketConn, doneChan chan struct{}, remoteAddr Addr, cancel context.CancelFunc) {
msgCh := pc.SubscribeUnreachable(doneChan)
if msgCh == nil {
cancel()

return
}
// read from channel until closed
for msg := range msgCh {
if msg.Problem == ProblemServiceUnknown && msg.ToNode == remoteAddr.node && msg.ToService == remoteAddr.service {
logger.Error("remote service unreachable")
cancel()
}
}
}
Expand Down Expand Up @@ -410,6 +410,16 @@ func (c *Conn) Close() error {
return c.qs.Close()
}

func (c *Conn) CloseConnection() error {
c.pc.cancel()
c.doneOnce.Do(func() {
close(c.doneChan)
})
logger.Debug("closing connection from service %s to %s", c.pc.localService, c.RemoteAddr().String())

return c.qc.CloseWithError(0, "normal close")
}

// LocalAddr returns the local address of this connection.
func (c *Conn) LocalAddr() net.Addr {
return c.qc.LocalAddr()
Expand Down
17 changes: 13 additions & 4 deletions pkg/netceptor/netceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1622,7 +1622,13 @@ func (s *Netceptor) sendInitialConnectMessage(ci *connInfo, initDoneChan chan bo
return
}
logger.Debug("Sending initial connection message\n")
ci.WriteChan <- ri
select {
case ci.WriteChan <- ri:
case <-ci.Context.Done():
return
case <-initDoneChan:
return
}
count++
if count > 10 {
logger.Warning("Giving up on connection initialization\n")
Expand All @@ -1643,15 +1649,18 @@ func (s *Netceptor) sendInitialConnectMessage(ci *connInfo, initDoneChan chan bo
}
}

func (s *Netceptor) sendRejectMessage(writeChan chan []byte) {
func (s *Netceptor) sendRejectMessage(ci *connInfo) {
rejMsg, err := s.translateStructToNetwork(MsgTypeReject, make([]string, 0))
if err != nil {
writeChan <- rejMsg
select {
case <-ci.Context.Done():
case ci.WriteChan <- rejMsg:
}
}
}

func (s *Netceptor) sendAndLogConnectionRejection(remoteNodeID string, ci *connInfo, reason string) error {
s.sendRejectMessage(ci.WriteChan)
s.sendRejectMessage(ci)

return fmt.Errorf("rejected connection with node %s because %s", remoteNodeID, reason)
}
Expand Down
34 changes: 18 additions & 16 deletions pkg/netceptor/netceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,18 +556,19 @@ func TestFirewalling(t *testing.T) {
}

// Subscribe for unreachable messages
unreach2chan := pc2.SubscribeUnreachable()
doneChan := make(chan struct{})
unreach2chan := pc2.SubscribeUnreachable(doneChan)

// Save received unreachable messages to a variable
var lastUnreachMsg *UnreachableNotification
go func() {
for {
select {
case <-timeout.Done():
return
case unreach := <-unreach2chan:
lastUnreachMsg = &unreach
}
<-timeout.Done()
close(doneChan)
}()
go func() {
for unreach := range unreach2chan {
unreach := unreach
lastUnreachMsg = &unreach
}
}()

Expand Down Expand Up @@ -715,18 +716,19 @@ func TestAllowedPeers(t *testing.T) {
}

// Subscribe for unreachable messages
unreach2chan := pc2.SubscribeUnreachable()
doneChan := make(chan struct{})
unreach2chan := pc2.SubscribeUnreachable(doneChan)

// Save received unreachable messages to a variable
var lastUnreachMsg *UnreachableNotification
go func() {
for {
select {
case <-timeout.Done():
return
case unreach := <-unreach2chan:
lastUnreachMsg = &unreach
}
<-timeout.Done()
close(doneChan)
}()
go func() {
for unreach := range unreach2chan {
unreach := unreach
lastUnreachMsg = &unreach
}
}()

Expand Down
Loading

0 comments on commit bfea897

Please sign in to comment.