diff --git a/cmd/discovery/main_test.go b/cmd/discovery/main_test.go index 1cf83609..6397df64 100644 --- a/cmd/discovery/main_test.go +++ b/cmd/discovery/main_test.go @@ -7,6 +7,7 @@ package main_test import ( + "context" "errors" "fmt" "io/ioutil" @@ -58,7 +59,7 @@ var _ = Describe("Main", func() { }) Context("all required parameters are specified", func() { AfterEach(func() { - _, _, err := cmder.CallCMD([]string{fmt.Sprintf("rm %s", path)}, "./") + _, _, err := cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", path)}, "./") Expect(err).NotTo(HaveOccurred()) }) It("succeeds", func() { @@ -77,7 +78,7 @@ var _ = Describe("Main", func() { Context("one of the required parameters is missing", func() { Context("when no frontendURL is defined", func() { AfterEach(func() { - _, _, err := cmder.CallCMD([]string{fmt.Sprintf("rm %s", path)}, "./") + _, _, err := cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", path)}, "./") Expect(err).NotTo(HaveOccurred()) }) It("returns an error", func() { diff --git a/cmd/ephemeral/main_test.go b/cmd/ephemeral/main_test.go index 6030723b..402473b1 100644 --- a/cmd/ephemeral/main_test.go +++ b/cmd/ephemeral/main_test.go @@ -7,6 +7,7 @@ package main_test import ( + "context" "fmt" "io/ioutil" "math/rand" @@ -43,7 +44,7 @@ var _ = Describe("Main", func() { path = fmt.Sprintf("/tmp/test-%d", random) }) AfterEach(func() { - _, _, err := cmder.CallCMD([]string{fmt.Sprintf("rm %s", path)}, "./") + _, _, err := cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", path)}, "./") Expect(err).NotTo(HaveOccurred()) }) Context("when it succeeds", func() { diff --git a/pkg/ephemeral/fake_spdz_test.go b/pkg/ephemeral/fake_spdz_test.go index 83144640..6763ae4c 100644 --- a/pkg/ephemeral/fake_spdz_test.go +++ b/pkg/ephemeral/fake_spdz_test.go @@ -7,6 +7,7 @@ package ephemeral import ( + "context" "errors" "github.com/carbynestack/ephemeral/pkg/discovery/fsm" pb "github.com/carbynestack/ephemeral/pkg/discovery/transport/proto" @@ -93,14 +94,14 @@ func (f *FakePlayer) PublishEvent(name, topic string, event *pb.Event) { type FakeExecutor struct { } -func (f *FakeExecutor) CallCMD(cmd []string, dir string) ([]byte, []byte, error) { +func (f *FakeExecutor) CallCMD(theContext context.Context, cmd []string, dir string) ([]byte, []byte, error) { return []byte{}, []byte{}, nil } type BrokenFakeExecutor struct { } -func (f *BrokenFakeExecutor) CallCMD(cmd []string, dir string) ([]byte, []byte, error) { +func (f *BrokenFakeExecutor) CallCMD(theContext context.Context, cmd []string, dir string) ([]byte, []byte, error) { return []byte{}, []byte{}, errors.New("some error") } diff --git a/pkg/ephemeral/io/carrier.go b/pkg/ephemeral/io/carrier.go index 7c8ef18d..d6ea0bdd 100644 --- a/pkg/ephemeral/io/carrier.go +++ b/pkg/ephemeral/io/carrier.go @@ -8,8 +8,11 @@ package io import ( "context" + "encoding/binary" "errors" + "fmt" "github.com/carbynestack/ephemeral/pkg/amphora" + "io" "io/ioutil" "net" ) @@ -21,7 +24,7 @@ type Result struct { // AbstractCarrier is the carriers interface. type AbstractCarrier interface { - Connect(context.Context, string, string) error + Connect(int32, context.Context, string, string) error Close() error Send([]amphora.SecretShare) error Read(ResponseConverter, bool) (*Result, error) @@ -29,10 +32,11 @@ type AbstractCarrier interface { // Carrier is a TCP client for TCP sockets. type Carrier struct { - Dialer func(ctx context.Context, addr, port string) (net.Conn, error) - Conn net.Conn - Packer Packer - connected bool + Dialer func(ctx context.Context, addr, port string) (net.Conn, error) + TlsConnector func(conn net.Conn, playerID int32) (net.Conn, error) + Conn net.Conn + Packer Packer + connected bool } // Config contains TCP connection properties of Carrier. @@ -42,16 +46,55 @@ type Config struct { } // Connect establishes a TCP connection to a socket on a given host and port. -func (c *Carrier) Connect(ctx context.Context, host, port string) error { +func (c *Carrier) Connect(playerID int32, ctx context.Context, host string, port string) error { conn, err := c.Dialer(ctx, host, port) if err != nil { return err } - c.Conn = conn + + _, err = conn.Write(c.buildHeader(playerID)) + if err != nil { + return err + } + + c.Conn, err = c.TlsConnector(conn, playerID) + if err != nil { + return err + } + + if playerID == 0 { + err = c.readSpec() + if err != nil { + return err + } + } + c.connected = true return nil } +func (c Carrier) readSpec() error { + const size = 4 + + readBytes := make([]byte, size) + _, err := io.LimitReader(c.Conn, size).Read(readBytes) + if err != nil { + return err + } + + sizeOfHeader := binary.LittleEndian.Uint32(readBytes) + + readBytes = make([]byte, sizeOfHeader) + _, err = io.LimitReader(c.Conn, int64(sizeOfHeader)).Read(readBytes) + if err != nil { + return err + } + + //ToDo, compare read PRIME with prime number from config? + + return nil +} + // Close closes the underlying TCP connection. func (c *Carrier) Close() error { if c.connected { @@ -68,16 +111,31 @@ func (c *Carrier) Send(secret []amphora.SecretShare) error { shares = append(shares, secret[i].Data) } err := c.Packer.Marshal(shares, &input) + if err != nil { return err } _, err = c.Conn.Write(input) + if err != nil { return err } return nil } +// Returns a new Slice with the header appended +// The header consists of the clientId as string: +// - 1 Long (4 Byte) that contains the length of the string in bytes +// - Then come X Bytes for the String +func (c *Carrier) buildHeader(playerId int32) []byte { + playerIdString := []byte(fmt.Sprintf("%d", playerId)) + + lengthOfString := make([]byte, 4) + binary.LittleEndian.PutUint32(lengthOfString, uint32(len(playerIdString))) + + return append(lengthOfString, playerIdString...) +} + // Read reads the response from the TCP connection and unmarshals it. func (c *Carrier) Read(conv ResponseConverter, bulkObjects bool) (*Result, error) { resp := []byte{} diff --git a/pkg/ephemeral/io/carrier_test.go b/pkg/ephemeral/io/carrier_test.go index 5ed31b52..53f26aff 100644 --- a/pkg/ephemeral/io/carrier_test.go +++ b/pkg/ephemeral/io/carrier_test.go @@ -9,17 +9,18 @@ package io_test import ( "context" "fmt" - "net" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "github.com/carbynestack/ephemeral/pkg/amphora" . "github.com/carbynestack/ephemeral/pkg/ephemeral/io" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "net" + "sync" ) var _ = Describe("Carrier", func() { var ctx = context.TODO() + var playerId = int32(1) // PlayerID 1, since PlayerID==0 contains another check when connecting + It("connects to a socket", func() { var connected bool conn := FakeNetConnection{} @@ -27,10 +28,14 @@ var _ = Describe("Carrier", func() { connected = true return &conn, nil } + fakeTlsConnector := func(connection net.Conn, playerID int32) (net.Conn, error) { + return connection, nil + } carrier := Carrier{ - Dialer: fakeDialer, + Dialer: fakeDialer, + TlsConnector: fakeTlsConnector, } - err := carrier.Connect(context.TODO(), "", "") + err := carrier.Connect(playerId, context.TODO(), "", "") Expect(connected).To(BeTrue()) Expect(err).NotTo(HaveOccurred()) }) @@ -39,10 +44,14 @@ var _ = Describe("Carrier", func() { fakeDialer := func(ctx context.Context, addr, port string) (net.Conn, error) { return &conn, nil } + fakeTlsConnector := func(connection net.Conn, playerID int32) (net.Conn, error) { + return connection, nil + } carrier := Carrier{ - Dialer: fakeDialer, + Dialer: fakeDialer, + TlsConnector: fakeTlsConnector, } - err := carrier.Connect(context.TODO(), "", "") + err := carrier.Connect(playerId, context.TODO(), "", "") Expect(err).NotTo(HaveOccurred()) err = carrier.Close() Expect(err).NotTo(HaveOccurred()) @@ -50,20 +59,26 @@ var _ = Describe("Carrier", func() { }) var ( - secret []amphora.SecretShare - output []byte - client, server net.Conn - dialer func(ctx context.Context, addr, port string) (net.Conn, error) + secret []amphora.SecretShare + output []byte + connectionOutput []byte //Will contain (length 4 byte, playerId 1 byte) + client, server net.Conn + dialer func(ctx context.Context, addr, port string) (net.Conn, error) + fakeTlsConnector func(conn net.Conn, playerID int32) (net.Conn, error) ) BeforeEach(func() { secret = []amphora.SecretShare{ amphora.SecretShare{}, } output = make([]byte, 1) + connectionOutput = make([]byte, 5) client, server = net.Pipe() dialer = func(ctx context.Context, addr, port string) (net.Conn, error) { return client, nil } + fakeTlsConnector = func(connection net.Conn, playerID int32) (net.Conn, error) { + return connection, nil + } }) Context("when sending secret shares through the carrier", func() { It("sends an amphora secret to the socket", func() { @@ -72,23 +87,28 @@ var _ = Describe("Carrier", func() { MarshalResponse: serverResponse, } carrier := Carrier{ - Dialer: dialer, - Packer: packer, + Dialer: dialer, + Packer: packer, + TlsConnector: fakeTlsConnector, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(playerId, ctx, "", "") go server.Read(output) err := carrier.Send(secret) carrier.Close() Expect(err).NotTo(HaveOccurred()) Expect(output[0]).To(Equal(byte(1))) + Expect(connectionOutput).To(Equal([]byte{1, 0, 0, 0, fmt.Sprintf("%d", playerId)[0]})) }) It("returns an error when it fails to marshal the object", func() { packer := &FakeBrokenPacker{} carrier := Carrier{ - Dialer: dialer, - Packer: packer, + Dialer: dialer, + Packer: packer, + TlsConnector: fakeTlsConnector, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(playerId, ctx, "", "") go server.Read(output) err := carrier.Send(secret) carrier.Close() @@ -100,10 +120,12 @@ var _ = Describe("Carrier", func() { MarshalResponse: serverResponse, } carrier := Carrier{ - Dialer: dialer, - Packer: packer, + Dialer: dialer, + Packer: packer, + TlsConnector: fakeTlsConnector, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(playerId, ctx, "", "") // Closing the connection to trigger a failure due to writing into the closed socket. server.Close() err := carrier.Send(secret) @@ -120,10 +142,12 @@ var _ = Describe("Carrier", func() { UnmarshalResponse: []string{packerResponse}, } carrier := Carrier{ - Dialer: dialer, - Packer: &packer, + Dialer: dialer, + Packer: &packer, + TlsConnector: fakeTlsConnector, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(playerId, ctx, "", "") go func() { server.Write(serverResponse) server.Close() @@ -140,10 +164,12 @@ var _ = Describe("Carrier", func() { UnmarshalResponse: []string{packerResponse}, } carrier := Carrier{ - Dialer: dialer, - Packer: &packer, + Dialer: dialer, + Packer: &packer, + TlsConnector: fakeTlsConnector, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(playerId, ctx, "", "") server.Close() anyConverter := &PlaintextConverter{} _, err := carrier.Read(anyConverter, false) @@ -153,10 +179,12 @@ var _ = Describe("Carrier", func() { serverResponse := []byte{byte(1)} packer := &FakeBrokenPacker{} carrier := Carrier{ - Dialer: dialer, - Packer: packer, + Dialer: dialer, + Packer: packer, + TlsConnector: fakeTlsConnector, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(playerId, ctx, "", "") go func() { server.Write(serverResponse) server.Close() @@ -166,4 +194,44 @@ var _ = Describe("Carrier", func() { Expect(err).To(HaveOccurred()) }) }) + + Context("when connecting as Player0", func() { + playerId := int32(0) + It("will receive and handle the server's fileHeader", func() { + // Arrange + // ToDo: Better Response for real-life scenario? + serverResponse := []byte{1, 0, 0, 0, 1} // 4 byte length + header, in this case "1". In real case Descriptor + Prime + packer := &FakeBrokenPacker{} + carrier := Carrier{ + Dialer: dialer, + Packer: packer, + TlsConnector: fakeTlsConnector, + } + + waitGroup := sync.WaitGroup{} + waitGroup.Add(1) + + go server.Read(connectionOutput) + + // Act + var errConnecting error + go func() { + errConnecting = carrier.Connect(playerId, ctx, "", "") + waitGroup.Done() + }() + + numberOfBytesWritten, errWrite := server.Write(serverResponse) + errClose := server.Close() + + // Make sure we wait until the Connect and Write are done + waitGroup.Wait() + + // Assert + Expect(connectionOutput).To(Equal([]byte{1, 0, 0, 0, fmt.Sprintf("%d", playerId)[0]})) + Expect(errConnecting).NotTo(HaveOccurred()) + Expect(errWrite).NotTo(HaveOccurred()) + Expect(numberOfBytesWritten).To(Equal(len(serverResponse))) + Expect(errClose).NotTo(HaveOccurred()) + }) + }) }) diff --git a/pkg/ephemeral/io/feeder.go b/pkg/ephemeral/io/feeder.go index d9082416..a3a2efbc 100644 --- a/pkg/ephemeral/io/feeder.go +++ b/pkg/ephemeral/io/feeder.go @@ -35,6 +35,7 @@ func NewAmphoraFeeder(l *zap.SugaredLogger, conf *SPDZEngineTypedConfig) *Amphor Packer: &SPDZPacker{ MaxBulkSize: conf.MaxBulkSize, }, + TlsConnector: network.NewTlsConnector(), } return &AmphoraFeeder{ logger: l, @@ -118,7 +119,7 @@ func (f *AmphoraFeeder) feedAndRead(params []string, port string, ctx *CtxConfig default: return nil, fmt.Errorf("no output config is given, either %s, %s or %s must be defined", PlainText, SecretShare, AmphoraSecret) } - err := f.carrier.Connect(ctx.Context, "localhost", port) + err := f.carrier.Connect(ctx.Spdz.PlayerID, ctx.Context, "localhost", port) defer f.carrier.Close() if err != nil { return nil, err diff --git a/pkg/ephemeral/io/feeder_test.go b/pkg/ephemeral/io/feeder_test.go index e40e7865..83325b6e 100644 --- a/pkg/ephemeral/io/feeder_test.go +++ b/pkg/ephemeral/io/feeder_test.go @@ -211,7 +211,7 @@ type FakeCarrier struct { isBulk bool } -func (f *FakeCarrier) Connect(context.Context, string, string) error { +func (f *FakeCarrier) Connect(int32, context.Context, string, string) error { return nil } @@ -232,7 +232,7 @@ type BrokenConnectFakeCarrier struct { isBulk bool } -func (f *BrokenConnectFakeCarrier) Connect(context.Context, string, string) error { +func (f *BrokenConnectFakeCarrier) Connect(int32, context.Context, string, string) error { return errors.New("carrier connect error") } @@ -253,7 +253,7 @@ type BrokenSendFakeCarrier struct { isBulk bool } -func (f *BrokenSendFakeCarrier) Connect(context.Context, string, string) error { +func (f *BrokenSendFakeCarrier) Connect(int32, context.Context, string, string) error { return nil } diff --git a/pkg/ephemeral/network/tls_connector.go b/pkg/ephemeral/network/tls_connector.go new file mode 100644 index 00000000..7c8cd7a8 --- /dev/null +++ b/pkg/ephemeral/network/tls_connector.go @@ -0,0 +1,46 @@ +package network + +import ( + "crypto/tls" + "fmt" + "net" +) + +func NewTlsConnector() func(conn net.Conn, playerID int32) (net.Conn, error) { + return NewTlsConnectorWithPath("Player-Data") +} + +func NewTlsConnectorWithPath(folderPath string) func(conn net.Conn, playerID int32) (net.Conn, error) { + + return func(conn net.Conn, playerID int32) (net.Conn, error) { + tlsConfig, err := getTlsConfig(playerID, folderPath) + if err != nil { + return nil, err + } + + tlsClient := tls.Client(conn, tlsConfig) + err = tlsClient.Handshake() + if err != nil { + return nil, err + } + + return net.Conn(tlsClient), nil + } +} + +func getTlsConfig(playerID int32, folder string) (*tls.Config, error) { + + certFile := fmt.Sprintf("%s/C%d.pem", folder, playerID) + keyFile := fmt.Sprintf("%s/C%d.key", folder, playerID) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: true, + } + + return tlsConfig, nil +} diff --git a/pkg/ephemeral/network/tls_connector_test.go b/pkg/ephemeral/network/tls_connector_test.go new file mode 100644 index 00000000..bddc0cc0 --- /dev/null +++ b/pkg/ephemeral/network/tls_connector_test.go @@ -0,0 +1,232 @@ +package network + +import ( + "crypto/tls" + "fmt" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "io/ioutil" + "net" + "os" +) + +const ( + keyFileClient = `-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCvx2eeVDXG5R+l +GlslnYNHlJgmmkLeXn5MT18qTbq3MCpB6o4rd8I2a1D/uFUht13Ourj7zilKz/5W +jcTnoVG7fiCLcBj3tXvCL5ymOGxxmQeN5siJcefpB8kcSB4RkrON9y6HCpZSIOMv +vfSVMrVMrQj/rjqsO2/Vv1A+4nETJm3GqKfwSikhgNsVcqiHYGkg0d1/3zP8CTAQ ++lp92LijeJAMCyNyHm/A+Wya3g8heRbm6lPtZWUcPOfyn3FGQ+Pu9MbBrQcPbXPW +0sjtGoBweNLYYyns3yViSp7gyOnZWaAwnQtA1T7PPNGkOYp5ehI3gA4bhhCWxbkN +ZVy0qajhAgMBAAECggEAJdsJ4706/6SklggBDS7I8Qd9ZQLf18f95y1Iz3GB/qWu +1BdRmublupaOESR/oQ0+dKEd6YzSs7vriHRrrX6+fWSCWcVAe0hoaL+cOuf34tcU +G2lSUtdnHHaCx0Z4w0wWw0IykP6ktPdENinwnJkZFnRFddrt493BDgVvoLtfosHO +Q+CcX6SmjfS3i0GSsDbI1sBAtH9vP+cCJeXWYtVcPRX9zoX3oYY9zBxuuiarcZku +3mcx22WFi4t30o2jCFwshhjY3W5mxZ3icCZ/mO/BS8FOYk4+BJUQtlxhDvJSjg/u +jCmmFi6WwtceKEhSL6IyiRFLzec60ITlR9U9YB/UqQKBgQDl6sMr/++hzQvOv58c +zoOfBKejHao7Bx9MkFLtQ4KXf4Ypc2uZh/XenziBb+tKRJ5mSXV8NLHs/zrdxPeY +ps0AYkWl9xVR1hKYlnQ75DCbs6zkIEKbKZ1xq5X1TfAmIyHmUcttD5BvQLAeQyG3 ++iNo2yFUgg6BywS4E6biL40zkwKBgQDDuF3FW2K5Ms5ntw/o/d55scinx05C74D6 +Oy+HesRs6bg77R07fr9Xqgnawqpn2Jk9TRFL5yVJTEHcXH9xMzHgNQ128SGNnDtC +T5/jfalj92hjdmt/gwdGK6PN+IDgb3h3vMnQZszK4zhXP78nte1QGUx2W7TZ7ZrP +C+iulm3iOwKBgQDbkkQqNRYpM6VfIWlXHXJd3xgpkx8LmFWvzPUlWh/RhxwdYfkU +et+4Z96S3suZ9cZAcU8d+0UgzO7u9DhxNHr7Lt7NDRbzPLottyHyQI6bZBBtHNH/ +VNLjx7ZCutfp1At/5gWcdgy98s0/WWVOSjie3wcJqdso4TX0hfAOetMiuQKBgDri +C+wla1U2kNypObMqNbW9JBY+IzCGJ/KgvdLvv4rY4iG9W68bmeuA78gOCwCFLM1B +k3OXjiM4OxRWC819zoKa03s2XpbhKv7vP7ZMhxrZQ2GxLfRF8nlNBdIg8n0TbFXx +yXHWi8R6iefN+O+0jzoq8lMlkgqCrrGd7pogDd0jAoGBALK43xm6ZIx5f6Ko94Vk +quXurZhmfbwiU52hBOdej6T+w2axs+mne83/HpcnWNtsmQDPN7vsfnKH/Ny4dG87 +G0iQcIEfW6OCGn1N6mr9ch7+2ihszOlKomOBxLurzw3Y7b3z0k9i1+NXeVY9agwF +U5QapxH75EeTq2YKGRjcN100 +-----END PRIVATE KEY----- +` + pemFileClient = `-----BEGIN CERTIFICATE----- +MIIC+zCCAeOgAwIBAgIUWOrYZliAZd4NDKJBNkYsOqSCj5owDQYJKoZIhvcNAQEL +BQAwDTELMAkGA1UEAwwCQzAwHhcNMjIwMTI2MTI0NjQ5WhcNMjIwMjI1MTI0NjQ5 +WjANMQswCQYDVQQDDAJDMDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +AK/HZ55UNcblH6UaWyWdg0eUmCaaQt5efkxPXypNurcwKkHqjit3wjZrUP+4VSG3 +Xc66uPvOKUrP/laNxOehUbt+IItwGPe1e8IvnKY4bHGZB43myIlx5+kHyRxIHhGS +s433LocKllIg4y+99JUytUytCP+uOqw7b9W/UD7icRMmbcaop/BKKSGA2xVyqIdg +aSDR3X/fM/wJMBD6Wn3YuKN4kAwLI3Ieb8D5bJreDyF5FubqU+1lZRw85/KfcUZD +4+70xsGtBw9tc9bSyO0agHB40thjKezfJWJKnuDI6dlZoDCdC0DVPs880aQ5inl6 +EjeADhuGEJbFuQ1lXLSpqOECAwEAAaNTMFEwHQYDVR0OBBYEFCXac7qi2TG+j/CQ +fVyvM6W3JfONMB8GA1UdIwQYMBaAFCXac7qi2TG+j/CQfVyvM6W3JfONMA8GA1Ud +EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAE0xk3rMO3xmpq1mwWGGQ/B2 +J9Xlqf5qwr63MNz6aIcKrlyk2+OLLaDm8RrF7wFNQ+uvMKKg6bLF7jW7MAX9WMO7 +giiT5ySjxddDT0cbSA3HcG3Ria9P6c02VZVt057M1FzXweR/FiJA1Tocn43lXrBT +n2sAiRtO4sxbfhUdIJI1Vh7UUhyAJLe3lVcG/AMMmPG/IedguhMbdalm5/gEaIIc +LjHyQLPWzHQTiUvj+AjpTmCN+3ZbBS/8r4g7XJ7/zvawXxi1Lk9fvSGWGkQLwHJ0 +DupEw8GWmc9H0cyY93qtEqKLQPvEDDdvhPoENcf/P6/BD1Z8lMmSMvZ+s6M7VfQ= +-----END CERTIFICATE----- +` + + keyFileServer = `-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCwdCqmEjiMVBPK +m31IG+I+xwgX+EnEpnrBnlOa0WhFzaMrwqXpijgMA+dLNR8a1zpyhglWBsqm8dpN +7tEV19piizOmxZtZee7h1Hdso/+4U106NqzX5HKwuqZVSOjVN29SFKq0sNricIX1 +HabE5LYyBQJtzMzxAZwclb+e7uGBfHJDsOk3hOhs3bkJyV3eRa0uHH2Bu4CPH6L9 +bcisFCmHiykZeZaY+BpRkS0c5+h7umLrKSGUe37/vf9UY9niLDUNolHePS/iQnmb +Hv1l/mDl2LvNy5OSCSvOE6L0GMCUDnYmYf6F999LLdQgC7gcZCp3rujZ4MYUsqR2 +Nqp46LdbAgMBAAECggEBAJ6ViM8AiTn1RmRNImdwSAHLtwZz6ziFtsXUmacGlQRH +MGLf6WTfCEgkKfd5op7o2Gqc9D8Qk4k+y8hG3jsXZ/owyRcVee0MnRjxbvOA4Q60 +PZFYGjdd5YXX+i2j/T3DOJU4ZcNHPzFLl9kX8Q37z5Nc1TYBXh8sJzW5kCIy5xEL +XAKNcwGTZF1ml3jkWkFl3LukS3DP8fF1qDvD987YGuc9oVliYW1F0oKL9VGyS7nB +BtQWslFdP8MbPXG1hjkFydCBiE4teqrFen6hvLdIQk7XJ88Q9UmBoOPJr6+gHuDf +vk33nVGpBVQ1UHFPnDzZyQKtlDBVEUJ8XhqzEkm4asECgYEA5sTgxJt/nJCL9Lh1 +61jFbVD21SVFEv7IWIV6YjBxzJhzGVJa6ZhrOnRTrkTAkraJ1wUd9FyIdEsL/Nvy +/z8hOAXbty1zXdpOo/BV0J6zRwJ0Cj8WVTeCUr5KGgw/pzbQdltJ+1J8jnAGbZjN +Ri/QUdryqZTQz3rD8sDVDFvLojMCgYEAw78HQ+y/gL5Z/IJww0lUYHjqcm1G5taY +3Ht6qRvkqdCmW8qC2wpKFKl9lCJfo+H1jidjhM5RTPFSlCxiWtxLAamMvfv0f3d6 +q5gPjcjak275bnmU1e0blkLEdeXQljRXH+oDmur95udzh0DrdTDJ/lqbf3uui8Uc +VApAcSbR/jkCgYAWUT/zg55Jw+jlF9m/kuw08DmOz3Xoql8xwGbfjBPVV4D6F+7W +3HiyRIG7Psbo6WJXOxV0hmZj6MYWBCdx6+cIhfiDtI+Nqgkk7Z8+97oaye/y9brx +LtcZrXF5J2oYf8KVT6rN9WI6XDci7j4b5Y/d+rCxGcU/6317wo5YDaCZ5QKBgET3 +4qRxHwxKhUQt5XM5PAx9rgVBMXEV/Wf57b71v/yBMow27yIkHvPmwANYlSAV9kHu +6OabFxQoFvN0K/ddlOPyDE/IHV5oB4W8HwbS1QiLWkEtf15cm5K21afAoFy79lKd +TkXgNDOOKytlmVCCLzl6TT1+o4JFofSOZCQ6DFUpAoGAAQdaCjX5UCeWb5e/Vbiu +SQL1RKIHkgm6gj1UjlQ981r6y+hVkBygtIr/eW0wSkFAkUrOdefHNVOQW18ESF06 +YqBL4gD7aEij9kGd0PrievimgcYYaBHOcO1RouQOURTMmWqjIPu1fyWDv+rFk+S5 +2uCuYndpzOgCiEhjDGCuSug= +-----END PRIVATE KEY-----` + + pemFileServer = `-----BEGIN CERTIFICATE----- +MIIC+zCCAeOgAwIBAgIUNI9WRun2Y+ICmpzjYRpVcJ/BBE4wDQYJKoZIhvcNAQEL +BQAwDTELMAkGA1UEAwwCUDAwHhcNMjIwMTI2MTI0NjQ5WhcNMjIwMjI1MTI0NjQ5 +WjANMQswCQYDVQQDDAJQMDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +ALB0KqYSOIxUE8qbfUgb4j7HCBf4ScSmesGeU5rRaEXNoyvCpemKOAwD50s1HxrX +OnKGCVYGyqbx2k3u0RXX2mKLM6bFm1l57uHUd2yj/7hTXTo2rNfkcrC6plVI6NU3 +b1IUqrSw2uJwhfUdpsTktjIFAm3MzPEBnByVv57u4YF8ckOw6TeE6GzduQnJXd5F +rS4cfYG7gI8fov1tyKwUKYeLKRl5lpj4GlGRLRzn6Hu6YuspIZR7fv+9/1Rj2eIs +NQ2iUd49L+JCeZse/WX+YOXYu83Lk5IJK84TovQYwJQOdiZh/oX330st1CALuBxk +Kneu6NngxhSypHY2qnjot1sCAwEAAaNTMFEwHQYDVR0OBBYEFDjtm5a7RbAFeYuQ +QfFYci+eTOeXMB8GA1UdIwQYMBaAFDjtm5a7RbAFeYuQQfFYci+eTOeXMA8GA1Ud +EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGo1n03gEYMsBLLaOcY7dDwn +behhLE7UP3eWRw2gpmbKfilk+dljYWsOdiQeXktE/LxyFiuBNwefI7JrypFifzio +udqYyQAJ2pvMogij+TPajaDhJxmMWqRizcAo/6cXekSCufnRbbTBENUG2ZNHRuyn +zsYFZtpxDO9LF0uutE2P6NJQpKKrCo/NGMV4AF0vy1tKp6h2fBU3K9Yn+1RihIyS +Y+sLoNiorJloqZ8qn2cULbax/xi/IcccdRJfoIjmIuSl9wUwl+lkeGB9Rlwm5iFJ +LO+mQ15hUEpbjrXF3IdY+4MjDqFOETC0KuI72yjUGPZqWe+WAhBcni3VNzs2Ik4= +-----END CERTIFICATE-----` +) + +var _ = Describe("TlsConnector", func() { + var testDataFolder string + var certificateFolder string + var playerID = int32(0) + + BeforeEach(func() { + var err error + testDataFolder, err = ioutil.TempDir("", "testData") + certificateFolder = testDataFolder + "Player-Data" + err = os.Mkdir(certificateFolder, os.ModeDir) + if err != nil { + panic(err) + } + + err = ioutil.WriteFile(fmt.Sprintf("%s/C%d.pem", certificateFolder, playerID), []byte(pemFileClient), os.ModePerm) + if err != nil { + panic(err) + } + + err = ioutil.WriteFile(fmt.Sprintf("%s/C%d.key", certificateFolder, playerID), []byte(keyFileClient), os.ModePerm) + if err != nil { + panic(err) + } + + err = ioutil.WriteFile(fmt.Sprintf("%s/P%d.pem", certificateFolder, playerID), []byte(pemFileServer), os.ModePerm) + if err != nil { + panic(err) + } + + err = ioutil.WriteFile(fmt.Sprintf("%s/P%d.key", certificateFolder, playerID), []byte(keyFileServer), os.ModePerm) + if err != nil { + panic(err) + } + }) + + AfterEach(func() { + err := os.RemoveAll(testDataFolder) + if err != nil { + panic(err) + } + }) + + Context("when trying to upgrade to a TLS connection", func() { + + var ( + tlsConnector func(conn net.Conn, playerID int32) (net.Conn, error) + client, server net.Conn + ) + + BeforeEach(func() { + tlsConnector = NewTlsConnectorWithPath(certificateFolder) + client, server = net.Pipe() + }) + + It("establishes a TLS Connection and allows to send something over the connection", func() { + // Arrange + serverPemFileLocation := fmt.Sprintf("%s/P%d.pem", certificateFolder, playerID) + serverKeyFileLocation := fmt.Sprintf("%s/P%d.key", certificateFolder, playerID) + serverCertificate, err := tls.LoadX509KeyPair(serverPemFileLocation, serverKeyFileLocation) + + if err != nil { + panic(err) + } + + serverConfig := &tls.Config{ + Certificates: []tls.Certificate{serverCertificate}, + } + + serverTlsConnection := tls.Server(server, serverConfig) + go serverTlsConnection.Handshake() + + // Act + tlsConnection, err := tlsConnector(client, playerID) + + contentToSend := []byte{byte(1)} + go tlsConnection.Write(contentToSend) + + contentToReceive := make([]byte, 1) + serverTlsConnection.Read(contentToReceive) + + // Assert + Expect(err).NotTo(HaveOccurred()) + Expect(tlsConnection).ToNot(BeNil()) + Expect(contentToReceive).To(Equal(contentToSend)) + }) + + Context("and no certificate files for the playerID exist", func() { + playerID := int32(1) + + It("errors when trying to load the certificate key pair", func() { + // Act + tlsConnection, err := tlsConnector(client, playerID) + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("C1.pem")) + Expect(tlsConnection).To(BeNil()) + }) + }) + + Context("and the server does not have the matching certificate", func() { + playerID := int32(0) + It("will throw a TLS Error", func() { + // Arrange + serverConfig := &tls.Config{ + //No Server Certificates -> Client certificate won't match + Certificates: []tls.Certificate{}, + } + + serverTlsConnection := tls.Server(server, serverConfig) + go serverTlsConnection.Handshake() + + // Act + tlsConnection, err := tlsConnector(client, playerID) + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("remote error: tls: unrecognized name")) + Expect(tlsConnection).To(BeNil()) + }) + }) + }) +}) diff --git a/pkg/ephemeral/player.go b/pkg/ephemeral/player.go index 07cde86a..2f755522 100644 --- a/pkg/ephemeral/player.go +++ b/pkg/ephemeral/player.go @@ -223,5 +223,6 @@ func (c *Callbacker) sendEvent(name, topic string, e interface{}) { }, } c.pb.PublishWithBody(name, topic, event, c.playerParams.GameID) - c.logger.Debugf("Sending event %v to topic %s\n", event.Name, topic) + c.logger.Debugw("Sending event", "event", event, "topic", topic) + c.logger.Debugf("Sending event.name %v to topic %s\n", event.Name, topic) } diff --git a/pkg/ephemeral/server.go b/pkg/ephemeral/server.go index 18bc141d..400c4a70 100644 --- a/pkg/ephemeral/server.go +++ b/pkg/ephemeral/server.go @@ -343,7 +343,7 @@ func (s *Server) getPodName() (string, error) { //use something like os.Getenv("HOST_NAME")? - name, _, err := cmder.CallCMD([]string{"hostname"}, "/") + name, _, err := cmder.CallCMD(context.TODO(), []string{"hostname"}, "/") if err != nil { return "", err } diff --git a/pkg/ephemeral/spdz.go b/pkg/ephemeral/spdz.go index 1faf729e..1dbc6f0b 100644 --- a/pkg/ephemeral/spdz.go +++ b/pkg/ephemeral/spdz.go @@ -7,6 +7,7 @@ package ephemeral import ( + "context" d "github.com/carbynestack/ephemeral/pkg/discovery" pb "github.com/carbynestack/ephemeral/pkg/discovery/transport/proto" . "github.com/carbynestack/ephemeral/pkg/ephemeral/io" @@ -233,9 +234,9 @@ func (s *SPDZEngine) Compile(ctx *CtxConfig) error { var stdoutSlice []byte var stderrSlice []byte - command := fmt.Sprintf("./compile.py %s", appName) - stdoutSlice, stderrSlice, err = s.cmder.CallCMD([]string{command}, s.baseDir) - + command := fmt.Sprintf("./compile.py -M %s", appName) + // TODO: ctx.context is nil at this time. + stdoutSlice, stderrSlice, err = s.cmder.CallCMD(context.TODO(), []string{command}, s.baseDir) stdOut := string(stdoutSlice) stdErr := string(stderrSlice) s.logger.Debugw("Compiled Successfully", "Command", command, "StdOut", stdOut, "StdErr", stdErr) @@ -254,7 +255,7 @@ func (s *SPDZEngine) getFeedPort() string { func (s *SPDZEngine) startMPC(ctx *CtxConfig) { command := []string{fmt.Sprintf("./Player-Online.x %s %s -N %s --ip-file-name %s", fmt.Sprint(s.config.PlayerID), appName, fmt.Sprint(ctx.Spdz.PlayerCount), ipFile)} s.logger.Infow("Starting Player-Online.x", GameID, ctx.Act.GameID, "command", command) - stdout, stderr, err := s.cmder.CallCMD(command, s.baseDir) + stdout, stderr, err := s.cmder.CallCMD(ctx.Context, command, s.baseDir) if err != nil { err := fmt.Errorf("error while executing the user code: %v", err) ctx.ErrCh <- err diff --git a/pkg/ephemeral/spdz_test.go b/pkg/ephemeral/spdz_test.go index 373104d9..0277d1b5 100644 --- a/pkg/ephemeral/spdz_test.go +++ b/pkg/ephemeral/spdz_test.go @@ -47,7 +47,7 @@ var _ = Describe("Spdz", func() { fileName = fmt.Sprintf("/tmp/program-%d.mpc", random) }) AfterEach(func() { - cmder.CallCMD([]string{fmt.Sprintf("rm %s", fileName)}, "./") + cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", fileName)}, "./") }) Context("writing succeeds", func() { It("writes the source code on the disk and runs the compiler", func() { @@ -63,7 +63,7 @@ var _ = Describe("Spdz", func() { } err := s.Compile(conf) Expect(err).NotTo(HaveOccurred()) - out, _, err := cmder.CallCMD([]string{fmt.Sprintf("cat %s", s.sourceCodePath)}, "./") + out, _, err := cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("cat %s", s.sourceCodePath)}, "./") Expect(err).NotTo(HaveOccurred()) Expect(string(out)).To(Equal("a")) }) diff --git a/pkg/utils/os.go b/pkg/utils/os.go index f44f4d58..58478b51 100644 --- a/pkg/utils/os.go +++ b/pkg/utils/os.go @@ -8,17 +8,20 @@ package utils import ( "bytes" + "context" "errors" + "fmt" "io/ioutil" "os" "os/exec" "path/filepath" + "sync" ) // Executor is an interface for calling a command and process its output. type Executor interface { // CallCMD executes the command and returns the output's STDOUT, STDERR streams as well as any errors - CallCMD(cmd []string, dir string) ([]byte, []byte, error) + CallCMD(theContext context.Context, cmd []string, dir string) ([]byte, []byte, error) } var ( @@ -45,7 +48,7 @@ type Commander struct { // Run is a facade command that runs a single command from the current directory. func (c *Commander) Run(cmd string) ([]byte, []byte, error) { - return c.CallCMD([]string{cmd}, "./") + return c.CallCMD(context.TODO(), []string{cmd}, "./") } // CallCMD calls a specified command in sh and returns its stdout and stderr as a byte slice and potentially an error. @@ -53,10 +56,10 @@ func (c *Commander) Run(cmd string) ([]byte, []byte, error) { // ``` // If the command fails to run or doesn't complete successfully, the error is of type *ExitError. Other error types may be returned for I/O problems. // ``` -func (c *Commander) CallCMD(cmd []string, dir string) ([]byte, []byte, error) { +func (c *Commander) CallCMD(theContext context.Context, cmd []string, dir string) ([]byte, []byte, error) { baseCmd := c.Options baseCmd = append(baseCmd, cmd...) - command := exec.Command(c.Command, baseCmd...) + command := exec.CommandContext(theContext, c.Command, baseCmd...) stderrBuffer := bytes.NewBuffer([]byte{}) stdoutBuffer := bytes.NewBuffer([]byte{}) @@ -68,8 +71,23 @@ func (c *Commander) CallCMD(cmd []string, dir string) ([]byte, []byte, error) { if err != nil { return nil, nil, err } - // Check if the command finished successfully. - err = command.Wait() + + var waitGroup sync.WaitGroup + waitGroup.Add(1) + go func() { + // Check if the command finished successfully. + err = command.Wait() + defer waitGroup.Done() + + if err != nil { + println(fmt.Sprintf("Error occured!")) + println(fmt.Sprintf("StdOut: %s", stdoutBuffer.Bytes())) + println(fmt.Sprintf("StdErr: %s", stderrBuffer.Bytes())) + } + }() + + waitGroup.Wait() + if err != nil { switch err.(type) { case *exec.ExitError: diff --git a/pkg/utils/os_test.go b/pkg/utils/os_test.go index c292d0df..e137dbe7 100644 --- a/pkg/utils/os_test.go +++ b/pkg/utils/os_test.go @@ -7,6 +7,7 @@ package utils_test import ( + "context" "fmt" "io/ioutil" "math/rand" @@ -82,7 +83,7 @@ var _ = Describe("OS utils", func() { } }) AfterEach(func() { - cmder.CallCMD([]string{fmt.Sprintf("rm %s", fileName)}, "./") + cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", fileName)}, "./") }) It("reads file content", func() { data := []byte(`a`)