[tor-commits] [snowflake/master] Use gorilla websocket in proxy-go too

arlo at torproject.org arlo at torproject.org
Mon Nov 25 19:38:41 UTC 2019


commit 30b5ef8a9e9c7a5b306e9285d1a8db323f8f22b2
Author: Arlo Breault <arlolra at gmail.com>
Date:   Wed Nov 20 19:33:28 2019 -0500

    Use gorilla websocket in proxy-go too
    
    Trac: 32465
---
 common/websocketconn/websocketconn.go      | 89 +++++++++++++++++++++++++++
 common/websocketconn/websocketconn_test.go | 30 +++++++++
 proxy-go/proxy-go_test.go                  | 19 ------
 proxy-go/snowflake.go                      | 25 ++------
 server/server.go                           | 99 ++----------------------------
 5 files changed, 128 insertions(+), 134 deletions(-)

diff --git a/common/websocketconn/websocketconn.go b/common/websocketconn/websocketconn.go
new file mode 100644
index 0000000..399cbaa
--- /dev/null
+++ b/common/websocketconn/websocketconn.go
@@ -0,0 +1,89 @@
+package websocketconn
+
+import (
+	"io"
+	"log"
+	"sync"
+	"time"
+
+	"github.com/gorilla/websocket"
+)
+
+// An abstraction that makes an underlying WebSocket connection look like an
+// io.ReadWriteCloser.
+type WebSocketConn struct {
+	Ws *websocket.Conn
+	r  io.Reader
+}
+
+// Implements io.Reader.
+func (conn *WebSocketConn) Read(b []byte) (n int, err error) {
+	var opCode int
+	if conn.r == nil {
+		// New message
+		var r io.Reader
+		for {
+			if opCode, r, err = conn.Ws.NextReader(); err != nil {
+				return
+			}
+			if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage {
+				continue
+			}
+
+			conn.r = r
+			break
+		}
+	}
+
+	n, err = conn.r.Read(b)
+	if err == io.EOF {
+		// Message finished
+		conn.r = nil
+		err = nil
+	}
+	return
+}
+
+// Implements io.Writer.
+func (conn *WebSocketConn) Write(b []byte) (n int, err error) {
+	var w io.WriteCloser
+	if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil {
+		return
+	}
+	if n, err = w.Write(b); err != nil {
+		return
+	}
+	err = w.Close()
+	return
+}
+
+// Implements io.Closer.
+func (conn *WebSocketConn) Close() error {
+	// Ignore any error in trying to write a Close frame.
+	_ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second))
+	return conn.Ws.Close()
+}
+
+// Create a new WebSocketConn.
+func NewWebSocketConn(ws *websocket.Conn) WebSocketConn {
+	var conn WebSocketConn
+	conn.Ws = ws
+	return conn
+}
+
+// Copy from WebSocket to socket and vice versa.
+func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) {
+	var wg sync.WaitGroup
+	copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
+		defer wg.Done()
+		if _, err := io.Copy(dst, src); err != nil {
+			log.Printf("io.Copy inside CopyLoop generated an error: %v", err)
+		}
+		dst.Close()
+		src.Close()
+	}
+	wg.Add(2)
+	go copyer(c1, c2)
+	go copyer(c2, c1)
+	wg.Wait()
+}
diff --git a/common/websocketconn/websocketconn_test.go b/common/websocketconn/websocketconn_test.go
new file mode 100644
index 0000000..3293165
--- /dev/null
+++ b/common/websocketconn/websocketconn_test.go
@@ -0,0 +1,30 @@
+package websocketconn
+
+import (
+	"net"
+	"testing"
+
+	. "github.com/smartystreets/goconvey/convey"
+)
+
+func TestWebsocketConn(t *testing.T) {
+	Convey("CopyLoop", t, func() {
+		c1, s1 := net.Pipe()
+		c2, s2 := net.Pipe()
+		go CopyLoop(s1, s2)
+		go func() {
+			bytes := []byte("Hello!")
+			c1.Write(bytes)
+		}()
+		bytes := make([]byte, 6)
+		n, err := c2.Read(bytes)
+		So(n, ShouldEqual, 6)
+		So(err, ShouldEqual, nil)
+		So(bytes, ShouldResemble, []byte("Hello!"))
+		s1.Close()
+
+		// Check that copy loop has closed other connection
+		_, err = s2.Write(bytes)
+		So(err, ShouldNotBeNil)
+	})
+}
diff --git a/proxy-go/proxy-go_test.go b/proxy-go/proxy-go_test.go
index ebe4381..538957b 100644
--- a/proxy-go/proxy-go_test.go
+++ b/proxy-go/proxy-go_test.go
@@ -374,23 +374,4 @@ func TestUtilityFuncs(t *testing.T) {
 		sid2 := genSessionID()
 		So(sid1, ShouldNotEqual, sid2)
 	})
-	Convey("CopyLoop", t, func() {
-		c1, s1 := net.Pipe()
-		c2, s2 := net.Pipe()
-		go CopyLoop(s1, s2)
-		go func() {
-			bytes := []byte("Hello!")
-			c1.Write(bytes)
-		}()
-		bytes := make([]byte, 6)
-		n, err := c2.Read(bytes)
-		So(n, ShouldEqual, 6)
-		So(err, ShouldEqual, nil)
-		So(bytes, ShouldResemble, []byte("Hello!"))
-		s1.Close()
-
-		//Check that copy loop has closed other connection
-		_, err = s2.Write(bytes)
-		So(err, ShouldNotBeNil)
-	})
 }
diff --git a/proxy-go/snowflake.go b/proxy-go/snowflake.go
index c4b2f0b..0e14eb2 100644
--- a/proxy-go/snowflake.go
+++ b/proxy-go/snowflake.go
@@ -21,8 +21,9 @@ import (
 
 	"git.torproject.org/pluggable-transports/snowflake.git/common/messages"
 	"git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
+	"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
+	"github.com/gorilla/websocket"
 	"github.com/pion/webrtc"
-	"golang.org/x/net/websocket"
 )
 
 const defaultBrokerURL = "https://snowflake-broker.bamsoftware.com/"
@@ -239,22 +240,6 @@ func (b *Broker) sendAnswer(sid string, pc *webrtc.PeerConnection) error {
 	return nil
 }
 
-func CopyLoop(c1 net.Conn, c2 net.Conn) {
-	var wg sync.WaitGroup
-	copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
-		defer wg.Done()
-		if _, err := io.Copy(dst, src); err != nil {
-			log.Printf("io.Copy inside CopyLoop generated an error: %v", err)
-		}
-		dst.Close()
-		src.Close()
-	}
-	wg.Add(2)
-	go copyer(c1, c2)
-	go copyer(c2, c1)
-	wg.Wait()
-}
-
 // We pass conn.RemoteAddr() as an additional parameter, rather than calling
 // conn.RemoteAddr() inside this function, as a workaround for a hang that
 // otherwise occurs inside of conn.pc.RemoteDescription() (called by
@@ -279,15 +264,15 @@ func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
 		log.Printf("no remote address given in websocket")
 	}
 
-	wsConn, err := websocket.Dial(u.String(), "", relayURL)
+	ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
 	if err != nil {
 		log.Printf("error dialing relay: %s", err)
 		return
 	}
+	wsConn := websocketconn.NewWebSocketConn(ws)
 	log.Printf("connected to relay")
 	defer wsConn.Close()
-	wsConn.PayloadType = websocket.BinaryFrame
-	CopyLoop(conn, wsConn)
+	websocketconn.CopyLoop(conn, &wsConn)
 	log.Printf("datachannelHandler ends")
 }
 
diff --git a/server/server.go b/server/server.go
index ce804fc..d950ddc 100644
--- a/server/server.go
+++ b/server/server.go
@@ -15,12 +15,12 @@ import (
 	"os/signal"
 	"path/filepath"
 	"strings"
-	"sync"
 	"syscall"
 	"time"
 
 	pt "git.torproject.org/pluggable-transports/goptlib.git"
 	"git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
+	"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
 	"github.com/gorilla/websocket"
 	"golang.org/x/crypto/acme/autocert"
 	"golang.org/x/net/http2"
@@ -50,97 +50,6 @@ additional HTTP listener on port 80 to work with ACME.
 	flag.PrintDefaults()
 }
 
-// An abstraction that makes an underlying WebSocket connection look like an
-// io.ReadWriteCloser.
-type webSocketConn struct {
-	Ws *websocket.Conn
-	r  io.Reader
-}
-
-// Implements io.Reader.
-func (conn *webSocketConn) Read(b []byte) (n int, err error) {
-	var opCode int
-	if conn.r == nil {
-		// New message
-		var r io.Reader
-		for {
-			if opCode, r, err = conn.Ws.NextReader(); err != nil {
-				return
-			}
-			if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage {
-				continue
-			}
-
-			conn.r = r
-			break
-		}
-	}
-
-	n, err = conn.r.Read(b)
-	if err == io.EOF {
-		// Message finished
-		conn.r = nil
-		err = nil
-	}
-	return
-}
-
-// Implements io.Writer.
-func (conn *webSocketConn) Write(b []byte) (n int, err error) {
-	var w io.WriteCloser
-	if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil {
-		return
-	}
-	if n, err = w.Write(b); err != nil {
-		return
-	}
-	err = w.Close()
-	return
-}
-
-// Implements io.Closer.
-func (conn *webSocketConn) Close() error {
-	// Ignore any error in trying to write a Close frame.
-	_ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second))
-	return conn.Ws.Close()
-}
-
-// Create a new webSocketConn.
-func newWebSocketConn(ws *websocket.Conn) webSocketConn {
-	var conn webSocketConn
-	conn.Ws = ws
-	return conn
-}
-
-// Copy from WebSocket to socket and vice versa.
-func proxy(local *net.TCPConn, conn *webSocketConn) {
-	var wg sync.WaitGroup
-	wg.Add(2)
-
-	go func() {
-		if _, err := io.Copy(conn, local); err != nil {
-			log.Printf("error copying ORPort to WebSocket %v", err)
-		}
-		if err := local.CloseRead(); err != nil {
-			log.Printf("error closing read after copying ORPort to WebSocket %v", err)
-		}
-		conn.Close()
-		wg.Done()
-	}()
-	go func() {
-		if _, err := io.Copy(local, conn); err != nil {
-			log.Printf("error copying WebSocket to ORPort")
-		}
-		if err := local.CloseWrite(); err != nil {
-			log.Printf("error closing write after copying WebSocket to ORPort %v", err)
-		}
-		conn.Close()
-		wg.Done()
-	}()
-
-	wg.Wait()
-}
-
 // Return an address string suitable to pass into pt.DialOr.
 func clientAddr(clientIPParam string) string {
 	if clientIPParam == "" {
@@ -166,8 +75,8 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	conn := newWebSocketConn(ws)
-	defer conn.Close()
+	wsConn := websocketconn.NewWebSocketConn(ws)
+	defer wsConn.Close()
 
 	// Pass the address of client as the remote address of incoming connection
 	clientIPParam := r.URL.Query().Get("client_ip")
@@ -184,7 +93,7 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	}
 	defer or.Close()
 
-	proxy(or, &conn)
+	websocketconn.CopyLoop(or, &wsConn)
 }
 
 func initServer(addr *net.TCPAddr,





More information about the tor-commits mailing list