[tor-commits] [snowflake/main] Turn the proxy code into a library
cohosh at torproject.org
cohosh at torproject.org
Tue Oct 26 18:17:37 UTC 2021
commit 50e4f4fd61596bab254cb34e850c9ae63d82f891
Author: idk <hankhill19580 at gmail.com>
Date: Mon Oct 25 22:51:40 2021 -0400
Turn the proxy code into a library
Allow other go programs to easily import the snowflake proxy library and
start/stop a snowflake proxy.
---
proxy/{ => lib}/proxy-go_test.go | 8 +-
proxy/{ => lib}/snowflake.go | 205 ++++++++++++++++++++++-----------------
proxy/{ => lib}/tokens.go | 2 +-
proxy/{ => lib}/tokens_test.go | 2 +-
proxy/{ => lib}/util.go | 18 +++-
proxy/{ => lib}/webrtcconn.go | 2 +-
proxy/main.go | 48 +++++++++
7 files changed, 185 insertions(+), 100 deletions(-)
diff --git a/proxy/proxy-go_test.go b/proxy/lib/proxy-go_test.go
similarity index 98%
rename from proxy/proxy-go_test.go
rename to proxy/lib/proxy-go_test.go
index 6fb5a0b9..af71648 100644
--- a/proxy/proxy-go_test.go
+++ b/proxy/lib/proxy-go_test.go
@@ -1,4 +1,4 @@
-package main
+package snowflake
import (
"bytes"
@@ -365,7 +365,7 @@ func TestBrokerInteractions(t *testing.T) {
b,
}
- sdp := broker.pollOffer(sampleOffer)
+ sdp := broker.pollOffer(sampleOffer, nil)
expectedSDP, _ := strconv.Unquote(sampleSDP)
So(sdp.SDP, ShouldResemble, expectedSDP)
})
@@ -379,7 +379,7 @@ func TestBrokerInteractions(t *testing.T) {
b,
}
- sdp := broker.pollOffer(sampleOffer)
+ sdp := broker.pollOffer(sampleOffer, nil)
So(sdp, ShouldBeNil)
})
Convey("sends answer to broker", func() {
@@ -478,7 +478,7 @@ func TestUtilityFuncs(t *testing.T) {
Convey("CopyLoop", t, func() {
c1, s1 := net.Pipe()
c2, s2 := net.Pipe()
- go CopyLoop(s1, s2)
+ go copyLoop(s1, s2, nil)
go func() {
bytes := []byte("Hello!")
c1.Write(bytes)
diff --git a/proxy/snowflake.go b/proxy/lib/snowflake.go
similarity index 72%
rename from proxy/snowflake.go
rename to proxy/lib/snowflake.go
index 7d7f9a2..e35eabd 100644
--- a/proxy/snowflake.go
+++ b/proxy/lib/snowflake.go
@@ -1,10 +1,9 @@
-package main
+package snowflake
import (
"bytes"
"crypto/rand"
"encoding/base64"
- "flag"
"fmt"
"io"
"io/ioutil"
@@ -12,27 +11,44 @@ import (
"net"
"net/http"
"net/url"
- "os"
"strings"
"sync"
"time"
"git.torproject.org/pluggable-transports/snowflake.git/common/messages"
- "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
"git.torproject.org/pluggable-transports/snowflake.git/common/util"
"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
"github.com/gorilla/websocket"
"github.com/pion/webrtc/v3"
)
-const defaultBrokerURL = "https://snowflake-broker.torproject.net/"
-const defaultProbeURL = "https://snowflake-broker.torproject.net:8443/probe"
-const defaultRelayURL = "wss://snowflake.torproject.net/"
-const defaultSTUNURL = "stun:stun.stunprotocol.org:3478"
+// DefaultBrokerURL is the bamsoftware.com broker, https://snowflake-broker.bamsoftware.com
+// Changing this will change the default broker. The recommended way of changing
+// the broker that gets used is by passing an argument to Main.
+const DefaultBrokerURL = "https://snowflake-broker.bamsoftware.com/"
+
+// DefaultProbeURL is the torproject.org ProbeURL, https://snowflake-broker.torproject.net:8443/probe
+// Changing this will change the default Probe URL. The recommended way of changing
+// the probe that gets used is by passing an argument to Main.
+const DefaultProbeURL = "https://snowflake-broker.torproject.net:8443/probe"
+
+// DefaultRelayURL is the bamsoftware.com Websocket Relay, wss://snowflake.bamsoftware.com/
+// Changing this will change the default Relay URL. The recommended way of changing
+// the relay that gets used is by passing an argument to Main.
+const DefaultRelayURL = "wss://snowflake.bamsoftware.com/"
+
+// DefaultSTUNURL is a stunprotocol.org STUN URL. stun:stun.stunprotocol.org:3478
+// Changing this will change the default STUN URL. The recommended way of changing
+// the STUN Server that gets used is by passing an argument to Main.
+const DefaultSTUNURL = "stun:stun.stunprotocol.org:3478"
const pollInterval = 5 * time.Second
+
const (
- NATUnknown = "unknown"
- NATRestricted = "restricted"
+ // NATUnknown represents a NAT type which is unknown.
+ NATUnknown = "unknown"
+ // NATRestricted represents a restricted NAT.
+ NATRestricted = "restricted"
+ // NATUnrestricted represents an unrestricted NAT.
NATUnrestricted = "unrestricted"
)
@@ -43,7 +59,6 @@ const dataChannelTimeout = 20 * time.Second
const readLimit = 100000 //Maximum number of bytes to be read from an HTTP request
var broker *SignalingServer
-var relayURL string
var currentNATType = NATUnknown
@@ -57,6 +72,18 @@ var (
client http.Client
)
+// SnowflakeProxy is a structure which is used to configure an embedded
+// Snowflake in another Go application.
+type SnowflakeProxy struct {
+ Capacity uint
+ StunURL string
+ RawBrokerURL string
+ KeepLocalAddresses bool
+ RelayURL string
+ LogOutput io.Writer
+ shutdown chan struct{}
+}
+
// Checks whether an IP address is a remote address for the client
func isRemoteAddress(ip net.IP) bool {
return !(util.IsLocal(ip) || ip.IsUnspecified() || ip.IsLoopback())
@@ -81,6 +108,7 @@ func limitedRead(r io.Reader, limit int64) ([]byte, error) {
return p, err
}
+// SignalingServer keeps track of the SignalingServer in use by the Snowflake
type SignalingServer struct {
url *url.URL
transport http.RoundTripper
@@ -102,6 +130,7 @@ func newSignalingServer(rawURL string, keepLocalAddresses bool) (*SignalingServe
return s, nil
}
+// Post sends a POST request to the SignalingServer
func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) {
req, err := http.NewRequest("POST", path, payload)
@@ -121,7 +150,7 @@ func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) {
return limitedRead(resp.Body, readLimit)
}
-func (s *SignalingServer) pollOffer(sid string) *webrtc.SessionDescription {
+func (s *SignalingServer) pollOffer(sid string, shutdown chan struct{}) *webrtc.SessionDescription {
brokerPath := s.url.ResolveReference(&url.URL{Path: "proxy"})
ticker := time.NewTicker(pollInterval)
@@ -129,31 +158,36 @@ func (s *SignalingServer) pollOffer(sid string) *webrtc.SessionDescription {
// Run the loop once before hitting the ticker
for ; true; <-ticker.C {
- numClients := int((tokens.count() / 8) * 8) // Round down to 8
- body, err := messages.EncodePollRequest(sid, "standalone", currentNATType, numClients)
- if err != nil {
- log.Printf("Error encoding poll message: %s", err.Error())
+ select {
+ case <-shutdown:
return nil
- }
- resp, err := s.Post(brokerPath.String(), bytes.NewBuffer(body))
- if err != nil {
- log.Printf("error polling broker: %s", err.Error())
- }
+ default:
+ numClients := int((tokens.count() / 8) * 8) // Round down to 8
+ body, err := messages.EncodePollRequest(sid, "standalone", currentNATType, numClients)
+ if err != nil {
+ log.Printf("Error encoding poll message: %s", err.Error())
+ return nil
+ }
+ resp, err := s.Post(brokerPath.String(), bytes.NewBuffer(body))
+ if err != nil {
+ log.Printf("error polling broker: %s", err.Error())
+ }
- offer, _, err := messages.DecodePollResponse(resp)
- if err != nil {
- log.Printf("Error reading broker response: %s", err.Error())
- log.Printf("body: %s", resp)
- return nil
- }
- if offer != "" {
- offer, err := util.DeserializeSessionDescription(offer)
+ offer, _, err := messages.DecodePollResponse(resp)
if err != nil {
- log.Printf("Error processing session description: %s", err.Error())
+ log.Printf("Error reading broker response: %s", err.Error())
+ log.Printf("body: %s", resp)
return nil
}
- return offer
+ if offer != "" {
+ offer, err := util.DeserializeSessionDescription(offer)
+ if err != nil {
+ log.Printf("Error processing session description: %s", err.Error())
+ return nil
+ }
+ return offer
+ }
}
}
return nil
@@ -192,33 +226,41 @@ func (s *SignalingServer) sendAnswer(sid string, pc *webrtc.PeerConnection) erro
return nil
}
-func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) {
- var wg sync.WaitGroup
+func copyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser, shutdown chan struct{}) {
+ var once sync.Once
+ defer c2.Close()
+ defer c1.Close()
+ done := make(chan struct{})
copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
- defer wg.Done()
// Ignore io.ErrClosedPipe because it is likely caused by the
// termination of copyer in the other direction.
if _, err := io.Copy(dst, src); err != nil && err != io.ErrClosedPipe {
log.Printf("io.Copy inside CopyLoop generated an error: %v", err)
}
- dst.Close()
- src.Close()
+ once.Do(func() {
+ close(done)
+ })
}
- wg.Add(2)
+
go copyer(c1, c2)
go copyer(c2, c1)
- wg.Wait()
+
+ select {
+ case <-done:
+ case <-shutdown:
+ }
+ log.Println("copy loop ended")
}
// We pass conn.RemoteAddr() as an additional parameter, rather than calling
// conn.RemoteAddr() inside this function, as a workaround for a hang that
// otherwise occurs inside of conn.pc.RemoteDescription() (called by
// RemoteAddr). https://bugs.torproject.org/18628#comment:8
-func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
+func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
defer conn.Close()
defer tokens.ret()
- u, err := url.Parse(relayURL)
+ u, err := url.Parse(sf.RelayURL)
if err != nil {
log.Fatalf("invalid relay url: %s", err)
}
@@ -241,7 +283,7 @@ func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
wsConn := websocketconn.New(ws)
log.Printf("connected to relay")
defer wsConn.Close()
- CopyLoop(conn, wsConn)
+ copyLoop(conn, wsConn, sf.shutdown)
log.Printf("datachannelHandler ends")
}
@@ -249,7 +291,7 @@ func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
// candidates is complete and the answer is available in LocalDescription.
// Installs an OnDataChannel callback that creates a webRTCConn and passes it to
// datachannelHandler.
-func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription,
+func (sf *SnowflakeProxy) makePeerConnectionFromOffer(sdp *webrtc.SessionDescription,
config webrtc.Configuration,
dataChan chan struct{},
handler func(conn *webRTCConn, remoteAddr net.Addr)) (*webrtc.PeerConnection, error) {
@@ -333,7 +375,7 @@ func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription,
// Create a new PeerConnection. Blocks until the gathering of ICE
// candidates is complete and the answer is available in LocalDescription.
-func makeNewPeerConnection(config webrtc.Configuration,
+func (sf *SnowflakeProxy) makeNewPeerConnection(config webrtc.Configuration,
dataChan chan struct{}) (*webrtc.PeerConnection, error) {
pc, err := webrtc.NewPeerConnection(config)
@@ -383,15 +425,15 @@ func makeNewPeerConnection(config webrtc.Configuration,
return pc, nil
}
-func runSession(sid string) {
- offer := broker.pollOffer(sid)
+func (sf *SnowflakeProxy) runSession(sid string) {
+ offer := broker.pollOffer(sid, sf.shutdown)
if offer == nil {
log.Printf("bad offer from broker")
tokens.ret()
return
}
dataChan := make(chan struct{})
- pc, err := makePeerConnectionFromOffer(offer, config, dataChan, datachannelHandler)
+ pc, err := sf.makePeerConnectionFromOffer(offer, config, dataChan, sf.datachannelHandler)
if err != nil {
log.Printf("error making WebRTC connection: %s", err)
tokens.ret()
@@ -421,53 +463,28 @@ func runSession(sid string) {
}
}
-func main() {
- var capacity uint
- var stunURL string
- var logFilename string
- var rawBrokerURL string
- var unsafeLogging bool
- var keepLocalAddresses bool
-
- flag.UintVar(&capacity, "capacity", 0, "maximum concurrent clients")
- flag.StringVar(&rawBrokerURL, "broker", defaultBrokerURL, "broker URL")
- flag.StringVar(&relayURL, "relay", defaultRelayURL, "websocket relay URL")
- flag.StringVar(&stunURL, "stun", defaultSTUNURL, "stun URL")
- flag.StringVar(&logFilename, "log", "", "log filename")
- flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed")
- flag.BoolVar(&keepLocalAddresses, "keep-local-addresses", false, "keep local LAN address ICE candidates")
- flag.Parse()
-
- var logOutput io.Writer = os.Stderr
+// Start configures and starts a Snowflake, fully formed and special. In the
+// case of an empty map, defaults are configured automatically and can be
+// found in the GoDoc and in main.go
+func (sf *SnowflakeProxy) Start() {
+
+ sf.shutdown = make(chan struct{})
+
log.SetFlags(log.LstdFlags | log.LUTC)
- if logFilename != "" {
- f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
- if err != nil {
- log.Fatal(err)
- }
- defer f.Close()
- logOutput = io.MultiWriter(os.Stderr, f)
- }
- if unsafeLogging {
- log.SetOutput(logOutput)
- } else {
- // We want to send the log output through our scrubber first
- log.SetOutput(&safelog.LogScrubber{Output: logOutput})
- }
log.Println("starting")
var err error
- broker, err = newSignalingServer(rawBrokerURL, keepLocalAddresses)
+ broker, err = newSignalingServer(sf.RawBrokerURL, sf.KeepLocalAddresses)
if err != nil {
log.Fatal(err)
}
- _, err = url.Parse(stunURL)
+ _, err = url.Parse(sf.StunURL)
if err != nil {
log.Fatalf("invalid stun url: %s", err)
}
- _, err = url.Parse(relayURL)
+ _, err = url.Parse(sf.RelayURL)
if err != nil {
log.Fatalf("invalid relay url: %s", err)
}
@@ -475,27 +492,37 @@ func main() {
config = webrtc.Configuration{
ICEServers: []webrtc.ICEServer{
{
- URLs: []string{stunURL},
+ URLs: []string{sf.StunURL},
},
},
}
- tokens = newTokens(capacity)
+ tokens = newTokens(sf.Capacity)
// use probetest to determine NAT compatability
- checkNATType(config, defaultProbeURL)
+ sf.checkNATType(config, DefaultProbeURL)
log.Printf("NAT type: %s", currentNATType)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for ; true; <-ticker.C {
- tokens.get()
- sessionID := genSessionID()
- runSession(sessionID)
+ select {
+ case <-sf.shutdown:
+ return
+ default:
+ tokens.get()
+ sessionID := genSessionID()
+ sf.runSession(sessionID)
+ }
}
}
-func checkNATType(config webrtc.Configuration, probeURL string) {
+// Stop calls close on the sf.shutdown channel shutting down the Snowflake.
+func (sf *SnowflakeProxy) Stop() {
+ close(sf.shutdown)
+}
+
+func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) {
probe, err := newSignalingServer(probeURL, false)
if err != nil {
@@ -504,7 +531,7 @@ func checkNATType(config webrtc.Configuration, probeURL string) {
// create offer
dataChan := make(chan struct{})
- pc, err := makeNewPeerConnection(config, dataChan)
+ pc, err := sf.makeNewPeerConnection(config, dataChan)
if err != nil {
log.Printf("error making WebRTC connection: %s", err)
return
diff --git a/proxy/tokens.go b/proxy/lib/tokens.go
similarity index 97%
rename from proxy/tokens.go
rename to proxy/lib/tokens.go
index fedb8f7..1331778 100644
--- a/proxy/tokens.go
+++ b/proxy/lib/tokens.go
@@ -1,4 +1,4 @@
-package main
+package snowflake
import (
"sync/atomic"
diff --git a/proxy/tokens_test.go b/proxy/lib/tokens_test.go
similarity index 96%
rename from proxy/tokens_test.go
rename to proxy/lib/tokens_test.go
index 622cc05..702a887 100644
--- a/proxy/tokens_test.go
+++ b/proxy/lib/tokens_test.go
@@ -1,4 +1,4 @@
-package main
+package snowflake
import (
"testing"
diff --git a/proxy/util.go b/proxy/lib/util.go
similarity index 71%
rename from proxy/util.go
rename to proxy/lib/util.go
index d737056..c6613d9 100644
--- a/proxy/util.go
+++ b/proxy/lib/util.go
@@ -1,21 +1,28 @@
-package main
+package snowflake
import (
"fmt"
"time"
)
+// BytesLogger is an interface which is used to allow logging the throughput
+// of the Snowflake. A default BytesLogger(BytesNullLogger) does nothing.
type BytesLogger interface {
AddOutbound(int)
AddInbound(int)
ThroughputSummary() string
}
-// Default BytesLogger does nothing.
+// BytesNullLogger Default BytesLogger does nothing.
type BytesNullLogger struct{}
-func (b BytesNullLogger) AddOutbound(amount int) {}
-func (b BytesNullLogger) AddInbound(amount int) {}
+// AddOutbound in BytesNullLogger does nothing
+func (b BytesNullLogger) AddOutbound(amount int) {}
+
+// AddInbound in BytesNullLogger does nothing
+func (b BytesNullLogger) AddInbound(amount int) {}
+
+// ThroughputSummary in BytesNullLogger does nothing
func (b BytesNullLogger) ThroughputSummary() string { return "" }
// BytesSyncLogger uses channels to safely log from multiple sources with output
@@ -50,14 +57,17 @@ func (b *BytesSyncLogger) log() {
}
}
+// AddOutbound add a number of bytes to the outbound total reported by the logger
func (b *BytesSyncLogger) AddOutbound(amount int) {
b.outboundChan <- amount
}
+// AddInbound add a number of bytes to the inbound total reported by the logger
func (b *BytesSyncLogger) AddInbound(amount int) {
b.inboundChan <- amount
}
+// ThroughputSummary view a formatted summary of the throughput totals
func (b *BytesSyncLogger) ThroughputSummary() string {
var inUnit, outUnit string
units := []string{"B", "KB", "MB", "GB"}
diff --git a/proxy/webrtcconn.go b/proxy/lib/webrtcconn.go
similarity index 99%
rename from proxy/webrtcconn.go
rename to proxy/lib/webrtcconn.go
index 5d95919..5c6192b 100644
--- a/proxy/webrtcconn.go
+++ b/proxy/lib/webrtcconn.go
@@ -1,4 +1,4 @@
-package main
+package snowflake
import (
"fmt"
diff --git a/proxy/main.go b/proxy/main.go
new file mode 100644
index 0000000..12b3752
--- /dev/null
+++ b/proxy/main.go
@@ -0,0 +1,48 @@
+package main
+
+import (
+ "flag"
+ "io"
+ "log"
+ "os"
+
+ "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
+ "git.torproject.org/pluggable-transports/snowflake.git/proxy/lib"
+)
+
+func main() {
+ capacity := flag.Int("capacity", 10, "maximum concurrent clients")
+ stunURL := flag.String("stun", snowflake.DefaultSTUNURL, "broker URL")
+ logFilename := flag.String("log", "", "log filename")
+ rawBrokerURL := flag.String("broker", snowflake.DefaultBrokerURL, "broker URL")
+ unsafeLogging := flag.Bool("unsafe-logging", false, "prevent logs from being scrubbed")
+ keepLocalAddresses := flag.Bool("keep-local-addresses", false, "keep local LAN address ICE candidates")
+ relayURL := flag.String("relay", snowflake.DefaultRelayURL, "websocket relay URL")
+
+ flag.Parse()
+
+ sf := snowflake.SnowflakeProxy{
+ Capacity: uint(*capacity),
+ StunURL: *stunURL,
+ RawBrokerURL: *rawBrokerURL,
+ KeepLocalAddresses: *keepLocalAddresses,
+ RelayURL: *relayURL,
+ LogOutput: os.Stderr,
+ }
+
+ if *logFilename != "" {
+ f, err := os.OpenFile(*logFilename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer f.Close()
+ sf.LogOutput = io.MultiWriter(os.Stderr, f)
+ }
+ if *unsafeLogging {
+ log.SetOutput(sf.LogOutput)
+ } else {
+ log.SetOutput(&safelog.LogScrubber{Output: sf.LogOutput})
+ }
+
+ sf.Start()
+}
More information about the tor-commits
mailing list