diff --git a/go.mod b/go.mod index 2e3333acc16..4a05d30857a 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/aead/siphash v1.0.1 // indirect github.com/decred/dcrd/crypto/blake256 v1.0.0 // indirect github.com/golang/snappy v0.0.4 // indirect + github.com/gorilla/websocket v1.5.0 // indirect github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect diff --git a/go.sum b/go.sum index 1e39ef32632..e77dfa2f5f5 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= diff --git a/rpcclient/chain_test.go b/rpcclient/chain_test.go index e32d547ce3b..749277d2a23 100644 --- a/rpcclient/chain_test.go +++ b/rpcclient/chain_test.go @@ -1,6 +1,15 @@ package rpcclient -import "testing" +import ( + "github.com/gorilla/websocket" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +var upgrader = websocket.Upgrader{} // TestUnmarshalGetBlockChainInfoResult ensures that the SoftForks and // UnifiedSoftForks fields of GetBlockChainInfoResult are properly unmarshaled @@ -90,3 +99,122 @@ func TestUnmarshalGetBlockChainInfoResultSoftForks(t *testing.T) { } } } + +func TestClientConnectedToWSServerRunner(t *testing.T) { + type TestTableItem struct { + Name string + TestCase func(t *testing.T) + } + + testTable := []TestTableItem{ + TestTableItem{ + Name: "TestGetChainTxStatsAsyncSuccessTx", + TestCase: func(t *testing.T) { + client, serverReceivedChannel, cleanup := makeClient(t) + defer cleanup() + client.GetChainTxStatsAsync() + + message := <-serverReceivedChannel + if message != "{\"jsonrpc\":\"1.0\",\"method\":\"getchaintxstats\",\"params\":[],\"id\":1}" { + t.Fatalf("received unexpected message: %s", message) + } + }, + }, + TestTableItem{ + Name: "TestGetChainTxStatsAsyncShutdownError", + TestCase: func(t *testing.T) { + client, _, cleanup := makeClient(t) + defer cleanup() + + // a bit of a hack here: since there are multiple places where we read + // from the shutdown channel, and it is not buffered, ensure that a shutdown + // message is sent every time it is read from, this will ensure that + // when client.GetChainTxStatsAsync() gets called, it hits the non-blocking + // read from the shutdown channel + go func() { + type shutdownMessage struct{} + for { + client.shutdown <- shutdownMessage{} + } + }() + + var response *Response = nil + + for response == nil { + respChan := client.GetChainTxStatsAsync() + select { + case response = <-respChan: + default: + } + } + + if response.err == nil || response.err.Error() != "the client has been shutdown" { + t.Fatalf("unexpected error: %s", response.err.Error()) + } + }, + }, + } + + // since these tests rely on concurrency, ensure there is a resonable timeout + // that they should run within + for _, testCase := range testTable { + done := make(chan bool) + + go func() { + t.Run(testCase.Name, testCase.TestCase) + done <- true + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatalf("timeout exceeded for: %s", testCase.Name) + } + } +} + +func makeClient(t *testing.T) (*Client, chan string, func()) { + serverReceivedChannel := make(chan string) + s := httptest.NewServer(http.HandlerFunc(makeUpgradeOneConnect(serverReceivedChannel))) + url := strings.TrimPrefix(s.URL, "http://") + + config := ConnConfig{ + DisableTLS: true, + User: "username", + Pass: "password", + Host: url, + } + + client, err := New(&config, nil) + if err != nil { + t.Fatalf("error when creating new client %s", err.Error()) + } + return client, serverReceivedChannel, func() { + s.Close() + } +} + +func makeUpgradeOneConnect(ch chan string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + + go func() { + ch <- string(message) + }() + + err = c.WriteMessage(mt, []byte("blahhhhhh")) + if err != nil { + break + } + } + } +}