[tor-commits] [snowflake/main] Implement server as a v2.1 PT Go API

cohosh at torproject.org cohosh at torproject.org
Wed May 12 13:11:17 UTC 2021


commit 11f0846264d4033e7a7dc7824febb6ad7140762f
Author: Cecylia Bocovich <cohosh at torproject.org>
Date:   Sat Mar 20 18:24:00 2021 -0400

    Implement server as a v2.1 PT Go API
---
 server/lib/http.go                   | 211 +++++++++++++++++
 server/lib/server_test.go            |  55 +++++
 server/lib/snowflake.go              | 242 ++++++++++++++++++++
 server/{ => lib}/turbotunnel.go      |   2 +-
 server/{ => lib}/turbotunnel_test.go |   2 +-
 server/server.go                     | 426 ++++-------------------------------
 server/server_test.go                | 153 -------------
 7 files changed, 551 insertions(+), 540 deletions(-)

diff --git a/server/lib/http.go b/server/lib/http.go
new file mode 100644
index 0000000..b1c453c
--- /dev/null
+++ b/server/lib/http.go
@@ -0,0 +1,211 @@
+package lib
+
+import (
+	"bufio"
+	"bytes"
+	"fmt"
+	"io"
+	"log"
+	"net"
+	"net/http"
+	"time"
+
+	"git.torproject.org/pluggable-transports/snowflake.git/common/encapsulation"
+	"git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel"
+	"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
+	"github.com/gorilla/websocket"
+)
+
+const requestTimeout = 10 * time.Second
+
+// How long to remember outgoing packets for a client, when we don't currently
+// have an active WebSocket connection corresponding to that client. Because a
+// client session may span multiple WebSocket connections, we keep packets we
+// aren't able to send immediately in memory, for a little while but not
+// indefinitely.
+const clientMapTimeout = 1 * time.Minute
+
+// How big to make the map of ClientIDs to IP addresses. The map is used in
+// turbotunnelMode to store a reasonable IP address for a client session that
+// may outlive any single WebSocket connection.
+const clientIDAddrMapCapacity = 1024
+
+// How long to wait for ListenAndServe or ListenAndServeTLS to return an error
+// before deciding that it's not going to return.
+const listenAndServeErrorTimeout = 100 * time.Millisecond
+
+var upgrader = websocket.Upgrader{
+	CheckOrigin: func(r *http.Request) bool { return true },
+}
+
+// clientIDAddrMap stores short-term mappings from ClientIDs to IP addresses.
+// When we call pt.DialOr, tor wants us to provide a USERADDR string that
+// represents the remote IP address of the client (for metrics purposes, etc.).
+// This data structure bridges the gap between ServeHTTP, which knows about IP
+// addresses, and handleStream, which is what calls pt.DialOr. The common piece
+// of information linking both ends of the chain is the ClientID, which is
+// attached to the WebSocket connection and every session.
+var clientIDAddrMap = newClientIDMap(clientIDAddrMapCapacity)
+
+// overrideReadConn is a net.Conn with an overridden Read method. Compare to
+// recordingConn at
+// https://dave.cheney.net/2015/05/22/struct-composition-with-go.
+type overrideReadConn struct {
+	net.Conn
+	io.Reader
+}
+
+func (conn *overrideReadConn) Read(p []byte) (int, error) {
+	return conn.Reader.Read(p)
+}
+
+type HTTPHandler struct {
+	// pconn is the adapter layer between stream-oriented WebSocket
+	// connections and the packet-oriented KCP layer.
+	pconn *turbotunnel.QueuePacketConn
+	ln    *SnowflakeListener
+}
+
+func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	ws, err := upgrader.Upgrade(w, r, nil)
+	if err != nil {
+		log.Println(err)
+		return
+	}
+
+	conn := websocketconn.New(ws)
+	defer conn.Close()
+
+	// Pass the address of client as the remote address of incoming connection
+	clientIPParam := r.URL.Query().Get("client_ip")
+	addr := clientAddr(clientIPParam)
+
+	var token [len(turbotunnel.Token)]byte
+	_, err = io.ReadFull(conn, token[:])
+	if err != nil {
+		// Don't bother logging EOF: that happens with an unused
+		// connection, which clients make frequently as they maintain a
+		// pool of proxies.
+		if err != io.EOF {
+			log.Printf("reading token: %v", err)
+		}
+		return
+	}
+
+	switch {
+	case bytes.Equal(token[:], turbotunnel.Token[:]):
+		err = turbotunnelMode(conn, addr, handler.pconn)
+	default:
+		// We didn't find a matching token, which means that we are
+		// dealing with a client that doesn't know about such things.
+		// "Unread" the token by constructing a new Reader and pass it
+		// to the old one-session-per-WebSocket mode.
+		conn2 := &overrideReadConn{Conn: conn, Reader: io.MultiReader(bytes.NewReader(token[:]), conn)}
+		err = oneshotMode(conn2, addr, handler.ln)
+	}
+	if err != nil {
+		log.Println(err)
+		return
+	}
+}
+
+// oneshotMode handles clients that did not send turbotunnel.Token at the start
+// of their stream. These clients use the WebSocket as a raw pipe, and expect
+// their session to begin and end when this single WebSocket does.
+func oneshotMode(conn net.Conn, addr net.Addr, ln *SnowflakeListener) error {
+	return ln.QueueConn(&SnowflakeClientConn{Conn: conn, address: addr})
+}
+
+// turbotunnelMode handles clients that sent turbotunnel.Token at the start of
+// their stream. These clients expect to send and receive encapsulated packets,
+// with a long-lived session identified by ClientID.
+func turbotunnelMode(conn net.Conn, addr net.Addr, pconn *turbotunnel.QueuePacketConn) error {
+	// Read the ClientID prefix. Every packet encapsulated in this WebSocket
+	// connection pertains to the same ClientID.
+	var clientID turbotunnel.ClientID
+	_, err := io.ReadFull(conn, clientID[:])
+	if err != nil {
+		return fmt.Errorf("reading ClientID: %v", err)
+	}
+
+	// Store a a short-term mapping from the ClientID to the client IP
+	// address attached to this WebSocket connection. tor will want us to
+	// provide a client IP address when we call pt.DialOr. But a KCP session
+	// does not necessarily correspond to any single IP address--it's
+	// composed of packets that are carried in possibly multiple WebSocket
+	// streams. We apply the heuristic that the IP address of the most
+	// recent WebSocket connection that has had to do with a session, at the
+	// time the session is established, is the IP address that should be
+	// credited for the entire KCP session.
+	clientIDAddrMap.Set(clientID, addr.String())
+
+	errCh := make(chan error)
+
+	// The remainder of the WebSocket stream consists of encapsulated
+	// packets. We read them one by one and feed them into the
+	// QueuePacketConn on which kcp.ServeConn was set up, which eventually
+	// leads to KCP-level sessions in the acceptSessions function.
+	go func() {
+		for {
+			p, err := encapsulation.ReadData(conn)
+			if err != nil {
+				errCh <- err
+				break
+			}
+			pconn.QueueIncoming(p, clientID)
+		}
+	}()
+
+	// At the same time, grab packets addressed to this ClientID and
+	// encapsulate them into the downstream.
+	go func() {
+		// Buffer encapsulation.WriteData operations to keep length
+		// prefixes in the same send as the data that follows.
+		bw := bufio.NewWriter(conn)
+		for p := range pconn.OutgoingQueue(clientID) {
+			_, err := encapsulation.WriteData(bw, p)
+			if err == nil {
+				err = bw.Flush()
+			}
+			if err != nil {
+				errCh <- err
+				break
+			}
+		}
+	}()
+
+	// Wait until one of the above loops terminates. The closing of the
+	// WebSocket connection will terminate the other one.
+	<-errCh
+
+	return nil
+}
+
+type ClientMapAddr string
+
+func (addr ClientMapAddr) Network() string {
+	return "snowflake"
+}
+
+func (addr ClientMapAddr) String() string {
+	return string(addr)
+}
+
+// Return a client address
+func clientAddr(clientIPParam string) net.Addr {
+	if clientIPParam == "" {
+		return ClientMapAddr("")
+	}
+	// Check if client addr is a valid IP
+	clientIP := net.ParseIP(clientIPParam)
+	if clientIP == nil {
+		return ClientMapAddr("")
+	}
+	// Check if client addr is 0.0.0.0 or [::]. Some proxies erroneously
+	// report an address of 0.0.0.0: https://bugs.torproject.org/33157.
+	if clientIP.IsUnspecified() {
+		return ClientMapAddr("")
+	}
+	// Add a stub port number. USERADDR requires a port number.
+	return ClientMapAddr((&net.TCPAddr{IP: clientIP, Port: 1, Zone: ""}).String())
+}
diff --git a/server/lib/server_test.go b/server/lib/server_test.go
new file mode 100644
index 0000000..65d31d1
--- /dev/null
+++ b/server/lib/server_test.go
@@ -0,0 +1,55 @@
+package lib
+
+import (
+	"net"
+	"strconv"
+	"testing"
+
+	. "github.com/smartystreets/goconvey/convey"
+)
+
+func TestClientAddr(t *testing.T) {
+	Convey("Testing clientAddr", t, func() {
+		// good tests
+		for _, test := range []struct {
+			input    string
+			expected net.IP
+		}{
+			{"1.2.3.4", net.ParseIP("1.2.3.4")},
+			{"1:2::3:4", net.ParseIP("1:2::3:4")},
+		} {
+			useraddr := clientAddr(test.input).String()
+			host, port, err := net.SplitHostPort(useraddr)
+			if err != nil {
+				t.Errorf("clientAddr(%q) → SplitHostPort error %v", test.input, err)
+				continue
+			}
+			if !test.expected.Equal(net.ParseIP(host)) {
+				t.Errorf("clientAddr(%q) → host %q, not %v", test.input, host, test.expected)
+			}
+			portNo, err := strconv.Atoi(port)
+			if err != nil {
+				t.Errorf("clientAddr(%q) → port %q", test.input, port)
+				continue
+			}
+			if portNo == 0 {
+				t.Errorf("clientAddr(%q) → port %d", test.input, portNo)
+			}
+		}
+
+		// bad tests
+		for _, input := range []string{
+			"",
+			"abc",
+			"1.2.3.4.5",
+			"[12::34]",
+			"0.0.0.0",
+			"[::]",
+		} {
+			useraddr := clientAddr(input).String()
+			if useraddr != "" {
+				t.Errorf("clientAddr(%q) → %q, not %q", input, useraddr, "")
+			}
+		}
+	})
+}
diff --git a/server/lib/snowflake.go b/server/lib/snowflake.go
new file mode 100644
index 0000000..319acd8
--- /dev/null
+++ b/server/lib/snowflake.go
@@ -0,0 +1,242 @@
+package lib
+
+import (
+	"crypto/tls"
+	"fmt"
+	"io"
+	"log"
+	"net"
+	"net/http"
+	"sync"
+	"time"
+
+	"git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel"
+	"github.com/xtaci/kcp-go/v5"
+	"github.com/xtaci/smux"
+	"golang.org/x/net/http2"
+)
+
+// Transport is a structure with methods that conform to the Go PT v2.1 API
+// https://github.com/Pluggable-Transports/Pluggable-Transports-spec/blob/master/releases/PTSpecV2.1/Pluggable%20Transport%20Specification%20v2.1%20-%20Go%20Transport%20API.pdf
+type Transport struct {
+	getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)
+}
+
+func NewSnowflakeServer(getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *Transport {
+
+	return &Transport{getCertificate: getCertificate}
+}
+
+func (t *Transport) Listen(addr net.Addr) (*SnowflakeListener, error) {
+	listener := &SnowflakeListener{addr: addr, queue: make(chan net.Conn, 65534)}
+
+	handler := HTTPHandler{
+		// pconn is shared among all connections to this server. It
+		// overlays packet-based client sessions on top of ephemeral
+		// WebSocket connections.
+		pconn: turbotunnel.NewQueuePacketConn(addr, clientMapTimeout),
+	}
+	server := &http.Server{
+		Addr:        addr.String(),
+		Handler:     &handler,
+		ReadTimeout: requestTimeout,
+	}
+	// We need to override server.TLSConfig.GetCertificate--but first
+	// server.TLSConfig needs to be non-nil. If we just create our own new
+	// &tls.Config, it will lack the default settings that the net/http
+	// package sets up for things like HTTP/2. Therefore we first call
+	// http2.ConfigureServer for its side effect of initializing
+	// server.TLSConfig properly. An alternative would be to make a dummy
+	// net.Listener, call Serve on it, and let it return.
+	// https://github.com/golang/go/issues/16588#issuecomment-237386446
+	err := http2.ConfigureServer(server, nil)
+	if err != nil {
+		return nil, err
+	}
+	server.TLSConfig.GetCertificate = t.getCertificate
+
+	// Another unfortunate effect of the inseparable net/http ListenAndServe
+	// is that we can't check for Listen errors like "permission denied" and
+	// "address already in use" without potentially entering the infinite
+	// loop of Serve. The hack we apply here is to wait a short time,
+	// listenAndServeErrorTimeout, to see if an error is returned (because
+	// it's better if the error message goes to the tor log through
+	// SMETHOD-ERROR than if it only goes to the snowflake log).
+	errChan := make(chan error)
+	go func() {
+		if t.getCertificate == nil {
+			// TLS is disabled
+			log.Printf("listening with plain HTTP on %s", addr)
+			err := server.ListenAndServe()
+			if err != nil {
+				log.Printf("error in ListenAndServe: %s", err)
+			}
+			errChan <- err
+		} else {
+			log.Printf("listening with HTTPS on %s", addr)
+			err := server.ListenAndServeTLS("", "")
+			if err != nil {
+				log.Printf("error in ListenAndServeTLS: %s", err)
+			}
+			errChan <- err
+		}
+	}()
+
+	select {
+	case err = <-errChan:
+		break
+	case <-time.After(listenAndServeErrorTimeout):
+		break
+	}
+
+	listener.server = server
+
+	// Start a KCP engine, set up to read and write its packets over the
+	// WebSocket connections that arrive at the web server.
+	// handler.ServeHTTP is responsible for encapsulation/decapsulation of
+	// packets on behalf of KCP. KCP takes those packets and turns them into
+	// sessions which appear in the acceptSessions function.
+	ln, err := kcp.ServeConn(nil, 0, 0, handler.pconn)
+	if err != nil {
+		server.Close()
+		return nil, err
+	}
+	go func() {
+		defer ln.Close()
+		err := listener.acceptSessions(ln)
+		if err != nil {
+			log.Printf("acceptSessions: %v", err)
+		}
+	}()
+
+	listener.ln = ln
+
+	return listener, nil
+
+}
+
+type SnowflakeListener struct {
+	addr      net.Addr
+	queue     chan net.Conn
+	server    *http.Server
+	ln        *kcp.Listener
+	closed    chan struct{}
+	closeOnce sync.Once
+}
+
+// Allows the caller to accept incoming Snowflake connections
+// We accept connections from a queue to accommodate both incoming
+// smux Streams and legacy non-turbotunnel connections
+func (l *SnowflakeListener) Accept() (net.Conn, error) {
+	select {
+	case <-l.closed:
+		//channel has been closed, no longer accepting connections
+		return nil, io.ErrClosedPipe
+	case conn := <-l.queue:
+		return conn, nil
+	}
+}
+
+func (l *SnowflakeListener) Addr() net.Addr {
+	return l.addr
+}
+
+func (l *SnowflakeListener) Close() error {
+	// Close our HTTP server and our KCP listener
+	l.closeOnce.Do(func() {
+		close(l.closed)
+		l.server.Close()
+		l.ln.Close()
+	})
+	return nil
+}
+
+// acceptStreams layers an smux.Session on the KCP connection and awaits streams
+// on it. Passes each stream to our SnowflakeListener accept queue.
+func (l *SnowflakeListener) acceptStreams(conn *kcp.UDPSession) error {
+	// Look up the IP address associated with this KCP session, via the
+	// ClientID that is returned by the session's RemoteAddr method.
+	addr, ok := clientIDAddrMap.Get(conn.RemoteAddr().(turbotunnel.ClientID))
+	if !ok {
+		// This means that the map is tending to run over capacity, not
+		// just that there was not client_ip on the incoming connection.
+		// We store "" in the map in the absence of client_ip. This log
+		// message means you should increase clientIDAddrMapCapacity.
+		log.Printf("no address in clientID-to-IP map (capacity %d)", clientIDAddrMapCapacity)
+	}
+
+	smuxConfig := smux.DefaultConfig()
+	smuxConfig.Version = 2
+	smuxConfig.KeepAliveTimeout = 10 * time.Minute
+	sess, err := smux.Server(conn, smuxConfig)
+	if err != nil {
+		return err
+	}
+
+	for {
+		stream, err := sess.AcceptStream()
+		if err != nil {
+			if err, ok := err.(net.Error); ok && err.Temporary() {
+				continue
+			}
+			return err
+		}
+		l.QueueConn(&SnowflakeClientConn{Conn: stream, address: clientAddr(addr)})
+	}
+}
+
+// acceptSessions listens for incoming KCP connections and passes them to
+// acceptStreams. It is handler.ServeHTTP that provides the network interface
+// that drives this function.
+func (l *SnowflakeListener) acceptSessions(ln *kcp.Listener) error {
+	for {
+		conn, err := ln.AcceptKCP()
+		if err != nil {
+			if err, ok := err.(net.Error); ok && err.Temporary() {
+				continue
+			}
+			return err
+		}
+		// Permit coalescing the payloads of consecutive sends.
+		conn.SetStreamMode(true)
+		// Set the maximum send and receive window sizes to a high number
+		// Removes KCP bottlenecks: https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40026
+		conn.SetWindowSize(65535, 65535)
+		// Disable the dynamic congestion window (limit only by the
+		// maximum of local and remote static windows).
+		conn.SetNoDelay(
+			0, // default nodelay
+			0, // default interval
+			0, // default resend
+			1, // nc=1 => congestion window off
+		)
+		go func() {
+			defer conn.Close()
+			err := l.acceptStreams(conn)
+			if err != nil && err != io.ErrClosedPipe {
+				log.Printf("acceptStreams: %v", err)
+			}
+		}()
+	}
+}
+
+func (l *SnowflakeListener) QueueConn(conn net.Conn) error {
+	select {
+	case <-l.closed:
+		return fmt.Errorf("accepted connection on closed listener")
+	case l.queue <- conn:
+		return nil
+	}
+}
+
+// A wrapper for the underlying oneshot or turbotunnel conn
+// because we need to reference our mapping to determine the client
+// address
+type SnowflakeClientConn struct {
+	net.Conn
+	address net.Addr
+}
+
+func (conn *SnowflakeClientConn) RemoteAddr() net.Addr {
+	return conn.address
+}
diff --git a/server/turbotunnel.go b/server/lib/turbotunnel.go
similarity index 99%
rename from server/turbotunnel.go
rename to server/lib/turbotunnel.go
index 1d00897..bb16fa3 100644
--- a/server/turbotunnel.go
+++ b/server/lib/turbotunnel.go
@@ -1,4 +1,4 @@
-package main
+package lib
 
 import (
 	"sync"
diff --git a/server/turbotunnel_test.go b/server/lib/turbotunnel_test.go
similarity index 99%
rename from server/turbotunnel_test.go
rename to server/lib/turbotunnel_test.go
index c4bf02b..ba4cf60 100644
--- a/server/turbotunnel_test.go
+++ b/server/lib/turbotunnel_test.go
@@ -1,4 +1,4 @@
-package main
+package lib
 
 import (
 	"encoding/binary"
diff --git a/server/server.go b/server/server.go
index 620cd50..b61d5b4 100644
--- a/server/server.go
+++ b/server/server.go
@@ -3,9 +3,6 @@
 package main
 
 import (
-	"bufio"
-	"bytes"
-	"crypto/tls"
 	"flag"
 	"fmt"
 	"io"
@@ -19,38 +16,15 @@ import (
 	"strings"
 	"sync"
 	"syscall"
-	"time"
 
-	pt "git.torproject.org/pluggable-transports/goptlib.git"
-	"git.torproject.org/pluggable-transports/snowflake.git/common/encapsulation"
 	"git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
-	"git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel"
-	"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
-	"github.com/gorilla/websocket"
-	"github.com/xtaci/kcp-go/v5"
-	"github.com/xtaci/smux"
 	"golang.org/x/crypto/acme/autocert"
-	"golang.org/x/net/http2"
+
+	pt "git.torproject.org/pluggable-transports/goptlib.git"
+	sf "git.torproject.org/pluggable-transports/snowflake.git/server/lib"
 )
 
 const ptMethodName = "snowflake"
-const requestTimeout = 10 * time.Second
-
-// How long to remember outgoing packets for a client, when we don't currently
-// have an active WebSocket connection corresponding to that client. Because a
-// client session may span multiple WebSocket connections, we keep packets we
-// aren't able to send immediately in memory, for a little while but not
-// indefinitely.
-const clientMapTimeout = 1 * time.Minute
-
-// How big to make the map of ClientIDs to IP addresses. The map is used in
-// turbotunnelMode to store a reasonable IP address for a client session that
-// may outlive any single WebSocket connection.
-const clientIDAddrMapCapacity = 1024
-
-// How long to wait for ListenAndServe or ListenAndServeTLS to return an error
-// before deciding that it's not going to return.
-const listenAndServeErrorTimeout = 100 * time.Millisecond
 
 var ptInfo pt.ServerInfo
 
@@ -92,366 +66,30 @@ func proxy(local *net.TCPConn, conn net.Conn) {
 	wg.Wait()
 }
 
-// Return an address string suitable to pass into pt.DialOr.
-func clientAddr(clientIPParam string) string {
-	if clientIPParam == "" {
-		return ""
-	}
-	// Check if client addr is a valid IP
-	clientIP := net.ParseIP(clientIPParam)
-	if clientIP == nil {
-		return ""
-	}
-	// Check if client addr is 0.0.0.0 or [::]. Some proxies erroneously
-	// report an address of 0.0.0.0: https://bugs.torproject.org/33157.
-	if clientIP.IsUnspecified() {
-		return ""
-	}
-	// Add a dummy port number. USERADDR requires a port number.
-	return (&net.TCPAddr{IP: clientIP, Port: 1, Zone: ""}).String()
-}
-
-var upgrader = websocket.Upgrader{
-	CheckOrigin: func(r *http.Request) bool { return true },
-}
-
-// clientIDAddrMap stores short-term mappings from ClientIDs to IP addresses.
-// When we call pt.DialOr, tor wants us to provide a USERADDR string that
-// represents the remote IP address of the client (for metrics purposes, etc.).
-// This data structure bridges the gap between ServeHTTP, which knows about IP
-// addresses, and handleStream, which is what calls pt.DialOr. The common piece
-// of information linking both ends of the chain is the ClientID, which is
-// attached to the WebSocket connection and every session.
-var clientIDAddrMap = newClientIDMap(clientIDAddrMapCapacity)
-
-// overrideReadConn is a net.Conn with an overridden Read method. Compare to
-// recordingConn at
-// https://dave.cheney.net/2015/05/22/struct-composition-with-go.
-type overrideReadConn struct {
-	net.Conn
-	io.Reader
-}
-
-func (conn *overrideReadConn) Read(p []byte) (int, error) {
-	return conn.Reader.Read(p)
-}
-
-type HTTPHandler struct {
-	// pconn is the adapter layer between stream-oriented WebSocket
-	// connections and the packet-oriented KCP layer.
-	pconn *turbotunnel.QueuePacketConn
-}
-
-func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ws, err := upgrader.Upgrade(w, r, nil)
-	if err != nil {
-		log.Println(err)
-		return
-	}
-
-	conn := websocketconn.New(ws)
-	defer conn.Close()
-
-	// Pass the address of client as the remote address of incoming connection
-	clientIPParam := r.URL.Query().Get("client_ip")
-	addr := clientAddr(clientIPParam)
-
-	var token [len(turbotunnel.Token)]byte
-	_, err = io.ReadFull(conn, token[:])
-	if err != nil {
-		// Don't bother logging EOF: that happens with an unused
-		// connection, which clients make frequently as they maintain a
-		// pool of proxies.
-		if err != io.EOF {
-			log.Printf("reading token: %v", err)
-		}
-		return
-	}
-
-	switch {
-	case bytes.Equal(token[:], turbotunnel.Token[:]):
-		err = turbotunnelMode(conn, addr, handler.pconn)
-	default:
-		// We didn't find a matching token, which means that we are
-		// dealing with a client that doesn't know about such things.
-		// "Unread" the token by constructing a new Reader and pass it
-		// to the old one-session-per-WebSocket mode.
-		conn2 := &overrideReadConn{Conn: conn, Reader: io.MultiReader(bytes.NewReader(token[:]), conn)}
-		err = oneshotMode(conn2, addr)
-	}
-	if err != nil {
-		log.Println(err)
-		return
-	}
-}
-
-// oneshotMode handles clients that did not send turbotunnel.Token at the start
-// of their stream. These clients use the WebSocket as a raw pipe, and expect
-// their session to begin and end when this single WebSocket does.
-func oneshotMode(conn net.Conn, addr string) error {
-	statsChannel <- addr != ""
-	or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
-	if err != nil {
-		return fmt.Errorf("failed to connect to ORPort: %s", err)
-	}
-	defer or.Close()
-
-	proxy(or, conn)
-
-	return nil
-}
-
-// turbotunnelMode handles clients that sent turbotunnel.Token at the start of
-// their stream. These clients expect to send and receive encapsulated packets,
-// with a long-lived session identified by ClientID.
-func turbotunnelMode(conn net.Conn, addr string, pconn *turbotunnel.QueuePacketConn) error {
-	// Read the ClientID prefix. Every packet encapsulated in this WebSocket
-	// connection pertains to the same ClientID.
-	var clientID turbotunnel.ClientID
-	_, err := io.ReadFull(conn, clientID[:])
-	if err != nil {
-		return fmt.Errorf("reading ClientID: %v", err)
-	}
-
-	// Store a a short-term mapping from the ClientID to the client IP
-	// address attached to this WebSocket connection. tor will want us to
-	// provide a client IP address when we call pt.DialOr. But a KCP session
-	// does not necessarily correspond to any single IP address--it's
-	// composed of packets that are carried in possibly multiple WebSocket
-	// streams. We apply the heuristic that the IP address of the most
-	// recent WebSocket connection that has had to do with a session, at the
-	// time the session is established, is the IP address that should be
-	// credited for the entire KCP session.
-	clientIDAddrMap.Set(clientID, addr)
-
-	errCh := make(chan error)
-
-	// The remainder of the WebSocket stream consists of encapsulated
-	// packets. We read them one by one and feed them into the
-	// QueuePacketConn on which kcp.ServeConn was set up, which eventually
-	// leads to KCP-level sessions in the acceptSessions function.
-	go func() {
-		for {
-			p, err := encapsulation.ReadData(conn)
-			if err != nil {
-				errCh <- err
-				break
-			}
-			pconn.QueueIncoming(p, clientID)
-		}
-	}()
-
-	// At the same time, grab packets addressed to this ClientID and
-	// encapsulate them into the downstream.
-	go func() {
-		// Buffer encapsulation.WriteData operations to keep length
-		// prefixes in the same send as the data that follows.
-		bw := bufio.NewWriter(conn)
-		for p := range pconn.OutgoingQueue(clientID) {
-			_, err := encapsulation.WriteData(bw, p)
-			if err == nil {
-				err = bw.Flush()
-			}
-			if err != nil {
-				errCh <- err
-				break
-			}
-		}
-	}()
-
-	// Wait until one of the above loops terminates. The closing of the
-	// WebSocket connection will terminate the other one.
-	<-errCh
-
-	return nil
-}
-
-// handleStream bidirectionally connects a client stream with the ORPort.
-func handleStream(stream net.Conn, addr string) error {
-	statsChannel <- addr != ""
-	or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
-	if err != nil {
-		return fmt.Errorf("connecting to ORPort: %v", err)
-	}
-	defer or.Close()
-
-	proxy(or, stream)
-
-	return nil
-}
-
-// acceptStreams layers an smux.Session on the KCP connection and awaits streams
-// on it. Passes each stream to handleStream.
-func acceptStreams(conn *kcp.UDPSession) error {
-	// Look up the IP address associated with this KCP session, via the
-	// ClientID that is returned by the session's RemoteAddr method.
-	addr, ok := clientIDAddrMap.Get(conn.RemoteAddr().(turbotunnel.ClientID))
-	if !ok {
-		// This means that the map is tending to run over capacity, not
-		// just that there was not client_ip on the incoming connection.
-		// We store "" in the map in the absence of client_ip. This log
-		// message means you should increase clientIDAddrMapCapacity.
-		log.Printf("no address in clientID-to-IP map (capacity %d)", clientIDAddrMapCapacity)
-	}
-
-	smuxConfig := smux.DefaultConfig()
-	smuxConfig.Version = 2
-	smuxConfig.KeepAliveTimeout = 10 * time.Minute
-	sess, err := smux.Server(conn, smuxConfig)
-	if err != nil {
-		return err
-	}
-
+func acceptLoop(ln net.Listener) {
 	for {
-		stream, err := sess.AcceptStream()
+		conn, err := ln.Accept()
 		if err != nil {
 			if err, ok := err.(net.Error); ok && err.Temporary() {
 				continue
 			}
-			return err
+			log.Printf("Snowflake accept error: %s", err)
+			break
 		}
-		go func() {
-			defer stream.Close()
-			err := handleStream(stream, addr)
-			if err != nil {
-				log.Printf("handleStream: %v", err)
-			}
-		}()
-	}
-}
+		defer conn.Close()
 
-// acceptSessions listens for incoming KCP connections and passes them to
-// acceptStreams. It is handler.ServeHTTP that provides the network interface
-// that drives this function.
-func acceptSessions(ln *kcp.Listener) error {
-	for {
-		conn, err := ln.AcceptKCP()
+		addr := conn.RemoteAddr().String()
+		statsChannel <- addr != ""
+		or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
 		if err != nil {
-			if err, ok := err.(net.Error); ok && err.Temporary() {
-				continue
-			}
-			return err
+			log.Printf("failed to connect to ORPort: %s", err)
+			continue
 		}
-		// Permit coalescing the payloads of consecutive sends.
-		conn.SetStreamMode(true)
-		// Set the maximum send and receive window sizes to a high number
-		// Removes KCP bottlenecks: https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40026
-		conn.SetWindowSize(65535, 65535)
-		// Disable the dynamic congestion window (limit only by the
-		// maximum of local and remote static windows).
-		conn.SetNoDelay(
-			0, // default nodelay
-			0, // default interval
-			0, // default resend
-			1, // nc=1 => congestion window off
-		)
-		go func() {
-			defer conn.Close()
-			err := acceptStreams(conn)
-			if err != nil && err != io.ErrClosedPipe {
-				log.Printf("acceptStreams: %v", err)
-			}
-		}()
+		defer or.Close()
+		go proxy(or, conn)
 	}
 }
 
-func initServer(addr *net.TCPAddr,
-	getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error),
-	listenAndServe func(*http.Server, chan<- error)) (*http.Server, error) {
-	// We're not capable of listening on port 0 (i.e., an ephemeral port
-	// unknown in advance). The reason is that while the net/http package
-	// exposes ListenAndServe and ListenAndServeTLS, those functions never
-	// return, so there's no opportunity to find out what the port number
-	// is, in between the Listen and Serve steps.
-	// https://groups.google.com/d/msg/Golang-nuts/3F1VRCCENp8/3hcayZiwYM8J
-	if addr.Port == 0 {
-		return nil, fmt.Errorf("cannot listen on port %d; configure a port using ServerTransportListenAddr", addr.Port)
-	}
-
-	handler := HTTPHandler{
-		// pconn is shared among all connections to this server. It
-		// overlays packet-based client sessions on top of ephemeral
-		// WebSocket connections.
-		pconn: turbotunnel.NewQueuePacketConn(addr, clientMapTimeout),
-	}
-	server := &http.Server{
-		Addr:        addr.String(),
-		Handler:     &handler,
-		ReadTimeout: requestTimeout,
-	}
-	// We need to override server.TLSConfig.GetCertificate--but first
-	// server.TLSConfig needs to be non-nil. If we just create our own new
-	// &tls.Config, it will lack the default settings that the net/http
-	// package sets up for things like HTTP/2. Therefore we first call
-	// http2.ConfigureServer for its side effect of initializing
-	// server.TLSConfig properly. An alternative would be to make a dummy
-	// net.Listener, call Serve on it, and let it return.
-	// https://github.com/golang/go/issues/16588#issuecomment-237386446
-	err := http2.ConfigureServer(server, nil)
-	if err != nil {
-		return server, err
-	}
-	server.TLSConfig.GetCertificate = getCertificate
-
-	// Another unfortunate effect of the inseparable net/http ListenAndServe
-	// is that we can't check for Listen errors like "permission denied" and
-	// "address already in use" without potentially entering the infinite
-	// loop of Serve. The hack we apply here is to wait a short time,
-	// listenAndServeErrorTimeout, to see if an error is returned (because
-	// it's better if the error message goes to the tor log through
-	// SMETHOD-ERROR than if it only goes to the snowflake log).
-	errChan := make(chan error)
-	go listenAndServe(server, errChan)
-	select {
-	case err = <-errChan:
-		break
-	case <-time.After(listenAndServeErrorTimeout):
-		break
-	}
-
-	// Start a KCP engine, set up to read and write its packets over the
-	// WebSocket connections that arrive at the web server.
-	// handler.ServeHTTP is responsible for encapsulation/decapsulation of
-	// packets on behalf of KCP. KCP takes those packets and turns them into
-	// sessions which appear in the acceptSessions function.
-	ln, err := kcp.ServeConn(nil, 0, 0, handler.pconn)
-	if err != nil {
-		server.Close()
-		return server, err
-	}
-	go func() {
-		defer ln.Close()
-		err := acceptSessions(ln)
-		if err != nil {
-			log.Printf("acceptSessions: %v", err)
-		}
-	}()
-
-	return server, err
-}
-
-func startServer(addr *net.TCPAddr) (*http.Server, error) {
-	return initServer(addr, nil, func(server *http.Server, errChan chan<- error) {
-		log.Printf("listening with plain HTTP on %s", addr)
-		err := server.ListenAndServe()
-		if err != nil {
-			log.Printf("error in ListenAndServe: %s", err)
-		}
-		errChan <- err
-	})
-}
-
-func startServerTLS(addr *net.TCPAddr, getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)) (*http.Server, error) {
-	return initServer(addr, getCertificate, func(server *http.Server, errChan chan<- error) {
-		log.Printf("listening with HTTPS on %s", addr)
-		err := server.ListenAndServeTLS("", "")
-		if err != nil {
-			log.Printf("error in ListenAndServeTLS: %s", err)
-		}
-		errChan <- err
-	})
-}
-
 func getCertificateCacheDir() (string, error) {
 	stateDir, err := pt.MakeStateDir()
 	if err != nil {
@@ -535,7 +173,7 @@ func main() {
 	// https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http-challenge
 	needHTTP01Listener := !disableTLS
 
-	servers := make([]*http.Server, 0)
+	listeners := make([]net.Listener, 0)
 	for _, bindaddr := range ptInfo.Bindaddrs {
 		if bindaddr.MethodName != ptMethodName {
 			pt.SmethodError(bindaddr.MethodName, "no such method")
@@ -560,29 +198,47 @@ func main() {
 			go func() {
 				log.Fatal(server.Serve(lnHTTP01))
 			}()
-			servers = append(servers, server)
+			listeners = append(listeners, lnHTTP01)
 			needHTTP01Listener = false
 		}
 
-		var server *http.Server
+		// We're not capable of listening on port 0 (i.e., an ephemeral port
+		// unknown in advance). The reason is that while the net/http package
+		// exposes ListenAndServe and ListenAndServeTLS, those functions never
+		// return, so there's no opportunity to find out what the port number
+		// is, in between the Listen and Serve steps.
+		// https://groups.google.com/d/msg/Golang-nuts/3F1VRCCENp8/3hcayZiwYM8J
+		if bindaddr.Addr.Port == 0 {
+			err := fmt.Errorf(
+				"cannot listen on port %d; configure a port using ServerTransportListenAddr",
+				bindaddr.Addr.Port)
+			log.Printf("error opening listener: %s", err)
+			pt.SmethodError(bindaddr.MethodName, err.Error())
+			continue
+		}
+
+		var transport *sf.Transport
 		args := pt.Args{}
 		if disableTLS {
 			args.Add("tls", "no")
-			server, err = startServer(bindaddr.Addr)
+			transport = sf.NewSnowflakeServer(nil)
 		} else {
 			args.Add("tls", "yes")
 			for _, hostname := range acmeHostnames {
 				args.Add("hostname", hostname)
 			}
-			server, err = startServerTLS(bindaddr.Addr, certManager.GetCertificate)
+			transport = sf.NewSnowflakeServer(certManager.GetCertificate)
 		}
+		ln, err := transport.Listen(bindaddr.Addr)
 		if err != nil {
 			log.Printf("error opening listener: %s", err)
 			pt.SmethodError(bindaddr.MethodName, err.Error())
 			continue
 		}
+		defer ln.Close()
+		go acceptLoop(ln)
 		pt.SmethodArgs(bindaddr.MethodName, bindaddr.Addr, args)
-		servers = append(servers, server)
+		listeners = append(listeners, ln)
 	}
 	pt.SmethodsDone()
 
@@ -606,7 +262,7 @@ func main() {
 
 	// Signal received, shut down.
 	log.Printf("caught signal %q, exiting", sig)
-	for _, server := range servers {
-		server.Close()
+	for _, ln := range listeners {
+		ln.Close()
 	}
 }
diff --git a/server/server_test.go b/server/server_test.go
deleted file mode 100644
index ba00d16..0000000
--- a/server/server_test.go
+++ /dev/null
@@ -1,153 +0,0 @@
-package main
-
-import (
-	"net"
-	"net/http"
-	"strconv"
-	"testing"
-
-	"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
-	"github.com/gorilla/websocket"
-	. "github.com/smartystreets/goconvey/convey"
-)
-
-func TestClientAddr(t *testing.T) {
-	Convey("Testing clientAddr", t, func() {
-		// good tests
-		for _, test := range []struct {
-			input    string
-			expected net.IP
-		}{
-			{"1.2.3.4", net.ParseIP("1.2.3.4")},
-			{"1:2::3:4", net.ParseIP("1:2::3:4")},
-		} {
-			useraddr := clientAddr(test.input)
-			host, port, err := net.SplitHostPort(useraddr)
-			if err != nil {
-				t.Errorf("clientAddr(%q) → SplitHostPort error %v", test.input, err)
-				continue
-			}
-			if !test.expected.Equal(net.ParseIP(host)) {
-				t.Errorf("clientAddr(%q) → host %q, not %v", test.input, host, test.expected)
-			}
-			portNo, err := strconv.Atoi(port)
-			if err != nil {
-				t.Errorf("clientAddr(%q) → port %q", test.input, port)
-				continue
-			}
-			if portNo == 0 {
-				t.Errorf("clientAddr(%q) → port %d", test.input, portNo)
-			}
-		}
-
-		// bad tests
-		for _, input := range []string{
-			"",
-			"abc",
-			"1.2.3.4.5",
-			"[12::34]",
-			"0.0.0.0",
-			"[::]",
-		} {
-			useraddr := clientAddr(input)
-			if useraddr != "" {
-				t.Errorf("clientAddr(%q) → %q, not %q", input, useraddr, "")
-			}
-		}
-	})
-}
-
-type StubHandler struct{}
-
-func (handler *StubHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ws, _ := upgrader.Upgrade(w, r, nil)
-
-	conn := websocketconn.New(ws)
-	defer conn.Close()
-
-	//dial stub OR
-	or, _ := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.ParseIP("localhost"), Port: 8889})
-
-	proxy(or, conn)
-}
-
-func Test(t *testing.T) {
-	Convey("Websocket server", t, func() {
-		//Set up the snowflake web server
-		ipStr, portStr, _ := net.SplitHostPort(":8888")
-		port, _ := strconv.ParseUint(portStr, 10, 16)
-		addr := &net.TCPAddr{IP: net.ParseIP(ipStr), Port: int(port)}
-		Convey("We don't listen on port 0", func() {
-			addr = &net.TCPAddr{IP: net.ParseIP(ipStr), Port: 0}
-			server, err := initServer(addr, nil,
-				func(server *http.Server, errChan chan<- error) {
-					return
-				})
-			So(err, ShouldNotBeNil)
-			So(server, ShouldBeNil)
-		})
-
-		Convey("Plain HTTP server accepts connections", func(c C) {
-			server, err := startServer(addr)
-			So(err, ShouldBeNil)
-
-			ws, _, err := websocket.DefaultDialer.Dial("ws://localhost:8888", nil)
-			wsConn := websocketconn.New(ws)
-			So(err, ShouldEqual, nil)
-			So(wsConn, ShouldNotEqual, nil)
-
-			server.Close()
-			wsConn.Close()
-
-		})
-		Convey("Handler proxies data", func(c C) {
-
-			laddr := &net.TCPAddr{IP: net.ParseIP("localhost"), Port: 8889}
-
-			go func() {
-
-				//stub OR
-				listener, err := net.ListenTCP("tcp", laddr)
-				c.So(err, ShouldBeNil)
-				conn, err := listener.Accept()
-				c.So(err, ShouldBeNil)
-
-				b := make([]byte, 5)
-				n, err := conn.Read(b)
-				c.So(err, ShouldBeNil)
-				c.So(n, ShouldEqual, 5)
-				c.So(b, ShouldResemble, []byte("Hello"))
-
-				n, err = conn.Write([]byte("world!"))
-				c.So(n, ShouldEqual, 6)
-				c.So(err, ShouldBeNil)
-			}()
-
-			//overwite handler
-			server, err := initServer(addr, nil,
-				func(server *http.Server, errChan chan<- error) {
-					server.ListenAndServe()
-				})
-			So(err, ShouldBeNil)
-
-			var handler StubHandler
-			server.Handler = &handler
-
-			ws, _, err := websocket.DefaultDialer.Dial("ws://localhost:8888", nil)
-			So(err, ShouldEqual, nil)
-			wsConn := websocketconn.New(ws)
-			So(wsConn, ShouldNotEqual, nil)
-
-			wsConn.Write([]byte("Hello"))
-			b := make([]byte, 6)
-			n, err := wsConn.Read(b)
-			So(n, ShouldEqual, 6)
-			So(b, ShouldResemble, []byte("world!"))
-
-			wsConn.Close()
-			server.Close()
-
-		})
-
-	})
-}



More information about the tor-commits mailing list