[tor-commits] [websocket/master] Stop setting GOPATH.
dcf at torproject.org
dcf at torproject.org
Fri Dec 13 07:14:20 UTC 2013
commit fa2f68b5d91e07670622ab7054a2d96d85183c0e
Author: David Fifield <david at bamsoftware.com>
Date: Wed Dec 11 01:05:26 2013 -0800
Stop setting GOPATH.
This is something the users is supposed to set for themselves and we
shouldn't mess with it. I went back to relative imports as that seems to
be supported by gccgo now.
http://golang.org/doc/code.html really wants you to take all the Go code
you download and put it in $GOPATH, with imports relative to $GOPATH.
Downloading a repo elsewhere and then doing "go build", like we're used
to with other source packages, doesn't work as well.
---
Makefile | 23 +-
pt/.gitignore | 2 +
pt/examples/dummy-client/dummy-client.go | 137 ++++++
pt/examples/dummy-server/dummy-server.go | 121 +++++
pt/pt.go | 611 ++++++++++++++++++++++++++
pt/pt_test.go | 61 +++
pt/socks/socks.go | 107 +++++
src/pt/.gitignore | 2 -
src/pt/examples/dummy-client/dummy-client.go | 137 ------
src/pt/examples/dummy-server/dummy-server.go | 121 -----
src/pt/pt.go | 611 --------------------------
src/pt/pt_test.go | 61 ---
src/pt/socks/socks.go | 107 -----
src/websocket-client/websocket-client.go | 254 -----------
src/websocket-server/websocket-server.go | 285 ------------
src/websocket/websocket.go | 431 ------------------
websocket-client/websocket-client.go | 254 +++++++++++
websocket-server/websocket-server.go | 285 ++++++++++++
websocket/websocket.go | 431 ++++++++++++++++++
19 files changed, 2018 insertions(+), 2023 deletions(-)
diff --git a/Makefile b/Makefile
index 2ce4ffa..95aa243 100644
--- a/Makefile
+++ b/Makefile
@@ -2,33 +2,28 @@ DESTDIR =
PREFIX = /usr/local
BINDIR = $(PREFIX)/bin
-PROGRAMS = websocket-client websocket-server
-
-export GOPATH = $(CURDIR)
GOBUILDFLAGS =
# Alternate flags to use gccgo, allowing cross-compiling for x86 from
# x86_64, and presumably better optimization. Install this package:
# apt-get install gccgo-multilib
# GOBUILDFLAGS = -compiler gccgo -gccgoflags "-O3 -m32 -static-libgo"
-all: websocket-server
+all: websocket-server/websocket-server
-%: $(GOPATH)/src/%/*.go
- go build $(GOBUILDFLAGS) "$*"
+websocket-server/websocket-server: websocket-server/*.go websocket/*.go pt/*.go
+ cd websocket-server && go build $(GOBUILDFLAGS)
-# websocket-client has a special rule because "go get" is necessary.
-websocket-client: $(GOPATH)/src/websocket-client/*.go
- go get -d $(GOBUILDFLAGS) websocket-client
- go build $(GOBUILDFLAGS) websocket-client
+websocket-client/websocket-client: websocket-client/*.go websocket/*.go pt/*.go
+ cd websocket-client && go build $(GOBUILDFLAGS)
-install:
+install: websocket-server/websocket-server
mkdir -p "$(DESTDIR)$(BINDIR)"
- cp -f websocket-server "$(DESTDIR)$(BINDIR)"
+ cp -f websocket-server/websocket-server "$(DESTDIR)$(BINDIR)"
clean:
- rm -f $(PROGRAMS)
+ rm -f websocket-server/websocket-server websocket-client/websocket-client
fmt:
- go fmt $(PROGRAMS)
+ go fmt ./websocket-server ./websocket-client ./websocket ./pt
.PHONY: all install clean fmt
diff --git a/pt/.gitignore b/pt/.gitignore
new file mode 100644
index 0000000..d4d5132
--- /dev/null
+++ b/pt/.gitignore
@@ -0,0 +1,2 @@
+/examples/dummy-client/dummy-client
+/examples/dummy-server/dummy-server
diff --git a/pt/examples/dummy-client/dummy-client.go b/pt/examples/dummy-client/dummy-client.go
new file mode 100644
index 0000000..3cf7b45
--- /dev/null
+++ b/pt/examples/dummy-client/dummy-client.go
@@ -0,0 +1,137 @@
+// Usage (in torrc):
+// UseBridges 1
+// Bridge dummy X.X.X.X:YYYY
+// ClientTransportPlugin dummy exec dummy-client
+// Because this transport doesn't do anything to the traffic, you can use any
+// ordinary relay's ORPort in the Bridge line.
+
+package main
+
+import (
+ "io"
+ "net"
+ "os"
+ "os/signal"
+ "sync"
+ "syscall"
+)
+
+import "git.torproject.org/pluggable-transports/websocket.git/src/pt"
+import "git.torproject.org/pluggable-transports/websocket.git/src/pt/socks"
+
+var ptInfo pt.ClientInfo
+
+// When a connection handler starts, +1 is written to this channel; when it
+// ends, -1 is written.
+var handlerChan = make(chan int)
+
+func copyLoop(a, b net.Conn) {
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ go func() {
+ io.Copy(b, a)
+ wg.Done()
+ }()
+ go func() {
+ io.Copy(a, b)
+ wg.Done()
+ }()
+
+ wg.Wait()
+}
+
+func handleConnection(local net.Conn) error {
+ defer local.Close()
+
+ handlerChan <- 1
+ defer func() {
+ handlerChan <- -1
+ }()
+
+ var remote net.Conn
+ err := socks.AwaitSocks4aConnect(local.(*net.TCPConn), func(dest string) (*net.TCPAddr, error) {
+ var err error
+ // set remote in outer function environment
+ remote, err = net.Dial("tcp", dest)
+ if err != nil {
+ return nil, err
+ }
+ return remote.RemoteAddr().(*net.TCPAddr), nil
+ })
+ if err != nil {
+ return err
+ }
+ defer remote.Close()
+ copyLoop(local, remote)
+
+ return nil
+}
+
+func acceptLoop(ln net.Listener) error {
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ return err
+ }
+ go handleConnection(conn)
+ }
+ return nil
+}
+
+func startListener(addr string) (net.Listener, error) {
+ ln, err := net.Listen("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ go acceptLoop(ln)
+ return ln, nil
+}
+
+func main() {
+ ptInfo = pt.ClientSetup([]string{"dummy"})
+
+ listeners := make([]net.Listener, 0)
+ for _, methodName := range ptInfo.MethodNames {
+ ln, err := startListener("127.0.0.1:0")
+ if err != nil {
+ pt.CmethodError(methodName, err.Error())
+ continue
+ }
+ pt.Cmethod(methodName, "socks4", ln.Addr())
+ listeners = append(listeners, ln)
+ }
+ pt.CmethodsDone()
+
+ var numHandlers int = 0
+ var sig os.Signal
+ sigChan := make(chan os.Signal, 1)
+ signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
+
+ // wait for first signal
+ sig = nil
+ for sig == nil {
+ select {
+ case n := <-handlerChan:
+ numHandlers += n
+ case sig = <-sigChan:
+ }
+ }
+ for _, ln := range listeners {
+ ln.Close()
+ }
+
+ if sig == syscall.SIGTERM {
+ return
+ }
+
+ // wait for second signal or no more handlers
+ sig = nil
+ for sig == nil && numHandlers != 0 {
+ select {
+ case n := <-handlerChan:
+ numHandlers += n
+ case sig = <-sigChan:
+ }
+ }
+}
diff --git a/pt/examples/dummy-server/dummy-server.go b/pt/examples/dummy-server/dummy-server.go
new file mode 100644
index 0000000..26314d0
--- /dev/null
+++ b/pt/examples/dummy-server/dummy-server.go
@@ -0,0 +1,121 @@
+// Usage (in torrc):
+// BridgeRelay 1
+// ORPort 9001
+// ExtORPort 6669
+// ServerTransportPlugin dummy exec dummy-server
+
+package main
+
+import (
+ "io"
+ "net"
+ "os"
+ "os/signal"
+ "sync"
+ "syscall"
+)
+
+import "git.torproject.org/pluggable-transports/websocket.git/src/pt"
+
+var ptInfo pt.ServerInfo
+
+// When a connection handler starts, +1 is written to this channel; when it
+// ends, -1 is written.
+var handlerChan = make(chan int)
+
+func copyLoop(a, b net.Conn) {
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ go func() {
+ io.Copy(b, a)
+ wg.Done()
+ }()
+ go func() {
+ io.Copy(a, b)
+ wg.Done()
+ }()
+
+ wg.Wait()
+}
+
+func handleConnection(conn net.Conn) {
+ handlerChan <- 1
+ defer func() {
+ handlerChan <- -1
+ }()
+
+ or, err := pt.ConnectOr(&ptInfo, conn, "dummy")
+ if err != nil {
+ return
+ }
+ copyLoop(conn, or)
+}
+
+func acceptLoop(ln net.Listener) error {
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ return err
+ }
+ go handleConnection(conn)
+ }
+ return nil
+}
+
+func startListener(addr *net.TCPAddr) (net.Listener, error) {
+ ln, err := net.ListenTCP("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ go acceptLoop(ln)
+ return ln, nil
+}
+
+func main() {
+ ptInfo = pt.ServerSetup([]string{"dummy"})
+
+ listeners := make([]net.Listener, 0)
+ for _, bindAddr := range ptInfo.BindAddrs {
+ ln, err := startListener(bindAddr.Addr)
+ if err != nil {
+ pt.SmethodError(bindAddr.MethodName, err.Error())
+ continue
+ }
+ pt.Smethod(bindAddr.MethodName, ln.Addr())
+ listeners = append(listeners, ln)
+ }
+ pt.SmethodsDone()
+
+ var numHandlers int = 0
+ var sig os.Signal
+ sigChan := make(chan os.Signal, 1)
+ signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
+
+ // wait for first signal
+ sig = nil
+ for sig == nil {
+ select {
+ case n := <-handlerChan:
+ numHandlers += n
+ case sig = <-sigChan:
+ }
+ }
+ for _, ln := range listeners {
+ ln.Close()
+ }
+
+ if sig == syscall.SIGTERM {
+ return
+ }
+
+ // wait for second signal or no more handlers
+ sig = nil
+ for sig == nil && numHandlers != 0 {
+ select {
+ case n := <-handlerChan:
+ numHandlers += n
+ case sig = <-sigChan:
+ }
+ }
+}
diff --git a/pt/pt.go b/pt/pt.go
new file mode 100644
index 0000000..526f3b7
--- /dev/null
+++ b/pt/pt.go
@@ -0,0 +1,611 @@
+// Tor pluggable transports library.
+//
+// Sample client usage:
+//
+// import "git.torproject.org/pluggable-transports/websocket.git/src/pt"
+// var ptInfo pt.ClientInfo
+// ptInfo = pt.ClientSetup([]string{"foo"})
+// for _, methodName := range ptInfo.MethodNames {
+// ln, err := startSocksListener()
+// if err != nil {
+// pt.CmethodError(methodName, err.Error())
+// continue
+// }
+// pt.Cmethod(methodName, "socks4", ln.Addr())
+// }
+// pt.CmethodsDone()
+//
+// Sample server usage:
+//
+// import "git.torproject.org/pluggable-transports/websocket.git/src/pt"
+// var ptInfo pt.ServerInfo
+// ptInfo = pt.ServerSetup([]string{"foo", "bar"})
+// for _, bindAddr := range ptInfo.BindAddrs {
+// ln, err := startListener(bindAddr.Addr, bindAddr.MethodName)
+// if err != nil {
+// pt.SmethodError(bindAddr.MethodName, err.Error())
+// continue
+// }
+// pt.Smethod(bindAddr.MethodName, ln.Addr())
+// }
+// pt.SmethodsDone()
+// func handler(conn net.Conn, methodName string) {
+// or, err := pt.ConnectOr(&ptInfo, conn, methodName)
+// if err != nil {
+// return
+// }
+// // Do something with or and conn.
+// }
+
+package pt
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/hmac"
+ "crypto/rand"
+ "crypto/sha256"
+ "crypto/subtle"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "strings"
+ "time"
+)
+
+func getenv(key string) string {
+ return os.Getenv(key)
+}
+
+// Abort with an ENV-ERROR if the environment variable isn't set.
+func getenvRequired(key string) string {
+ value := os.Getenv(key)
+ if value == "" {
+ EnvError(fmt.Sprintf("no %s environment variable", key))
+ }
+ return value
+}
+
+// Escape a string so it contains no byte values over 127 and doesn't contain
+// any of the characters '\x00' or '\n'.
+func escape(s string) string {
+ var buf bytes.Buffer
+ for _, b := range []byte(s) {
+ if b == '\n' {
+ buf.WriteString("\\n")
+ } else if b == '\\' {
+ buf.WriteString("\\\\")
+ } else if 0 < b && b < 128 {
+ buf.WriteByte(b)
+ } else {
+ fmt.Fprintf(&buf, "\\x%02x", b)
+ }
+ }
+ return buf.String()
+}
+
+// Print a pluggable transports protocol line to stdout. The line consists of an
+// unescaped keyword, followed by any number of escaped strings.
+func Line(keyword string, v ...string) {
+ var buf bytes.Buffer
+ buf.WriteString(keyword)
+ for _, x := range v {
+ buf.WriteString(" " + escape(x))
+ }
+ fmt.Println(buf.String())
+ os.Stdout.Sync()
+}
+
+// All of the *Error functions call os.Exit(1).
+
+// Emit an ENV-ERROR with explanation text.
+func EnvError(msg string) {
+ Line("ENV-ERROR", msg)
+ os.Exit(1)
+}
+
+// Emit a VERSION-ERROR with explanation text.
+func VersionError(msg string) {
+ Line("VERSION-ERROR", msg)
+ os.Exit(1)
+}
+
+// Emit a CMETHOD-ERROR with explanation text.
+func CmethodError(methodName, msg string) {
+ Line("CMETHOD-ERROR", methodName, msg)
+ os.Exit(1)
+}
+
+// Emit an SMETHOD-ERROR with explanation text.
+func SmethodError(methodName, msg string) {
+ Line("SMETHOD-ERROR", methodName, msg)
+ os.Exit(1)
+}
+
+// Emit a CMETHOD line. socks must be "socks4" or "socks5". Call this once for
+// each listening client SOCKS port.
+func Cmethod(name string, socks string, addr net.Addr) {
+ Line("CMETHOD", name, socks, addr.String())
+}
+
+// Emit a CMETHODS DONE line. Call this after opening all client listeners.
+func CmethodsDone() {
+ Line("CMETHODS", "DONE")
+}
+
+// Emit an SMETHOD line. Call this once for each listening server port.
+func Smethod(name string, addr net.Addr) {
+ Line("SMETHOD", name, addr.String())
+}
+
+// Emit an SMETHODS DONE line. Call this after opening all server listeners.
+func SmethodsDone() {
+ Line("SMETHODS", "DONE")
+}
+
+// Get a pluggable transports version offered by Tor and understood by us, if
+// any. The only version we understand is "1". This function reads the
+// environment variable TOR_PT_MANAGED_TRANSPORT_VER.
+func getManagedTransportVer() string {
+ const transportVersion = "1"
+ for _, offered := range strings.Split(getenvRequired("TOR_PT_MANAGED_TRANSPORT_VER"), ",") {
+ if offered == transportVersion {
+ return offered
+ }
+ }
+ return ""
+}
+
+// Get the intersection of the method names offered by Tor and those in
+// methodNames. This function reads the environment variable
+// TOR_PT_CLIENT_TRANSPORTS.
+func getClientTransports(methodNames []string) []string {
+ clientTransports := getenvRequired("TOR_PT_CLIENT_TRANSPORTS")
+ if clientTransports == "*" {
+ return methodNames
+ }
+ result := make([]string, 0)
+ for _, requested := range strings.Split(clientTransports, ",") {
+ for _, methodName := range methodNames {
+ if requested == methodName {
+ result = append(result, methodName)
+ break
+ }
+ }
+ }
+ return result
+}
+
+// This structure is returned by ClientSetup. It consists of a list of method
+// names.
+type ClientInfo struct {
+ MethodNames []string
+}
+
+// Check the client pluggable transports environments, emitting an error message
+// and exiting the program if any error is encountered. Returns a subset of
+// methodNames requested by Tor.
+func ClientSetup(methodNames []string) ClientInfo {
+ var info ClientInfo
+
+ ver := getManagedTransportVer()
+ if ver == "" {
+ VersionError("no-version")
+ } else {
+ Line("VERSION", ver)
+ }
+
+ info.MethodNames = getClientTransports(methodNames)
+ if len(info.MethodNames) == 0 {
+ CmethodsDone()
+ os.Exit(1)
+ }
+
+ return info
+}
+
+// A combination of a method name and an address, as extracted from
+// TOR_PT_SERVER_BINDADDR.
+type BindAddr struct {
+ MethodName string
+ Addr *net.TCPAddr
+}
+
+// Resolve an address string into a net.TCPAddr.
+func resolveBindAddr(bindAddr string) (*net.TCPAddr, error) {
+ addr, err := net.ResolveTCPAddr("tcp", bindAddr)
+ if err == nil {
+ return addr, nil
+ }
+ // Before the fixing of bug #7011, tor doesn't put brackets around IPv6
+ // addresses. Split after the last colon, assuming it is a port
+ // separator, and try adding the brackets.
+ parts := strings.Split(bindAddr, ":")
+ if len(parts) <= 2 {
+ return nil, err
+ }
+ bindAddr = "[" + strings.Join(parts[:len(parts)-1], ":") + "]:" + parts[len(parts)-1]
+ return net.ResolveTCPAddr("tcp", bindAddr)
+}
+
+// Return a new slice, the members of which are those members of addrs having a
+// MethodName in methodNames.
+func filterBindAddrs(addrs []BindAddr, methodNames []string) []BindAddr {
+ var result []BindAddr
+
+ for _, ba := range addrs {
+ for _, methodName := range methodNames {
+ if ba.MethodName == methodName {
+ result = append(result, ba)
+ break
+ }
+ }
+ }
+
+ return result
+}
+
+// Return a map from method names to bind addresses. The map is the contents of
+// TOR_PT_SERVER_BINDADDR, with keys filtered by TOR_PT_SERVER_TRANSPORTS, and
+// further filtered by the methods in methodNames.
+func getServerBindAddrs(methodNames []string) []BindAddr {
+ var result []BindAddr
+
+ // Get the list of all requested bindaddrs.
+ var serverBindAddr = getenvRequired("TOR_PT_SERVER_BINDADDR")
+ for _, spec := range strings.Split(serverBindAddr, ",") {
+ var bindAddr BindAddr
+
+ parts := strings.SplitN(spec, "-", 2)
+ if len(parts) != 2 {
+ EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: doesn't contain \"-\"", spec))
+ }
+ bindAddr.MethodName = parts[0]
+ addr, err := resolveBindAddr(parts[1])
+ if err != nil {
+ EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: %s", spec, err.Error()))
+ }
+ bindAddr.Addr = addr
+ result = append(result, bindAddr)
+ }
+
+ // Filter by TOR_PT_SERVER_TRANSPORTS.
+ serverTransports := getenvRequired("TOR_PT_SERVER_TRANSPORTS")
+ if serverTransports != "*" {
+ result = filterBindAddrs(result, strings.Split(serverTransports, ","))
+ }
+
+ // Finally filter by what we understand.
+ result = filterBindAddrs(result, methodNames)
+
+ return result
+}
+
+// Read and validate the contents of an auth cookie file. Returns the 32-byte
+// cookie. See section 4.2.1.2 of pt-spec.txt.
+func readAuthCookieFile(filename string) ([]byte, error) {
+ authCookieHeader := []byte("! Extended ORPort Auth Cookie !\x0a")
+ header := make([]byte, 32)
+ cookie := make([]byte, 32)
+
+ f, err := os.Open(filename)
+ if err != nil {
+ return cookie, err
+ }
+ defer f.Close()
+
+ n, err := io.ReadFull(f, header)
+ if err != nil {
+ return cookie, err
+ }
+ n, err = io.ReadFull(f, cookie)
+ if err != nil {
+ return cookie, err
+ }
+ // Check that the file ends here.
+ n, err = f.Read(make([]byte, 1))
+ if n != 0 {
+ return cookie, errors.New(fmt.Sprintf("file is longer than 64 bytes"))
+ } else if err != io.EOF {
+ return cookie, errors.New(fmt.Sprintf("did not find EOF at end of file"))
+ }
+
+ if !bytes.Equal(header, authCookieHeader) {
+ return cookie, errors.New(fmt.Sprintf("missing auth cookie header"))
+ }
+
+ return cookie, nil
+}
+
+// This structure is returned by ServerSetup. It consists of a list of
+// BindAddrs, an address for the ORPort, an address for the extended ORPort (if
+// any), and an authentication cookie (if any).
+type ServerInfo struct {
+ BindAddrs []BindAddr
+ OrAddr *net.TCPAddr
+ ExtendedOrAddr *net.TCPAddr
+ AuthCookie []byte
+}
+
+// Check the server pluggable transports environments, emitting an error message
+// and exiting the program if any error is encountered. Resolves the various
+// requested bind addresses, the server ORPort and extended ORPort, and reads
+// the auth cookie file. Returns a ServerInfo struct.
+func ServerSetup(methodNames []string) ServerInfo {
+ var info ServerInfo
+ var err error
+
+ ver := getManagedTransportVer()
+ if ver == "" {
+ VersionError("no-version")
+ } else {
+ Line("VERSION", ver)
+ }
+
+ var orPort = getenvRequired("TOR_PT_ORPORT")
+ info.OrAddr, err = net.ResolveTCPAddr("tcp", orPort)
+ if err != nil {
+ EnvError(fmt.Sprintf("cannot resolve TOR_PT_ORPORT %q: %s", orPort, err.Error()))
+ }
+
+ info.BindAddrs = getServerBindAddrs(methodNames)
+ if len(info.BindAddrs) == 0 {
+ SmethodsDone()
+ os.Exit(1)
+ }
+
+ var extendedOrPort = getenv("TOR_PT_EXTENDED_SERVER_PORT")
+ if extendedOrPort != "" {
+ info.ExtendedOrAddr, err = net.ResolveTCPAddr("tcp", extendedOrPort)
+ if err != nil {
+ EnvError(fmt.Sprintf("cannot resolve TOR_PT_EXTENDED_SERVER_PORT %q: %s", extendedOrPort, err.Error()))
+ }
+ }
+
+ var authCookieFilename = getenv("TOR_PT_AUTH_COOKIE_FILE")
+ if authCookieFilename != "" {
+ info.AuthCookie, err = readAuthCookieFile(authCookieFilename)
+ if err != nil {
+ EnvError(fmt.Sprintf("error reading TOR_PT_AUTH_COOKIE_FILE %q: %s", authCookieFilename, err.Error()))
+ }
+ }
+
+ return info
+}
+
+// See 217-ext-orport-auth.txt section 4.2.1.3.
+func computeServerHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte {
+ h := hmac.New(sha256.New, info.AuthCookie)
+ io.WriteString(h, "ExtORPort authentication server-to-client hash")
+ h.Write(clientNonce)
+ h.Write(serverNonce)
+ return h.Sum([]byte{})
+}
+
+// See 217-ext-orport-auth.txt section 4.2.1.3.
+func computeClientHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte {
+ h := hmac.New(sha256.New, info.AuthCookie)
+ io.WriteString(h, "ExtORPort authentication client-to-server hash")
+ h.Write(clientNonce)
+ h.Write(serverNonce)
+ return h.Sum([]byte{})
+}
+
+func extOrPortAuthenticate(s *net.TCPConn, info *ServerInfo) error {
+ r := bufio.NewReader(s)
+
+ // Read auth types. 217-ext-orport-auth.txt section 4.1.
+ var authTypes [256]bool
+ var count int
+ for count = 0; count < 256; count++ {
+ b, err := r.ReadByte()
+ if err != nil {
+ return err
+ }
+ if b == 0 {
+ break
+ }
+ authTypes[b] = true
+ }
+ if count >= 256 {
+ return errors.New(fmt.Sprintf("read 256 auth types without seeing \\x00"))
+ }
+
+ // We support only type 1, SAFE_COOKIE.
+ if !authTypes[1] {
+ return errors.New(fmt.Sprintf("server didn't offer auth type 1"))
+ }
+ _, err := s.Write([]byte{1})
+ if err != nil {
+ return err
+ }
+
+ clientNonce := make([]byte, 32)
+ clientHash := make([]byte, 32)
+ serverNonce := make([]byte, 32)
+ serverHash := make([]byte, 32)
+
+ _, err = io.ReadFull(rand.Reader, clientNonce)
+ if err != nil {
+ return err
+ }
+ _, err = s.Write(clientNonce)
+ if err != nil {
+ return err
+ }
+
+ _, err = io.ReadFull(r, serverHash)
+ if err != nil {
+ return err
+ }
+ _, err = io.ReadFull(r, serverNonce)
+ if err != nil {
+ return err
+ }
+
+ expectedServerHash := computeServerHash(info, clientNonce, serverNonce)
+ if subtle.ConstantTimeCompare(serverHash, expectedServerHash) != 1 {
+ return errors.New(fmt.Sprintf("mismatch in server hash"))
+ }
+
+ clientHash = computeClientHash(info, clientNonce, serverNonce)
+ _, err = s.Write(clientHash)
+ if err != nil {
+ return err
+ }
+
+ status := make([]byte, 1)
+ _, err = io.ReadFull(r, status)
+ if err != nil {
+ return err
+ }
+ if status[0] != 1 {
+ return errors.New(fmt.Sprintf("server rejected authentication"))
+ }
+
+ if r.Buffered() != 0 {
+ return errors.New(fmt.Sprintf("%d bytes left after extended OR port authentication", r.Buffered()))
+ }
+
+ return nil
+}
+
+// See section 3.1 of 196-transport-control-ports.txt.
+const (
+ extOrCmdDone = 0x0000
+ extOrCmdUserAddr = 0x0001
+ extOrCmdTransport = 0x0002
+ extOrCmdOkay = 0x1000
+ extOrCmdDeny = 0x1001
+)
+
+func extOrPortWriteCommand(s *net.TCPConn, cmd uint16, body []byte) error {
+ var buf bytes.Buffer
+ if len(body) > 65535 {
+ return errors.New("command exceeds maximum length of 65535")
+ }
+ err := binary.Write(&buf, binary.BigEndian, cmd)
+ if err != nil {
+ return err
+ }
+ err = binary.Write(&buf, binary.BigEndian, uint16(len(body)))
+ if err != nil {
+ return err
+ }
+ err = binary.Write(&buf, binary.BigEndian, body)
+ if err != nil {
+ return err
+ }
+ _, err = s.Write(buf.Bytes())
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// Send a USERADDR command on s. See section 3.1.2.1 of
+// 196-transport-control-ports.txt.
+func extOrPortSendUserAddr(s *net.TCPConn, conn net.Conn) error {
+ return extOrPortWriteCommand(s, extOrCmdUserAddr, []byte(conn.RemoteAddr().String()))
+}
+
+// Send a TRANSPORT command on s. See section 3.1.2.2 of
+// 196-transport-control-ports.txt.
+func extOrPortSendTransport(s *net.TCPConn, methodName string) error {
+ return extOrPortWriteCommand(s, extOrCmdTransport, []byte(methodName))
+}
+
+// Send a DONE command on s. See section 3.1 of 196-transport-control-ports.txt.
+func extOrPortSendDone(s *net.TCPConn) error {
+ return extOrPortWriteCommand(s, extOrCmdDone, []byte{})
+}
+
+func extOrPortRecvCommand(s *net.TCPConn) (cmd uint16, body []byte, err error) {
+ var bodyLen uint16
+ data := make([]byte, 4)
+
+ _, err = io.ReadFull(s, data)
+ if err != nil {
+ return
+ }
+ buf := bytes.NewBuffer(data)
+ err = binary.Read(buf, binary.BigEndian, &cmd)
+ if err != nil {
+ return
+ }
+ err = binary.Read(buf, binary.BigEndian, &bodyLen)
+ if err != nil {
+ return
+ }
+ body = make([]byte, bodyLen)
+ _, err = io.ReadFull(s, body)
+ if err != nil {
+ return
+ }
+
+ return cmd, body, err
+}
+
+// Send USERADDR and TRANSPORT commands followed by a DONE command. Wait for an
+// OKAY or DENY response command from the server. Returns nil if and only if
+// OKAY is received.
+func extOrPortSetup(s *net.TCPConn, conn net.Conn, methodName string) error {
+ var err error
+
+ err = extOrPortSendUserAddr(s, conn)
+ if err != nil {
+ return err
+ }
+ err = extOrPortSendTransport(s, methodName)
+ if err != nil {
+ return err
+ }
+ err = extOrPortSendDone(s)
+ if err != nil {
+ return err
+ }
+ cmd, _, err := extOrPortRecvCommand(s)
+ if err != nil {
+ return err
+ }
+ if cmd == extOrCmdDeny {
+ return errors.New("server returned DENY after our USERADDR and DONE")
+ } else if cmd != extOrCmdOkay {
+ return errors.New(fmt.Sprintf("server returned unknown command 0x%04x after our USERADDR and DONE", cmd))
+ }
+
+ return nil
+}
+
+// Connect to info.ExtendedOrAddr if defined, or else info.OrAddr, and return an
+// open *net.TCPConn. If connecting to the extended OR port, extended OR port
+// authentication à la 217-ext-orport-auth.txt is done before returning; an
+// error is returned if authentication fails.
+func ConnectOr(info *ServerInfo, conn net.Conn, methodName string) (*net.TCPConn, error) {
+ if info.ExtendedOrAddr == nil {
+ return net.DialTCP("tcp", nil, info.OrAddr)
+ }
+
+ s, err := net.DialTCP("tcp", nil, info.ExtendedOrAddr)
+ if err != nil {
+ return nil, err
+ }
+ s.SetDeadline(time.Now().Add(5 * time.Second))
+ err = extOrPortAuthenticate(s, info)
+ if err != nil {
+ s.Close()
+ return nil, err
+ }
+ err = extOrPortSetup(s, conn, methodName)
+ if err != nil {
+ s.Close()
+ return nil, err
+ }
+ s.SetDeadline(time.Time{})
+
+ return s, nil
+}
diff --git a/pt/pt_test.go b/pt/pt_test.go
new file mode 100644
index 0000000..cc7924a
--- /dev/null
+++ b/pt/pt_test.go
@@ -0,0 +1,61 @@
+package pt
+
+import "os"
+import "testing"
+
+func stringIsSafe(s string) bool {
+ for _, c := range []byte(s) {
+ if c == '\x00' || c == '\n' || c > 127 {
+ return false
+ }
+ }
+ return true
+}
+
+func TestEscape(t *testing.T) {
+ tests := [...]string{
+ "",
+ "abc",
+ "a\nb",
+ "a\\b",
+ "ab\\",
+ "ab\\\n",
+ "ab\n\\",
+ }
+
+ check := func(input string) {
+ output := escape(input)
+ if !stringIsSafe(output) {
+ t.Errorf("escape(%q) â %q", input, output)
+ }
+ }
+ for _, input := range tests {
+ check(input)
+ }
+ for b := 0; b < 256; b++ {
+ // check one-byte string with each byte value 0â255
+ check(string([]byte{byte(b)}))
+ // check UTF-8 encoding of each character 0â255
+ check(string(b))
+ }
+}
+
+func TestGetManagedTransportVer(t *testing.T) {
+ tests := [...]struct {
+ input, expected string
+ }{
+ {"1", "1"},
+ {"1,1", "1"},
+ {"1,2", "1"},
+ {"2,1", "1"},
+ {"2", ""},
+ }
+
+ for _, test := range tests {
+ os.Setenv("TOR_PT_MANAGED_TRANSPORT_VER", test.input)
+ output := getManagedTransportVer()
+ if output != test.expected {
+ t.Errorf("%q â %q (expected %q)", test.input, output, test.expected)
+ }
+ }
+}
diff --git a/pt/socks/socks.go b/pt/socks/socks.go
new file mode 100644
index 0000000..788d53c
--- /dev/null
+++ b/pt/socks/socks.go
@@ -0,0 +1,107 @@
+// SOCKS4a server library.
+
+package socks
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+)
+
+const (
+ socksVersion = 0x04
+ socksCmdConnect = 0x01
+ socksResponseVersion = 0x00
+ socksRequestGranted = 0x5a
+ socksRequestFailed = 0x5b
+)
+
+// Read a SOCKS4a connect request, and call the given connect callback with the
+// requested destination string. If the callback returns an error, sends a SOCKS
+// request failed message. Otherwise, sends a SOCKS request granted message for
+// the destination address returned by the callback.
+func AwaitSocks4aConnect(conn *net.TCPConn, connect func(string) (*net.TCPAddr, error)) error {
+ dest, err := readSocks4aConnect(conn)
+ if err != nil {
+ sendSocks4aResponseFailed(conn)
+ return err
+ }
+ destAddr, err := connect(dest)
+ if err != nil {
+ sendSocks4aResponseFailed(conn)
+ return err
+ }
+ sendSocks4aResponseGranted(conn, destAddr)
+ return nil
+}
+
+// Read a SOCKS4a connect request. Returns a "host:port" string.
+func readSocks4aConnect(s io.Reader) (string, error) {
+ r := bufio.NewReader(s)
+
+ var h [8]byte
+ n, err := io.ReadFull(r, h[:])
+ if err != nil {
+ return "", errors.New(fmt.Sprintf("after %d bytes of SOCKS header: %s", n, err))
+ }
+ if h[0] != socksVersion {
+ return "", errors.New(fmt.Sprintf("SOCKS header had version 0x%02x, not 0x%02x", h[0], socksVersion))
+ }
+ if h[1] != socksCmdConnect {
+ return "", errors.New(fmt.Sprintf("SOCKS header had command 0x%02x, not 0x%02x", h[1], socksCmdConnect))
+ }
+
+ _, err = r.ReadBytes('\x00')
+ if err != nil {
+ return "", errors.New(fmt.Sprintf("reading SOCKS userid: %s", err))
+ }
+
+ var port int
+ var host string
+
+ port = int(h[2])<<8 | int(h[3])<<0
+ if h[4] == 0 && h[5] == 0 && h[6] == 0 && h[7] != 0 {
+ hostBytes, err := r.ReadBytes('\x00')
+ if err != nil {
+ return "", errors.New(fmt.Sprintf("reading SOCKS4a destination: %s", err))
+ }
+ host = string(hostBytes[:len(hostBytes)-1])
+ } else {
+ host = net.IPv4(h[4], h[5], h[6], h[7]).String()
+ }
+
+ if r.Buffered() != 0 {
+ return "", errors.New(fmt.Sprintf("%d bytes left after SOCKS header", r.Buffered()))
+ }
+
+ return fmt.Sprintf("%s:%d", host, port), nil
+}
+
+// Send a SOCKS4a response with the given code and address.
+func sendSocks4aResponse(w io.Writer, code byte, addr *net.TCPAddr) error {
+ var resp [8]byte
+ resp[0] = socksResponseVersion
+ resp[1] = code
+ resp[2] = byte((addr.Port >> 8) & 0xff)
+ resp[3] = byte((addr.Port >> 0) & 0xff)
+ resp[4] = addr.IP[0]
+ resp[5] = addr.IP[1]
+ resp[6] = addr.IP[2]
+ resp[7] = addr.IP[3]
+ _, err := w.Write(resp[:])
+ return err
+}
+
+var emptyAddr = net.TCPAddr{IP: net.IPv4(0, 0, 0, 0), Port: 0}
+
+// Send a SOCKS4a response code 0x5a.
+func sendSocks4aResponseGranted(w io.Writer, addr *net.TCPAddr) error {
+ return sendSocks4aResponse(w, socksRequestGranted, addr)
+}
+
+// Send a SOCKS4a response code 0x5b (with an all-zero address).
+func sendSocks4aResponseFailed(w io.Writer) error {
+ return sendSocks4aResponse(w, socksRequestFailed, &emptyAddr)
+}
diff --git a/src/pt/.gitignore b/src/pt/.gitignore
deleted file mode 100644
index d4d5132..0000000
--- a/src/pt/.gitignore
+++ /dev/null
@@ -1,2 +0,0 @@
-/examples/dummy-client/dummy-client
-/examples/dummy-server/dummy-server
diff --git a/src/pt/examples/dummy-client/dummy-client.go b/src/pt/examples/dummy-client/dummy-client.go
deleted file mode 100644
index 3cf7b45..0000000
--- a/src/pt/examples/dummy-client/dummy-client.go
+++ /dev/null
@@ -1,137 +0,0 @@
-// Usage (in torrc):
-// UseBridges 1
-// Bridge dummy X.X.X.X:YYYY
-// ClientTransportPlugin dummy exec dummy-client
-// Because this transport doesn't do anything to the traffic, you can use any
-// ordinary relay's ORPort in the Bridge line.
-
-package main
-
-import (
- "io"
- "net"
- "os"
- "os/signal"
- "sync"
- "syscall"
-)
-
-import "git.torproject.org/pluggable-transports/websocket.git/src/pt"
-import "git.torproject.org/pluggable-transports/websocket.git/src/pt/socks"
-
-var ptInfo pt.ClientInfo
-
-// When a connection handler starts, +1 is written to this channel; when it
-// ends, -1 is written.
-var handlerChan = make(chan int)
-
-func copyLoop(a, b net.Conn) {
- var wg sync.WaitGroup
- wg.Add(2)
-
- go func() {
- io.Copy(b, a)
- wg.Done()
- }()
- go func() {
- io.Copy(a, b)
- wg.Done()
- }()
-
- wg.Wait()
-}
-
-func handleConnection(local net.Conn) error {
- defer local.Close()
-
- handlerChan <- 1
- defer func() {
- handlerChan <- -1
- }()
-
- var remote net.Conn
- err := socks.AwaitSocks4aConnect(local.(*net.TCPConn), func(dest string) (*net.TCPAddr, error) {
- var err error
- // set remote in outer function environment
- remote, err = net.Dial("tcp", dest)
- if err != nil {
- return nil, err
- }
- return remote.RemoteAddr().(*net.TCPAddr), nil
- })
- if err != nil {
- return err
- }
- defer remote.Close()
- copyLoop(local, remote)
-
- return nil
-}
-
-func acceptLoop(ln net.Listener) error {
- for {
- conn, err := ln.Accept()
- if err != nil {
- return err
- }
- go handleConnection(conn)
- }
- return nil
-}
-
-func startListener(addr string) (net.Listener, error) {
- ln, err := net.Listen("tcp", addr)
- if err != nil {
- return nil, err
- }
- go acceptLoop(ln)
- return ln, nil
-}
-
-func main() {
- ptInfo = pt.ClientSetup([]string{"dummy"})
-
- listeners := make([]net.Listener, 0)
- for _, methodName := range ptInfo.MethodNames {
- ln, err := startListener("127.0.0.1:0")
- if err != nil {
- pt.CmethodError(methodName, err.Error())
- continue
- }
- pt.Cmethod(methodName, "socks4", ln.Addr())
- listeners = append(listeners, ln)
- }
- pt.CmethodsDone()
-
- var numHandlers int = 0
- var sig os.Signal
- sigChan := make(chan os.Signal, 1)
- signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
-
- // wait for first signal
- sig = nil
- for sig == nil {
- select {
- case n := <-handlerChan:
- numHandlers += n
- case sig = <-sigChan:
- }
- }
- for _, ln := range listeners {
- ln.Close()
- }
-
- if sig == syscall.SIGTERM {
- return
- }
-
- // wait for second signal or no more handlers
- sig = nil
- for sig == nil && numHandlers != 0 {
- select {
- case n := <-handlerChan:
- numHandlers += n
- case sig = <-sigChan:
- }
- }
-}
diff --git a/src/pt/examples/dummy-server/dummy-server.go b/src/pt/examples/dummy-server/dummy-server.go
deleted file mode 100644
index 26314d0..0000000
--- a/src/pt/examples/dummy-server/dummy-server.go
+++ /dev/null
@@ -1,121 +0,0 @@
-// Usage (in torrc):
-// BridgeRelay 1
-// ORPort 9001
-// ExtORPort 6669
-// ServerTransportPlugin dummy exec dummy-server
-
-package main
-
-import (
- "io"
- "net"
- "os"
- "os/signal"
- "sync"
- "syscall"
-)
-
-import "git.torproject.org/pluggable-transports/websocket.git/src/pt"
-
-var ptInfo pt.ServerInfo
-
-// When a connection handler starts, +1 is written to this channel; when it
-// ends, -1 is written.
-var handlerChan = make(chan int)
-
-func copyLoop(a, b net.Conn) {
- var wg sync.WaitGroup
- wg.Add(2)
-
- go func() {
- io.Copy(b, a)
- wg.Done()
- }()
- go func() {
- io.Copy(a, b)
- wg.Done()
- }()
-
- wg.Wait()
-}
-
-func handleConnection(conn net.Conn) {
- handlerChan <- 1
- defer func() {
- handlerChan <- -1
- }()
-
- or, err := pt.ConnectOr(&ptInfo, conn, "dummy")
- if err != nil {
- return
- }
- copyLoop(conn, or)
-}
-
-func acceptLoop(ln net.Listener) error {
- for {
- conn, err := ln.Accept()
- if err != nil {
- return err
- }
- go handleConnection(conn)
- }
- return nil
-}
-
-func startListener(addr *net.TCPAddr) (net.Listener, error) {
- ln, err := net.ListenTCP("tcp", addr)
- if err != nil {
- return nil, err
- }
- go acceptLoop(ln)
- return ln, nil
-}
-
-func main() {
- ptInfo = pt.ServerSetup([]string{"dummy"})
-
- listeners := make([]net.Listener, 0)
- for _, bindAddr := range ptInfo.BindAddrs {
- ln, err := startListener(bindAddr.Addr)
- if err != nil {
- pt.SmethodError(bindAddr.MethodName, err.Error())
- continue
- }
- pt.Smethod(bindAddr.MethodName, ln.Addr())
- listeners = append(listeners, ln)
- }
- pt.SmethodsDone()
-
- var numHandlers int = 0
- var sig os.Signal
- sigChan := make(chan os.Signal, 1)
- signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
-
- // wait for first signal
- sig = nil
- for sig == nil {
- select {
- case n := <-handlerChan:
- numHandlers += n
- case sig = <-sigChan:
- }
- }
- for _, ln := range listeners {
- ln.Close()
- }
-
- if sig == syscall.SIGTERM {
- return
- }
-
- // wait for second signal or no more handlers
- sig = nil
- for sig == nil && numHandlers != 0 {
- select {
- case n := <-handlerChan:
- numHandlers += n
- case sig = <-sigChan:
- }
- }
-}
diff --git a/src/pt/pt.go b/src/pt/pt.go
deleted file mode 100644
index 526f3b7..0000000
--- a/src/pt/pt.go
+++ /dev/null
@@ -1,611 +0,0 @@
-// Tor pluggable transports library.
-//
-// Sample client usage:
-//
-// import "git.torproject.org/pluggable-transports/websocket.git/src/pt"
-// var ptInfo pt.ClientInfo
-// ptInfo = pt.ClientSetup([]string{"foo"})
-// for _, methodName := range ptInfo.MethodNames {
-// ln, err := startSocksListener()
-// if err != nil {
-// pt.CmethodError(methodName, err.Error())
-// continue
-// }
-// pt.Cmethod(methodName, "socks4", ln.Addr())
-// }
-// pt.CmethodsDone()
-//
-// Sample server usage:
-//
-// import "git.torproject.org/pluggable-transports/websocket.git/src/pt"
-// var ptInfo pt.ServerInfo
-// ptInfo = pt.ServerSetup([]string{"foo", "bar"})
-// for _, bindAddr := range ptInfo.BindAddrs {
-// ln, err := startListener(bindAddr.Addr, bindAddr.MethodName)
-// if err != nil {
-// pt.SmethodError(bindAddr.MethodName, err.Error())
-// continue
-// }
-// pt.Smethod(bindAddr.MethodName, ln.Addr())
-// }
-// pt.SmethodsDone()
-// func handler(conn net.Conn, methodName string) {
-// or, err := pt.ConnectOr(&ptInfo, conn, methodName)
-// if err != nil {
-// return
-// }
-// // Do something with or and conn.
-// }
-
-package pt
-
-import (
- "bufio"
- "bytes"
- "crypto/hmac"
- "crypto/rand"
- "crypto/sha256"
- "crypto/subtle"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "net"
- "os"
- "strings"
- "time"
-)
-
-func getenv(key string) string {
- return os.Getenv(key)
-}
-
-// Abort with an ENV-ERROR if the environment variable isn't set.
-func getenvRequired(key string) string {
- value := os.Getenv(key)
- if value == "" {
- EnvError(fmt.Sprintf("no %s environment variable", key))
- }
- return value
-}
-
-// Escape a string so it contains no byte values over 127 and doesn't contain
-// any of the characters '\x00' or '\n'.
-func escape(s string) string {
- var buf bytes.Buffer
- for _, b := range []byte(s) {
- if b == '\n' {
- buf.WriteString("\\n")
- } else if b == '\\' {
- buf.WriteString("\\\\")
- } else if 0 < b && b < 128 {
- buf.WriteByte(b)
- } else {
- fmt.Fprintf(&buf, "\\x%02x", b)
- }
- }
- return buf.String()
-}
-
-// Print a pluggable transports protocol line to stdout. The line consists of an
-// unescaped keyword, followed by any number of escaped strings.
-func Line(keyword string, v ...string) {
- var buf bytes.Buffer
- buf.WriteString(keyword)
- for _, x := range v {
- buf.WriteString(" " + escape(x))
- }
- fmt.Println(buf.String())
- os.Stdout.Sync()
-}
-
-// All of the *Error functions call os.Exit(1).
-
-// Emit an ENV-ERROR with explanation text.
-func EnvError(msg string) {
- Line("ENV-ERROR", msg)
- os.Exit(1)
-}
-
-// Emit a VERSION-ERROR with explanation text.
-func VersionError(msg string) {
- Line("VERSION-ERROR", msg)
- os.Exit(1)
-}
-
-// Emit a CMETHOD-ERROR with explanation text.
-func CmethodError(methodName, msg string) {
- Line("CMETHOD-ERROR", methodName, msg)
- os.Exit(1)
-}
-
-// Emit an SMETHOD-ERROR with explanation text.
-func SmethodError(methodName, msg string) {
- Line("SMETHOD-ERROR", methodName, msg)
- os.Exit(1)
-}
-
-// Emit a CMETHOD line. socks must be "socks4" or "socks5". Call this once for
-// each listening client SOCKS port.
-func Cmethod(name string, socks string, addr net.Addr) {
- Line("CMETHOD", name, socks, addr.String())
-}
-
-// Emit a CMETHODS DONE line. Call this after opening all client listeners.
-func CmethodsDone() {
- Line("CMETHODS", "DONE")
-}
-
-// Emit an SMETHOD line. Call this once for each listening server port.
-func Smethod(name string, addr net.Addr) {
- Line("SMETHOD", name, addr.String())
-}
-
-// Emit an SMETHODS DONE line. Call this after opening all server listeners.
-func SmethodsDone() {
- Line("SMETHODS", "DONE")
-}
-
-// Get a pluggable transports version offered by Tor and understood by us, if
-// any. The only version we understand is "1". This function reads the
-// environment variable TOR_PT_MANAGED_TRANSPORT_VER.
-func getManagedTransportVer() string {
- const transportVersion = "1"
- for _, offered := range strings.Split(getenvRequired("TOR_PT_MANAGED_TRANSPORT_VER"), ",") {
- if offered == transportVersion {
- return offered
- }
- }
- return ""
-}
-
-// Get the intersection of the method names offered by Tor and those in
-// methodNames. This function reads the environment variable
-// TOR_PT_CLIENT_TRANSPORTS.
-func getClientTransports(methodNames []string) []string {
- clientTransports := getenvRequired("TOR_PT_CLIENT_TRANSPORTS")
- if clientTransports == "*" {
- return methodNames
- }
- result := make([]string, 0)
- for _, requested := range strings.Split(clientTransports, ",") {
- for _, methodName := range methodNames {
- if requested == methodName {
- result = append(result, methodName)
- break
- }
- }
- }
- return result
-}
-
-// This structure is returned by ClientSetup. It consists of a list of method
-// names.
-type ClientInfo struct {
- MethodNames []string
-}
-
-// Check the client pluggable transports environments, emitting an error message
-// and exiting the program if any error is encountered. Returns a subset of
-// methodNames requested by Tor.
-func ClientSetup(methodNames []string) ClientInfo {
- var info ClientInfo
-
- ver := getManagedTransportVer()
- if ver == "" {
- VersionError("no-version")
- } else {
- Line("VERSION", ver)
- }
-
- info.MethodNames = getClientTransports(methodNames)
- if len(info.MethodNames) == 0 {
- CmethodsDone()
- os.Exit(1)
- }
-
- return info
-}
-
-// A combination of a method name and an address, as extracted from
-// TOR_PT_SERVER_BINDADDR.
-type BindAddr struct {
- MethodName string
- Addr *net.TCPAddr
-}
-
-// Resolve an address string into a net.TCPAddr.
-func resolveBindAddr(bindAddr string) (*net.TCPAddr, error) {
- addr, err := net.ResolveTCPAddr("tcp", bindAddr)
- if err == nil {
- return addr, nil
- }
- // Before the fixing of bug #7011, tor doesn't put brackets around IPv6
- // addresses. Split after the last colon, assuming it is a port
- // separator, and try adding the brackets.
- parts := strings.Split(bindAddr, ":")
- if len(parts) <= 2 {
- return nil, err
- }
- bindAddr = "[" + strings.Join(parts[:len(parts)-1], ":") + "]:" + parts[len(parts)-1]
- return net.ResolveTCPAddr("tcp", bindAddr)
-}
-
-// Return a new slice, the members of which are those members of addrs having a
-// MethodName in methodNames.
-func filterBindAddrs(addrs []BindAddr, methodNames []string) []BindAddr {
- var result []BindAddr
-
- for _, ba := range addrs {
- for _, methodName := range methodNames {
- if ba.MethodName == methodName {
- result = append(result, ba)
- break
- }
- }
- }
-
- return result
-}
-
-// Return a map from method names to bind addresses. The map is the contents of
-// TOR_PT_SERVER_BINDADDR, with keys filtered by TOR_PT_SERVER_TRANSPORTS, and
-// further filtered by the methods in methodNames.
-func getServerBindAddrs(methodNames []string) []BindAddr {
- var result []BindAddr
-
- // Get the list of all requested bindaddrs.
- var serverBindAddr = getenvRequired("TOR_PT_SERVER_BINDADDR")
- for _, spec := range strings.Split(serverBindAddr, ",") {
- var bindAddr BindAddr
-
- parts := strings.SplitN(spec, "-", 2)
- if len(parts) != 2 {
- EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: doesn't contain \"-\"", spec))
- }
- bindAddr.MethodName = parts[0]
- addr, err := resolveBindAddr(parts[1])
- if err != nil {
- EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: %s", spec, err.Error()))
- }
- bindAddr.Addr = addr
- result = append(result, bindAddr)
- }
-
- // Filter by TOR_PT_SERVER_TRANSPORTS.
- serverTransports := getenvRequired("TOR_PT_SERVER_TRANSPORTS")
- if serverTransports != "*" {
- result = filterBindAddrs(result, strings.Split(serverTransports, ","))
- }
-
- // Finally filter by what we understand.
- result = filterBindAddrs(result, methodNames)
-
- return result
-}
-
-// Read and validate the contents of an auth cookie file. Returns the 32-byte
-// cookie. See section 4.2.1.2 of pt-spec.txt.
-func readAuthCookieFile(filename string) ([]byte, error) {
- authCookieHeader := []byte("! Extended ORPort Auth Cookie !\x0a")
- header := make([]byte, 32)
- cookie := make([]byte, 32)
-
- f, err := os.Open(filename)
- if err != nil {
- return cookie, err
- }
- defer f.Close()
-
- n, err := io.ReadFull(f, header)
- if err != nil {
- return cookie, err
- }
- n, err = io.ReadFull(f, cookie)
- if err != nil {
- return cookie, err
- }
- // Check that the file ends here.
- n, err = f.Read(make([]byte, 1))
- if n != 0 {
- return cookie, errors.New(fmt.Sprintf("file is longer than 64 bytes"))
- } else if err != io.EOF {
- return cookie, errors.New(fmt.Sprintf("did not find EOF at end of file"))
- }
-
- if !bytes.Equal(header, authCookieHeader) {
- return cookie, errors.New(fmt.Sprintf("missing auth cookie header"))
- }
-
- return cookie, nil
-}
-
-// This structure is returned by ServerSetup. It consists of a list of
-// BindAddrs, an address for the ORPort, an address for the extended ORPort (if
-// any), and an authentication cookie (if any).
-type ServerInfo struct {
- BindAddrs []BindAddr
- OrAddr *net.TCPAddr
- ExtendedOrAddr *net.TCPAddr
- AuthCookie []byte
-}
-
-// Check the server pluggable transports environments, emitting an error message
-// and exiting the program if any error is encountered. Resolves the various
-// requested bind addresses, the server ORPort and extended ORPort, and reads
-// the auth cookie file. Returns a ServerInfo struct.
-func ServerSetup(methodNames []string) ServerInfo {
- var info ServerInfo
- var err error
-
- ver := getManagedTransportVer()
- if ver == "" {
- VersionError("no-version")
- } else {
- Line("VERSION", ver)
- }
-
- var orPort = getenvRequired("TOR_PT_ORPORT")
- info.OrAddr, err = net.ResolveTCPAddr("tcp", orPort)
- if err != nil {
- EnvError(fmt.Sprintf("cannot resolve TOR_PT_ORPORT %q: %s", orPort, err.Error()))
- }
-
- info.BindAddrs = getServerBindAddrs(methodNames)
- if len(info.BindAddrs) == 0 {
- SmethodsDone()
- os.Exit(1)
- }
-
- var extendedOrPort = getenv("TOR_PT_EXTENDED_SERVER_PORT")
- if extendedOrPort != "" {
- info.ExtendedOrAddr, err = net.ResolveTCPAddr("tcp", extendedOrPort)
- if err != nil {
- EnvError(fmt.Sprintf("cannot resolve TOR_PT_EXTENDED_SERVER_PORT %q: %s", extendedOrPort, err.Error()))
- }
- }
-
- var authCookieFilename = getenv("TOR_PT_AUTH_COOKIE_FILE")
- if authCookieFilename != "" {
- info.AuthCookie, err = readAuthCookieFile(authCookieFilename)
- if err != nil {
- EnvError(fmt.Sprintf("error reading TOR_PT_AUTH_COOKIE_FILE %q: %s", authCookieFilename, err.Error()))
- }
- }
-
- return info
-}
-
-// See 217-ext-orport-auth.txt section 4.2.1.3.
-func computeServerHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte {
- h := hmac.New(sha256.New, info.AuthCookie)
- io.WriteString(h, "ExtORPort authentication server-to-client hash")
- h.Write(clientNonce)
- h.Write(serverNonce)
- return h.Sum([]byte{})
-}
-
-// See 217-ext-orport-auth.txt section 4.2.1.3.
-func computeClientHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte {
- h := hmac.New(sha256.New, info.AuthCookie)
- io.WriteString(h, "ExtORPort authentication client-to-server hash")
- h.Write(clientNonce)
- h.Write(serverNonce)
- return h.Sum([]byte{})
-}
-
-func extOrPortAuthenticate(s *net.TCPConn, info *ServerInfo) error {
- r := bufio.NewReader(s)
-
- // Read auth types. 217-ext-orport-auth.txt section 4.1.
- var authTypes [256]bool
- var count int
- for count = 0; count < 256; count++ {
- b, err := r.ReadByte()
- if err != nil {
- return err
- }
- if b == 0 {
- break
- }
- authTypes[b] = true
- }
- if count >= 256 {
- return errors.New(fmt.Sprintf("read 256 auth types without seeing \\x00"))
- }
-
- // We support only type 1, SAFE_COOKIE.
- if !authTypes[1] {
- return errors.New(fmt.Sprintf("server didn't offer auth type 1"))
- }
- _, err := s.Write([]byte{1})
- if err != nil {
- return err
- }
-
- clientNonce := make([]byte, 32)
- clientHash := make([]byte, 32)
- serverNonce := make([]byte, 32)
- serverHash := make([]byte, 32)
-
- _, err = io.ReadFull(rand.Reader, clientNonce)
- if err != nil {
- return err
- }
- _, err = s.Write(clientNonce)
- if err != nil {
- return err
- }
-
- _, err = io.ReadFull(r, serverHash)
- if err != nil {
- return err
- }
- _, err = io.ReadFull(r, serverNonce)
- if err != nil {
- return err
- }
-
- expectedServerHash := computeServerHash(info, clientNonce, serverNonce)
- if subtle.ConstantTimeCompare(serverHash, expectedServerHash) != 1 {
- return errors.New(fmt.Sprintf("mismatch in server hash"))
- }
-
- clientHash = computeClientHash(info, clientNonce, serverNonce)
- _, err = s.Write(clientHash)
- if err != nil {
- return err
- }
-
- status := make([]byte, 1)
- _, err = io.ReadFull(r, status)
- if err != nil {
- return err
- }
- if status[0] != 1 {
- return errors.New(fmt.Sprintf("server rejected authentication"))
- }
-
- if r.Buffered() != 0 {
- return errors.New(fmt.Sprintf("%d bytes left after extended OR port authentication", r.Buffered()))
- }
-
- return nil
-}
-
-// See section 3.1 of 196-transport-control-ports.txt.
-const (
- extOrCmdDone = 0x0000
- extOrCmdUserAddr = 0x0001
- extOrCmdTransport = 0x0002
- extOrCmdOkay = 0x1000
- extOrCmdDeny = 0x1001
-)
-
-func extOrPortWriteCommand(s *net.TCPConn, cmd uint16, body []byte) error {
- var buf bytes.Buffer
- if len(body) > 65535 {
- return errors.New("command exceeds maximum length of 65535")
- }
- err := binary.Write(&buf, binary.BigEndian, cmd)
- if err != nil {
- return err
- }
- err = binary.Write(&buf, binary.BigEndian, uint16(len(body)))
- if err != nil {
- return err
- }
- err = binary.Write(&buf, binary.BigEndian, body)
- if err != nil {
- return err
- }
- _, err = s.Write(buf.Bytes())
- if err != nil {
- return err
- }
-
- return nil
-}
-
-// Send a USERADDR command on s. See section 3.1.2.1 of
-// 196-transport-control-ports.txt.
-func extOrPortSendUserAddr(s *net.TCPConn, conn net.Conn) error {
- return extOrPortWriteCommand(s, extOrCmdUserAddr, []byte(conn.RemoteAddr().String()))
-}
-
-// Send a TRANSPORT command on s. See section 3.1.2.2 of
-// 196-transport-control-ports.txt.
-func extOrPortSendTransport(s *net.TCPConn, methodName string) error {
- return extOrPortWriteCommand(s, extOrCmdTransport, []byte(methodName))
-}
-
-// Send a DONE command on s. See section 3.1 of 196-transport-control-ports.txt.
-func extOrPortSendDone(s *net.TCPConn) error {
- return extOrPortWriteCommand(s, extOrCmdDone, []byte{})
-}
-
-func extOrPortRecvCommand(s *net.TCPConn) (cmd uint16, body []byte, err error) {
- var bodyLen uint16
- data := make([]byte, 4)
-
- _, err = io.ReadFull(s, data)
- if err != nil {
- return
- }
- buf := bytes.NewBuffer(data)
- err = binary.Read(buf, binary.BigEndian, &cmd)
- if err != nil {
- return
- }
- err = binary.Read(buf, binary.BigEndian, &bodyLen)
- if err != nil {
- return
- }
- body = make([]byte, bodyLen)
- _, err = io.ReadFull(s, body)
- if err != nil {
- return
- }
-
- return cmd, body, err
-}
-
-// Send USERADDR and TRANSPORT commands followed by a DONE command. Wait for an
-// OKAY or DENY response command from the server. Returns nil if and only if
-// OKAY is received.
-func extOrPortSetup(s *net.TCPConn, conn net.Conn, methodName string) error {
- var err error
-
- err = extOrPortSendUserAddr(s, conn)
- if err != nil {
- return err
- }
- err = extOrPortSendTransport(s, methodName)
- if err != nil {
- return err
- }
- err = extOrPortSendDone(s)
- if err != nil {
- return err
- }
- cmd, _, err := extOrPortRecvCommand(s)
- if err != nil {
- return err
- }
- if cmd == extOrCmdDeny {
- return errors.New("server returned DENY after our USERADDR and DONE")
- } else if cmd != extOrCmdOkay {
- return errors.New(fmt.Sprintf("server returned unknown command 0x%04x after our USERADDR and DONE", cmd))
- }
-
- return nil
-}
-
-// Connect to info.ExtendedOrAddr if defined, or else info.OrAddr, and return an
-// open *net.TCPConn. If connecting to the extended OR port, extended OR port
-// authentication à la 217-ext-orport-auth.txt is done before returning; an
-// error is returned if authentication fails.
-func ConnectOr(info *ServerInfo, conn net.Conn, methodName string) (*net.TCPConn, error) {
- if info.ExtendedOrAddr == nil {
- return net.DialTCP("tcp", nil, info.OrAddr)
- }
-
- s, err := net.DialTCP("tcp", nil, info.ExtendedOrAddr)
- if err != nil {
- return nil, err
- }
- s.SetDeadline(time.Now().Add(5 * time.Second))
- err = extOrPortAuthenticate(s, info)
- if err != nil {
- s.Close()
- return nil, err
- }
- err = extOrPortSetup(s, conn, methodName)
- if err != nil {
- s.Close()
- return nil, err
- }
- s.SetDeadline(time.Time{})
-
- return s, nil
-}
diff --git a/src/pt/pt_test.go b/src/pt/pt_test.go
deleted file mode 100644
index cc7924a..0000000
--- a/src/pt/pt_test.go
+++ /dev/null
@@ -1,61 +0,0 @@
-package pt
-
-import "os"
-import "testing"
-
-func stringIsSafe(s string) bool {
- for _, c := range []byte(s) {
- if c == '\x00' || c == '\n' || c > 127 {
- return false
- }
- }
- return true
-}
-
-func TestEscape(t *testing.T) {
- tests := [...]string{
- "",
- "abc",
- "a\nb",
- "a\\b",
- "ab\\",
- "ab\\\n",
- "ab\n\\",
- }
-
- check := func(input string) {
- output := escape(input)
- if !stringIsSafe(output) {
- t.Errorf("escape(%q) â %q", input, output)
- }
- }
- for _, input := range tests {
- check(input)
- }
- for b := 0; b < 256; b++ {
- // check one-byte string with each byte value 0â255
- check(string([]byte{byte(b)}))
- // check UTF-8 encoding of each character 0â255
- check(string(b))
- }
-}
-
-func TestGetManagedTransportVer(t *testing.T) {
- tests := [...]struct {
- input, expected string
- }{
- {"1", "1"},
- {"1,1", "1"},
- {"1,2", "1"},
- {"2,1", "1"},
- {"2", ""},
- }
-
- for _, test := range tests {
- os.Setenv("TOR_PT_MANAGED_TRANSPORT_VER", test.input)
- output := getManagedTransportVer()
- if output != test.expected {
- t.Errorf("%q â %q (expected %q)", test.input, output, test.expected)
- }
- }
-}
diff --git a/src/pt/socks/socks.go b/src/pt/socks/socks.go
deleted file mode 100644
index 788d53c..0000000
--- a/src/pt/socks/socks.go
+++ /dev/null
@@ -1,107 +0,0 @@
-// SOCKS4a server library.
-
-package socks
-
-import (
- "bufio"
- "errors"
- "fmt"
- "io"
- "net"
-)
-
-const (
- socksVersion = 0x04
- socksCmdConnect = 0x01
- socksResponseVersion = 0x00
- socksRequestGranted = 0x5a
- socksRequestFailed = 0x5b
-)
-
-// Read a SOCKS4a connect request, and call the given connect callback with the
-// requested destination string. If the callback returns an error, sends a SOCKS
-// request failed message. Otherwise, sends a SOCKS request granted message for
-// the destination address returned by the callback.
-func AwaitSocks4aConnect(conn *net.TCPConn, connect func(string) (*net.TCPAddr, error)) error {
- dest, err := readSocks4aConnect(conn)
- if err != nil {
- sendSocks4aResponseFailed(conn)
- return err
- }
- destAddr, err := connect(dest)
- if err != nil {
- sendSocks4aResponseFailed(conn)
- return err
- }
- sendSocks4aResponseGranted(conn, destAddr)
- return nil
-}
-
-// Read a SOCKS4a connect request. Returns a "host:port" string.
-func readSocks4aConnect(s io.Reader) (string, error) {
- r := bufio.NewReader(s)
-
- var h [8]byte
- n, err := io.ReadFull(r, h[:])
- if err != nil {
- return "", errors.New(fmt.Sprintf("after %d bytes of SOCKS header: %s", n, err))
- }
- if h[0] != socksVersion {
- return "", errors.New(fmt.Sprintf("SOCKS header had version 0x%02x, not 0x%02x", h[0], socksVersion))
- }
- if h[1] != socksCmdConnect {
- return "", errors.New(fmt.Sprintf("SOCKS header had command 0x%02x, not 0x%02x", h[1], socksCmdConnect))
- }
-
- _, err = r.ReadBytes('\x00')
- if err != nil {
- return "", errors.New(fmt.Sprintf("reading SOCKS userid: %s", err))
- }
-
- var port int
- var host string
-
- port = int(h[2])<<8 | int(h[3])<<0
- if h[4] == 0 && h[5] == 0 && h[6] == 0 && h[7] != 0 {
- hostBytes, err := r.ReadBytes('\x00')
- if err != nil {
- return "", errors.New(fmt.Sprintf("reading SOCKS4a destination: %s", err))
- }
- host = string(hostBytes[:len(hostBytes)-1])
- } else {
- host = net.IPv4(h[4], h[5], h[6], h[7]).String()
- }
-
- if r.Buffered() != 0 {
- return "", errors.New(fmt.Sprintf("%d bytes left after SOCKS header", r.Buffered()))
- }
-
- return fmt.Sprintf("%s:%d", host, port), nil
-}
-
-// Send a SOCKS4a response with the given code and address.
-func sendSocks4aResponse(w io.Writer, code byte, addr *net.TCPAddr) error {
- var resp [8]byte
- resp[0] = socksResponseVersion
- resp[1] = code
- resp[2] = byte((addr.Port >> 8) & 0xff)
- resp[3] = byte((addr.Port >> 0) & 0xff)
- resp[4] = addr.IP[0]
- resp[5] = addr.IP[1]
- resp[6] = addr.IP[2]
- resp[7] = addr.IP[3]
- _, err := w.Write(resp[:])
- return err
-}
-
-var emptyAddr = net.TCPAddr{IP: net.IPv4(0, 0, 0, 0), Port: 0}
-
-// Send a SOCKS4a response code 0x5a.
-func sendSocks4aResponseGranted(w io.Writer, addr *net.TCPAddr) error {
- return sendSocks4aResponse(w, socksRequestGranted, addr)
-}
-
-// Send a SOCKS4a response code 0x5b (with an all-zero address).
-func sendSocks4aResponseFailed(w io.Writer) error {
- return sendSocks4aResponse(w, socksRequestFailed, &emptyAddr)
-}
diff --git a/src/websocket-client/websocket-client.go b/src/websocket-client/websocket-client.go
deleted file mode 100644
index 1c3b3b9..0000000
--- a/src/websocket-client/websocket-client.go
+++ /dev/null
@@ -1,254 +0,0 @@
-// Tor websocket client transport plugin.
-//
-// Usage:
-// ClientTransportPlugin websocket exec ./websocket-client
-
-package main
-
-import (
- "code.google.com/p/go.net/websocket"
- "flag"
- "fmt"
- "io"
- "net"
- "net/url"
- "os"
- "os/signal"
- "sync"
- "time"
-)
-
-import "pt"
-import "pt/socks"
-
-const ptMethodName = "websocket"
-const socksTimeout = 2 * time.Second
-const bufSiz = 1500
-
-var logFile = os.Stderr
-
-// When a connection handler starts, +1 is written to this channel; when it
-// ends, -1 is written.
-var handlerChan = make(chan int)
-
-var logMutex sync.Mutex
-
-func usage() {
- fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0])
- fmt.Printf("WebSocket client pluggable transport for Tor.\n")
- fmt.Printf("Works only as a managed proxy.\n")
- fmt.Printf("\n")
- fmt.Printf(" -h, --help show this help.\n")
- fmt.Printf(" --log FILE log messages to FILE (default stderr).\n")
- fmt.Printf(" --socks ADDR listen for SOCKS on ADDR.\n")
-}
-
-func Log(format string, v ...interface{}) {
- dateStr := time.Now().Format("2006-01-02 15:04:05")
- logMutex.Lock()
- defer logMutex.Unlock()
- msg := fmt.Sprintf(format, v...)
- fmt.Fprintf(logFile, "%s %s\n", dateStr, msg)
-}
-
-func proxy(local *net.TCPConn, ws *websocket.Conn) {
- var wg sync.WaitGroup
-
- wg.Add(2)
-
- // Local-to-WebSocket read loop.
- go func() {
- buf := make([]byte, bufSiz)
- var err error
- for {
- n, er := local.Read(buf[:])
- if n > 0 {
- ew := websocket.Message.Send(ws, buf[:n])
- if ew != nil {
- err = ew
- break
- }
- }
- if er != nil {
- err = er
- break
- }
- }
- if err != nil && err != io.EOF {
- Log("%s", err)
- }
- local.CloseRead()
- ws.Close()
-
- wg.Done()
- }()
-
- // WebSocket-to-local read loop.
- go func() {
- var buf []byte
- var err error
- for {
- er := websocket.Message.Receive(ws, &buf)
- if er != nil {
- err = er
- break
- }
- n, ew := local.Write(buf)
- if ew != nil {
- err = ew
- break
- }
- if n != len(buf) {
- err = io.ErrShortWrite
- break
- }
- }
- if err != nil && err != io.EOF {
- Log("%s", err)
- }
- local.CloseWrite()
- ws.Close()
-
- wg.Done()
- }()
-
- wg.Wait()
-}
-
-func handleConnection(conn *net.TCPConn) error {
- defer conn.Close()
-
- handlerChan <- 1
- defer func() {
- handlerChan <- -1
- }()
-
- var ws *websocket.Conn
-
- conn.SetDeadline(time.Now().Add(socksTimeout))
- err := socks.AwaitSocks4aConnect(conn, func(dest string) (*net.TCPAddr, error) {
- // Disable deadline.
- conn.SetDeadline(time.Time{})
- Log("SOCKS request for %s", dest)
- destAddr, err := net.ResolveTCPAddr("tcp", dest)
- if err != nil {
- return nil, err
- }
- wsUrl := url.URL{Scheme: "ws", Host: dest}
- ws, err = websocket.Dial(wsUrl.String(), "", wsUrl.String())
- if err != nil {
- return nil, err
- }
- Log("WebSocket connection to %s", ws.Config().Location.String())
- return destAddr, nil
- })
- if err != nil {
- return err
- }
- defer ws.Close()
- proxy(conn, ws)
- return nil
-}
-
-func socksAcceptLoop(ln *net.TCPListener) error {
- for {
- socks, err := ln.AcceptTCP()
- if err != nil {
- return err
- }
- go func() {
- err := handleConnection(socks)
- if err != nil {
- Log("SOCKS from %s: %s", socks.RemoteAddr(), err)
- }
- }()
- }
- return nil
-}
-
-func startListener(addrStr string) (*net.TCPListener, error) {
- addr, err := net.ResolveTCPAddr("tcp", addrStr)
- if err != nil {
- return nil, err
- }
- ln, err := net.ListenTCP("tcp", addr)
- if err != nil {
- return nil, err
- }
- go func() {
- err := socksAcceptLoop(ln)
- if err != nil {
- Log("accept: %s", err)
- }
- }()
- return ln, nil
-}
-
-func main() {
- var logFilename string
- var socksAddrStrs = []string{"127.0.0.1:0"}
- var socksArg string
-
- flag.Usage = usage
- flag.StringVar(&logFilename, "log", "", "log file to write to")
- flag.StringVar(&socksArg, "socks", "", "address on which to listen for SOCKS connections")
- flag.Parse()
-
- if logFilename != "" {
- f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
- if err != nil {
- fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error())
- os.Exit(1)
- }
- logFile = f
- }
-
- if socksArg != "" {
- socksAddrStrs = []string{socksArg}
- }
-
- Log("starting")
- pt.ClientSetup([]string{ptMethodName})
-
- listeners := make([]*net.TCPListener, 0)
- for _, socksAddrStr := range socksAddrStrs {
- ln, err := startListener(socksAddrStr)
- if err != nil {
- pt.CmethodError(ptMethodName, err.Error())
- }
- pt.Cmethod(ptMethodName, "socks4", ln.Addr())
- Log("listening on %s", ln.Addr().String())
- listeners = append(listeners, ln)
- }
- pt.CmethodsDone()
-
- var numHandlers int = 0
-
- signalChan := make(chan os.Signal, 1)
- signal.Notify(signalChan, os.Interrupt)
- var sigint bool = false
- for !sigint {
- select {
- case n := <-handlerChan:
- numHandlers += n
- case <-signalChan:
- Log("SIGINT")
- sigint = true
- }
- }
-
- for _, ln := range listeners {
- ln.Close()
- }
-
- sigint = false
- for numHandlers != 0 && !sigint {
- select {
- case n := <-handlerChan:
- numHandlers += n
- case <-signalChan:
- Log("SIGINT")
- sigint = true
- }
- }
-}
diff --git a/src/websocket-server/websocket-server.go b/src/websocket-server/websocket-server.go
deleted file mode 100644
index 207be8d..0000000
--- a/src/websocket-server/websocket-server.go
+++ /dev/null
@@ -1,285 +0,0 @@
-// Tor websocket server transport plugin.
-//
-// Usage:
-// ServerTransportPlugin websocket exec ./websocket-server --port 9901
-
-package main
-
-import (
- "encoding/base64"
- "errors"
- "flag"
- "fmt"
- "io"
- "net"
- "net/http"
- "os"
- "os/signal"
- "sync"
- "syscall"
- "time"
-)
-
-import "pt"
-import "websocket"
-
-const ptMethodName = "websocket"
-const requestTimeout = 10 * time.Second
-
-// "4/3+1" accounts for possible base64 encoding.
-const maxMessageSize = 64*1024*4/3 + 1
-
-var logFile = os.Stderr
-
-var ptInfo pt.ServerInfo
-
-// When a connection handler starts, +1 is written to this channel; when it
-// ends, -1 is written.
-var handlerChan = make(chan int)
-
-func usage() {
- fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0])
- fmt.Printf("WebSocket server pluggable transport for Tor.\n")
- fmt.Printf("Works only as a managed proxy.\n")
- fmt.Printf("\n")
- fmt.Printf(" -h, --help show this help.\n")
- fmt.Printf(" --log FILE log messages to FILE (default stderr).\n")
- fmt.Printf(" --port PORT listen on PORT (overrides Tor's requested port).\n")
-}
-
-var logMutex sync.Mutex
-
-func log(format string, v ...interface{}) {
- dateStr := time.Now().Format("2006-01-02 15:04:05")
- logMutex.Lock()
- defer logMutex.Unlock()
- msg := fmt.Sprintf(format, v...)
- fmt.Fprintf(logFile, "%s %s\n", dateStr, msg)
-}
-
-// An abstraction that makes an underlying WebSocket connection look like an
-// io.ReadWriteCloser. It internally takes care of things like base64 encoding
-// and decoding.
-type webSocketConn struct {
- Ws *websocket.WebSocket
- Base64 bool
- messageBuf []byte
-}
-
-// Implements io.Reader.
-func (conn *webSocketConn) Read(b []byte) (n int, err error) {
- for len(conn.messageBuf) == 0 {
- var m websocket.Message
- m, err = conn.Ws.ReadMessage()
- if err != nil {
- return
- }
- if m.Opcode == 8 {
- err = io.EOF
- return
- }
- if conn.Base64 {
- if m.Opcode != 1 {
- err = errors.New(fmt.Sprintf("got non-text opcode %d with the base64 subprotocol", m.Opcode))
- return
- }
- conn.messageBuf = make([]byte, base64.StdEncoding.DecodedLen(len(m.Payload)))
- var num int
- num, err = base64.StdEncoding.Decode(conn.messageBuf, m.Payload)
- if err != nil {
- return
- }
- conn.messageBuf = conn.messageBuf[:num]
- } else {
- if m.Opcode != 2 {
- err = errors.New(fmt.Sprintf("got non-binary opcode %d with no subprotocol", m.Opcode))
- return
- }
- conn.messageBuf = m.Payload
- }
- }
-
- n = copy(b, conn.messageBuf)
- conn.messageBuf = conn.messageBuf[n:]
-
- return
-}
-
-// Implements io.Writer.
-func (conn *webSocketConn) Write(b []byte) (n int, err error) {
- if conn.Base64 {
- buf := make([]byte, base64.StdEncoding.EncodedLen(len(b)))
- base64.StdEncoding.Encode(buf, b)
- err = conn.Ws.WriteMessage(1, buf)
- if err != nil {
- return
- }
- n = len(b)
- } else {
- err = conn.Ws.WriteMessage(2, b)
- n = len(b)
- }
- return
-}
-
-// Implements io.Closer.
-func (conn *webSocketConn) Close() error {
- // Ignore any error in trying to write a Close frame.
- _ = conn.Ws.WriteFrame(8, nil)
- return conn.Ws.Conn.Close()
-}
-
-// Create a new webSocketConn.
-func newWebSocketConn(ws *websocket.WebSocket) webSocketConn {
- var conn webSocketConn
- conn.Ws = ws
- conn.Base64 = (ws.Subprotocol == "base64")
- return conn
-}
-
-// Copy from WebSocket to socket and vice versa.
-func proxy(local *net.TCPConn, conn *webSocketConn) {
- var wg sync.WaitGroup
-
- wg.Add(2)
-
- go func() {
- _, err := io.Copy(conn, local)
- if err != nil {
- log("error copying ORPort to WebSocket")
- }
- local.CloseRead()
- conn.Close()
- wg.Done()
- }()
-
- go func() {
- _, err := io.Copy(local, conn)
- if err != nil {
- log("error copying WebSocket to ORPort")
- }
- local.CloseWrite()
- conn.Close()
- wg.Done()
- }()
-
- wg.Wait()
-}
-
-func webSocketHandler(ws *websocket.WebSocket) {
- // Undo timeouts on HTTP request handling.
- ws.Conn.SetDeadline(time.Time{})
- conn := newWebSocketConn(ws)
-
- handlerChan <- 1
- defer func() {
- handlerChan <- -1
- }()
-
- s, err := pt.ConnectOr(&ptInfo, ws.Conn, ptMethodName)
- if err != nil {
- log("Failed to connect to ORPort: " + err.Error())
- return
- }
-
- proxy(s, &conn)
-}
-
-func startListener(addr *net.TCPAddr) (*net.TCPListener, error) {
- ln, err := net.ListenTCP("tcp", addr)
- if err != nil {
- return nil, err
- }
- go func() {
- var config websocket.Config
- config.Subprotocols = []string{"base64"}
- config.MaxMessageSize = maxMessageSize
- s := &http.Server{
- Handler: config.Handler(webSocketHandler),
- ReadTimeout: requestTimeout,
- }
- err = s.Serve(ln)
- if err != nil {
- log("http.Serve: " + err.Error())
- }
- }()
- return ln, nil
-}
-
-func main() {
- var logFilename string
- var port int
-
- flag.Usage = usage
- flag.StringVar(&logFilename, "log", "", "log file to write to")
- flag.IntVar(&port, "port", 0, "port to listen on if unspecified by Tor")
- flag.Parse()
-
- if logFilename != "" {
- f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
- if err != nil {
- fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error())
- os.Exit(1)
- }
- logFile = f
- }
-
- log("starting")
- ptInfo = pt.ServerSetup([]string{ptMethodName})
-
- listeners := make([]*net.TCPListener, 0)
- for _, bindAddr := range ptInfo.BindAddrs {
- // Override tor's requested port (which is 0 if this transport
- // has not been run before) with the one requested by the --port
- // option.
- if port != 0 {
- bindAddr.Addr.Port = port
- }
-
- ln, err := startListener(bindAddr.Addr)
- if err != nil {
- pt.SmethodError(bindAddr.MethodName, err.Error())
- continue
- }
- pt.Smethod(bindAddr.MethodName, ln.Addr())
- log("listening on %s", ln.Addr().String())
- listeners = append(listeners, ln)
- }
- pt.SmethodsDone()
-
- var numHandlers int = 0
- var sig os.Signal
- sigChan := make(chan os.Signal, 1)
- signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
-
- sig = nil
- for sig == nil {
- select {
- case n := <-handlerChan:
- numHandlers += n
- case sig = <-sigChan:
- }
- }
- log("Got first signal %q with %d running handlers.", sig, numHandlers)
- for _, ln := range listeners {
- ln.Close()
- }
-
- if sig == syscall.SIGTERM {
- log("Caught signal %q, exiting.", sig)
- return
- }
-
- sig = nil
- for sig == nil && numHandlers != 0 {
- select {
- case n := <-handlerChan:
- numHandlers += n
- log("%d remaining handlers.", numHandlers)
- case sig = <-sigChan:
- }
- }
- if sig != nil {
- log("Got second signal %q with %d running handlers.", sig, numHandlers)
- }
-}
diff --git a/src/websocket/websocket.go b/src/websocket/websocket.go
deleted file mode 100644
index dc228d1..0000000
--- a/src/websocket/websocket.go
+++ /dev/null
@@ -1,431 +0,0 @@
-// WebSocket library. Only the RFC 6455 variety of WebSocket is supported.
-//
-// Reading and writing is strictly per-frame (or per-message). There is no way
-// to partially read a frame. Config.MaxMessageSize affords control of the
-// maximum buffering of messages.
-//
-// The reason for using this custom implementation instead of
-// code.google.com/p/go.net/websocket is that the latter has problems with long
-// messages and does not support server subprotocols.
-// "Denial of Service Protection in Go HTTP Servers"
-// https://code.google.com/p/go/issues/detail?id=2093
-// "go.websocket: Read/Copy fail with long frames"
-// https://code.google.com/p/go/issues/detail?id=2134
-// http://golang.org/pkg/net/textproto/#pkg-bugs
-// "To let callers manage exposure to denial of service attacks, Reader should
-// allow them to set and reset a limit on the number of bytes read from the
-// connection."
-// "websocket.Dial doesn't limit response header length as http.Get does"
-// https://groups.google.com/forum/?fromgroups=#!topic/golang-nuts/2Tge6U8-QYI
-//
-// Example usage:
-//
-// func doSomething(ws *WebSocket) {
-// }
-// var config websocket.Config
-// config.Subprotocols = []string{"base64"}
-// config.MaxMessageSize = 2500
-// http.Handle("/", config.Handler(doSomething))
-// err = http.ListenAndServe(":8080", nil)
-
-package websocket
-
-import (
- "bufio"
- "bytes"
- "crypto/rand"
- "crypto/sha1"
- "encoding/base64"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "net"
- "net/http"
- "strings"
-)
-
-// Settings for potential WebSocket connections. Subprotocols is a list of
-// supported subprotocols as in RFC 6455 section 1.9. When answering client
-// requests, the first of the client's requests subprotocols that is also in
-// this list (if any) will be used as the subprotocol for the connection.
-// MaxMessageSize is a limit on buffering messages.
-type Config struct {
- Subprotocols []string
- MaxMessageSize int
-}
-
-// Representation of a WebSocket frame. The Payload is always without masking.
-type Frame struct {
- Fin bool
- Opcode byte
- Payload []byte
-}
-
-// Return true iff the frame's opcode says it is a control frame.
-func (frame *Frame) IsControl() bool {
- return (frame.Opcode & 0x08) != 0
-}
-
-// Representation of a WebSocket message. The Payload is always without masking.
-type Message struct {
- Opcode byte
- Payload []byte
-}
-
-// A WebSocket connection after hijacking from HTTP.
-type WebSocket struct {
- // Conn and ReadWriter from http.ResponseWriter.Hijack.
- Conn net.Conn
- Bufrw *bufio.ReadWriter
- // Whether we are a client or a server has implications for masking.
- IsClient bool
- // Set from a parent Config.
- MaxMessageSize int
- // The single selected subprotocol after negotiation, or "".
- Subprotocol string
- // Buffer for message payloads, which may be interrupted by control
- // messages.
- messageBuf bytes.Buffer
-}
-
-func applyMask(payload []byte, maskKey [4]byte) {
- for i := 0; i < len(payload); i++ {
- payload[i] = payload[i] ^ maskKey[i%4]
- }
-}
-
-func (ws *WebSocket) maxMessageSize() int {
- if ws.MaxMessageSize == 0 {
- return 64000
- }
- return ws.MaxMessageSize
-}
-
-// Read a single frame from the WebSocket.
-func (ws *WebSocket) ReadFrame() (frame Frame, err error) {
- var b byte
- err = binary.Read(ws.Bufrw, binary.BigEndian, &b)
- if err != nil {
- return
- }
- frame.Fin = (b & 0x80) != 0
- frame.Opcode = b & 0x0f
- err = binary.Read(ws.Bufrw, binary.BigEndian, &b)
- if err != nil {
- return
- }
- masked := (b & 0x80) != 0
-
- payloadLen := uint64(b & 0x7f)
- if payloadLen == 126 {
- var short uint16
- err = binary.Read(ws.Bufrw, binary.BigEndian, &short)
- if err != nil {
- return
- }
- payloadLen = uint64(short)
- } else if payloadLen == 127 {
- var long uint64
- err = binary.Read(ws.Bufrw, binary.BigEndian, &long)
- if err != nil {
- return
- }
- payloadLen = long
- }
- if payloadLen > uint64(ws.maxMessageSize()) {
- err = errors.New(fmt.Sprintf("frame payload length of %d exceeds maximum of %d", payloadLen, ws.MaxMessageSize))
- return
- }
-
- maskKey := [4]byte{}
- if masked {
- if ws.IsClient {
- err = errors.New("client got masked frame")
- return
- }
- err = binary.Read(ws.Bufrw, binary.BigEndian, &maskKey)
- if err != nil {
- return
- }
- } else {
- if !ws.IsClient {
- err = errors.New("server got unmasked frame")
- return
- }
- }
-
- frame.Payload = make([]byte, payloadLen)
- _, err = io.ReadFull(ws.Bufrw, frame.Payload)
- if err != nil {
- return
- }
- if masked {
- applyMask(frame.Payload, maskKey)
- }
-
- return frame, nil
-}
-
-// Read a single message from the WebSocket. Multiple fragmented frames are
-// combined into a single message before being returned. Non-control messages
-// may be interrupted by control frames. The control frames are returned as
-// individual messages before the message that they interrupt.
-func (ws *WebSocket) ReadMessage() (message Message, err error) {
- var opcode byte = 0
- for {
- var frame Frame
- frame, err = ws.ReadFrame()
- if err != nil {
- return
- }
- if frame.IsControl() {
- if !frame.Fin {
- err = errors.New("control frame has fin bit unset")
- return
- }
- message.Opcode = frame.Opcode
- message.Payload = frame.Payload
- return message, nil
- }
-
- if opcode == 0 {
- if frame.Opcode == 0 {
- err = errors.New("first frame has opcode 0")
- return
- }
- opcode = frame.Opcode
- } else {
- if frame.Opcode != 0 {
- err = errors.New(fmt.Sprintf("non-first frame has nonzero opcode %d", frame.Opcode))
- return
- }
- }
- if ws.messageBuf.Len()+len(frame.Payload) > ws.MaxMessageSize {
- err = errors.New(fmt.Sprintf("message payload length of %d exceeds maximum of %d",
- ws.messageBuf.Len()+len(frame.Payload), ws.MaxMessageSize))
- return
- }
- ws.messageBuf.Write(frame.Payload)
- if frame.Fin {
- break
- }
- }
- message.Opcode = opcode
- message.Payload = ws.messageBuf.Bytes()
- ws.messageBuf.Reset()
-
- return message, nil
-}
-
-// Write a single frame to the WebSocket stream. Destructively masks payload in
-// place if ws.IsClient. Frames are always unfragmented.
-func (ws *WebSocket) WriteFrame(opcode byte, payload []byte) (err error) {
- if opcode >= 16 {
- err = errors.New(fmt.Sprintf("opcode %d is >= 16", opcode))
- return
- }
- ws.Bufrw.WriteByte(0x80 | opcode)
-
- var maskBit byte
- var maskKey [4]byte
- if ws.IsClient {
- _, err = io.ReadFull(rand.Reader, maskKey[:])
- if err != nil {
- return
- }
- applyMask(payload, maskKey)
- maskBit = 0x80
- } else {
- maskBit = 0x00
- }
-
- if len(payload) < 126 {
- ws.Bufrw.WriteByte(maskBit | byte(len(payload)))
- } else if len(payload) <= 0xffff {
- ws.Bufrw.WriteByte(maskBit | 126)
- binary.Write(ws.Bufrw, binary.BigEndian, uint16(len(payload)))
- } else {
- ws.Bufrw.WriteByte(maskBit | 127)
- binary.Write(ws.Bufrw, binary.BigEndian, uint64(len(payload)))
- }
-
- if ws.IsClient {
- _, err = ws.Bufrw.Write(maskKey[:])
- if err != nil {
- return
- }
- }
- _, err = ws.Bufrw.Write(payload)
- if err != nil {
- return
- }
-
- ws.Bufrw.Flush()
-
- return
-}
-
-// Write a single message to the WebSocket stream. Destructively masks payload
-// in place if ws.IsClient. Messages are always sent as a single unfragmented
-// frame.
-func (ws *WebSocket) WriteMessage(opcode byte, payload []byte) (err error) {
- return ws.WriteFrame(opcode, payload)
-}
-
-// Split a string on commas and trim whitespace.
-func commaSplit(s string) []string {
- var result []string
- if strings.TrimSpace(s) == "" {
- return result
- }
- for _, e := range strings.Split(s, ",") {
- result = append(result, strings.TrimSpace(e))
- }
- return result
-}
-
-// Returns true iff one of the strings in haystack is needle (case-insensitive).
-func containsCase(haystack []string, needle string) bool {
- for _, e := range haystack {
- if strings.ToLower(e) == strings.ToLower(needle) {
- return true
- }
- }
- return false
-}
-
-// One-step SHA-1 hash of a string.
-func sha1Hash(data string) []byte {
- h := sha1.New()
- h.Write([]byte(data))
- return h.Sum(nil)
-}
-
-func httpError(w http.ResponseWriter, bufrw *bufio.ReadWriter, code int) {
- w.Header().Set("Connection", "close")
- bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", code, http.StatusText(code)))
- w.Header().Write(bufrw)
- bufrw.WriteString("\r\n")
- bufrw.Flush()
-}
-
-// An implementation of http.Handler with a Config. The ServeHTTP function calls
-// Callback assuming WebSocket HTTP negotiation is successful.
-type HTTPHandler struct {
- Config *Config
- Callback func(*WebSocket)
-}
-
-// Implements the http.Handler interface.
-func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- conn, bufrw, err := w.(http.Hijacker).Hijack()
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
- defer conn.Close()
-
- // See RFC 6455 section 4.2.1 for this sequence of checks.
-
- // 1. An HTTP/1.1 or higher GET request, including a "Request-URI"...
- if req.Method != "GET" {
- httpError(w, bufrw, http.StatusMethodNotAllowed)
- return
- }
- if req.URL.Path != "/" {
- httpError(w, bufrw, http.StatusNotFound)
- return
- }
- // 2. A |Host| header field containing the server's authority.
- // We deliberately skip this test.
- // 3. An |Upgrade| header field containing the value "websocket",
- // treated as an ASCII case-insensitive value.
- if !containsCase(commaSplit(req.Header.Get("Upgrade")), "websocket") {
- httpError(w, bufrw, http.StatusBadRequest)
- return
- }
- // 4. A |Connection| header field that includes the token "Upgrade",
- // treated as an ASCII case-insensitive value.
- if !containsCase(commaSplit(req.Header.Get("Connection")), "Upgrade") {
- httpError(w, bufrw, http.StatusBadRequest)
- return
- }
- // 5. A |Sec-WebSocket-Key| header field with a base64-encoded value
- // that, when decoded, is 16 bytes in length.
- websocketKey := req.Header.Get("Sec-WebSocket-Key")
- key, err := base64.StdEncoding.DecodeString(websocketKey)
- if err != nil || len(key) != 16 {
- httpError(w, bufrw, http.StatusBadRequest)
- return
- }
- // 6. A |Sec-WebSocket-Version| header field, with a value of 13.
- // We also allow 8 from draft-ietf-hybi-thewebsocketprotocol-10.
- var knownVersions = []string{"8", "13"}
- websocketVersion := req.Header.Get("Sec-WebSocket-Version")
- if !containsCase(knownVersions, websocketVersion) {
- // "If this version does not match a version understood by the
- // server, the server MUST abort the WebSocket handshake
- // described in this section and instead send an appropriate
- // HTTP error code (such as 426 Upgrade Required) and a
- // |Sec-WebSocket-Version| header field indicating the
- // version(s) the server is capable of understanding."
- w.Header().Set("Sec-WebSocket-Version", strings.Join(knownVersions, ", "))
- httpError(w, bufrw, 426)
- return
- }
- // 7. Optionally, an |Origin| header field.
- // 8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list of
- // values indicating which protocols the client would like to speak, ordered
- // by preference.
- clientProtocols := commaSplit(req.Header.Get("Sec-WebSocket-Protocol"))
- // 9. Optionally, a |Sec-WebSocket-Extensions| header field...
- // 10. Optionally, other header fields...
-
- var ws WebSocket
- ws.Conn = conn
- ws.Bufrw = bufrw
- ws.IsClient = false
- ws.MaxMessageSize = handler.Config.MaxMessageSize
-
- // See RFC 6455 section 4.2.2, item 5 for these steps.
-
- // 1. A Status-Line with a 101 response code as per RFC 2616.
- bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)))
- // 2. An |Upgrade| header field with value "websocket" as per RFC 2616.
- w.Header().Set("Upgrade", "websocket")
- // 3. A |Connection| header field with value "Upgrade".
- w.Header().Set("Connection", "Upgrade")
- // 4. A |Sec-WebSocket-Accept| header field. The value of this header
- // field is constructed by concatenating /key/, defined above in step 4
- // in Section 4.2.2, with the string
- // "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this
- // concatenated value to obtain a 20-byte value and base64-encoding (see
- // Section 4 of [RFC4648]) this 20-byte hash.
- const magicGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
- acceptKey := base64.StdEncoding.EncodeToString(sha1Hash(websocketKey + magicGUID))
- w.Header().Set("Sec-WebSocket-Accept", acceptKey)
- // 5. Optionally, a |Sec-WebSocket-Protocol| header field, with a value
- // /subprotocol/ as defined in step 4 in Section 4.2.2.
- for _, clientProto := range clientProtocols {
- for _, serverProto := range handler.Config.Subprotocols {
- if clientProto == serverProto {
- ws.Subprotocol = clientProto
- w.Header().Set("Sec-WebSocket-Protocol", clientProto)
- break
- }
- }
- }
- // 6. Optionally, a |Sec-WebSocket-Extensions| header field...
- w.Header().Write(bufrw)
- bufrw.WriteString("\r\n")
- bufrw.Flush()
-
- // Call the WebSocket-specific handler.
- handler.Callback(&ws)
-}
-
-// Return an http.Handler with the given callback function.
-func (config *Config) Handler(callback func(*WebSocket)) http.Handler {
- return &HTTPHandler{config, callback}
-}
diff --git a/websocket-client/websocket-client.go b/websocket-client/websocket-client.go
new file mode 100644
index 0000000..7f838bb
--- /dev/null
+++ b/websocket-client/websocket-client.go
@@ -0,0 +1,254 @@
+// Tor websocket client transport plugin.
+//
+// Usage:
+// ClientTransportPlugin websocket exec ./websocket-client
+
+package main
+
+import (
+ "code.google.com/p/go.net/websocket"
+ "flag"
+ "fmt"
+ "io"
+ "net"
+ "net/url"
+ "os"
+ "os/signal"
+ "sync"
+ "time"
+)
+
+import "../pt"
+import "../pt/socks"
+
+const ptMethodName = "websocket"
+const socksTimeout = 2 * time.Second
+const bufSiz = 1500
+
+var logFile = os.Stderr
+
+// When a connection handler starts, +1 is written to this channel; when it
+// ends, -1 is written.
+var handlerChan = make(chan int)
+
+var logMutex sync.Mutex
+
+func usage() {
+ fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0])
+ fmt.Printf("WebSocket client pluggable transport for Tor.\n")
+ fmt.Printf("Works only as a managed proxy.\n")
+ fmt.Printf("\n")
+ fmt.Printf(" -h, --help show this help.\n")
+ fmt.Printf(" --log FILE log messages to FILE (default stderr).\n")
+ fmt.Printf(" --socks ADDR listen for SOCKS on ADDR.\n")
+}
+
+func Log(format string, v ...interface{}) {
+ dateStr := time.Now().Format("2006-01-02 15:04:05")
+ logMutex.Lock()
+ defer logMutex.Unlock()
+ msg := fmt.Sprintf(format, v...)
+ fmt.Fprintf(logFile, "%s %s\n", dateStr, msg)
+}
+
+func proxy(local *net.TCPConn, ws *websocket.Conn) {
+ var wg sync.WaitGroup
+
+ wg.Add(2)
+
+ // Local-to-WebSocket read loop.
+ go func() {
+ buf := make([]byte, bufSiz)
+ var err error
+ for {
+ n, er := local.Read(buf[:])
+ if n > 0 {
+ ew := websocket.Message.Send(ws, buf[:n])
+ if ew != nil {
+ err = ew
+ break
+ }
+ }
+ if er != nil {
+ err = er
+ break
+ }
+ }
+ if err != nil && err != io.EOF {
+ Log("%s", err)
+ }
+ local.CloseRead()
+ ws.Close()
+
+ wg.Done()
+ }()
+
+ // WebSocket-to-local read loop.
+ go func() {
+ var buf []byte
+ var err error
+ for {
+ er := websocket.Message.Receive(ws, &buf)
+ if er != nil {
+ err = er
+ break
+ }
+ n, ew := local.Write(buf)
+ if ew != nil {
+ err = ew
+ break
+ }
+ if n != len(buf) {
+ err = io.ErrShortWrite
+ break
+ }
+ }
+ if err != nil && err != io.EOF {
+ Log("%s", err)
+ }
+ local.CloseWrite()
+ ws.Close()
+
+ wg.Done()
+ }()
+
+ wg.Wait()
+}
+
+func handleConnection(conn *net.TCPConn) error {
+ defer conn.Close()
+
+ handlerChan <- 1
+ defer func() {
+ handlerChan <- -1
+ }()
+
+ var ws *websocket.Conn
+
+ conn.SetDeadline(time.Now().Add(socksTimeout))
+ err := socks.AwaitSocks4aConnect(conn, func(dest string) (*net.TCPAddr, error) {
+ // Disable deadline.
+ conn.SetDeadline(time.Time{})
+ Log("SOCKS request for %s", dest)
+ destAddr, err := net.ResolveTCPAddr("tcp", dest)
+ if err != nil {
+ return nil, err
+ }
+ wsUrl := url.URL{Scheme: "ws", Host: dest}
+ ws, err = websocket.Dial(wsUrl.String(), "", wsUrl.String())
+ if err != nil {
+ return nil, err
+ }
+ Log("WebSocket connection to %s", ws.Config().Location.String())
+ return destAddr, nil
+ })
+ if err != nil {
+ return err
+ }
+ defer ws.Close()
+ proxy(conn, ws)
+ return nil
+}
+
+func socksAcceptLoop(ln *net.TCPListener) error {
+ for {
+ socks, err := ln.AcceptTCP()
+ if err != nil {
+ return err
+ }
+ go func() {
+ err := handleConnection(socks)
+ if err != nil {
+ Log("SOCKS from %s: %s", socks.RemoteAddr(), err)
+ }
+ }()
+ }
+ return nil
+}
+
+func startListener(addrStr string) (*net.TCPListener, error) {
+ addr, err := net.ResolveTCPAddr("tcp", addrStr)
+ if err != nil {
+ return nil, err
+ }
+ ln, err := net.ListenTCP("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ go func() {
+ err := socksAcceptLoop(ln)
+ if err != nil {
+ Log("accept: %s", err)
+ }
+ }()
+ return ln, nil
+}
+
+func main() {
+ var logFilename string
+ var socksAddrStrs = []string{"127.0.0.1:0"}
+ var socksArg string
+
+ flag.Usage = usage
+ flag.StringVar(&logFilename, "log", "", "log file to write to")
+ flag.StringVar(&socksArg, "socks", "", "address on which to listen for SOCKS connections")
+ flag.Parse()
+
+ if logFilename != "" {
+ f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error())
+ os.Exit(1)
+ }
+ logFile = f
+ }
+
+ if socksArg != "" {
+ socksAddrStrs = []string{socksArg}
+ }
+
+ Log("starting")
+ pt.ClientSetup([]string{ptMethodName})
+
+ listeners := make([]*net.TCPListener, 0)
+ for _, socksAddrStr := range socksAddrStrs {
+ ln, err := startListener(socksAddrStr)
+ if err != nil {
+ pt.CmethodError(ptMethodName, err.Error())
+ }
+ pt.Cmethod(ptMethodName, "socks4", ln.Addr())
+ Log("listening on %s", ln.Addr().String())
+ listeners = append(listeners, ln)
+ }
+ pt.CmethodsDone()
+
+ var numHandlers int = 0
+
+ signalChan := make(chan os.Signal, 1)
+ signal.Notify(signalChan, os.Interrupt)
+ var sigint bool = false
+ for !sigint {
+ select {
+ case n := <-handlerChan:
+ numHandlers += n
+ case <-signalChan:
+ Log("SIGINT")
+ sigint = true
+ }
+ }
+
+ for _, ln := range listeners {
+ ln.Close()
+ }
+
+ sigint = false
+ for numHandlers != 0 && !sigint {
+ select {
+ case n := <-handlerChan:
+ numHandlers += n
+ case <-signalChan:
+ Log("SIGINT")
+ sigint = true
+ }
+ }
+}
diff --git a/websocket-server/websocket-server.go b/websocket-server/websocket-server.go
new file mode 100644
index 0000000..e5ed1c5
--- /dev/null
+++ b/websocket-server/websocket-server.go
@@ -0,0 +1,285 @@
+// Tor websocket server transport plugin.
+//
+// Usage:
+// ServerTransportPlugin websocket exec ./websocket-server --port 9901
+
+package main
+
+import (
+ "encoding/base64"
+ "errors"
+ "flag"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "os"
+ "os/signal"
+ "sync"
+ "syscall"
+ "time"
+)
+
+import "../pt"
+import "../websocket"
+
+const ptMethodName = "websocket"
+const requestTimeout = 10 * time.Second
+
+// "4/3+1" accounts for possible base64 encoding.
+const maxMessageSize = 64*1024*4/3 + 1
+
+var logFile = os.Stderr
+
+var ptInfo pt.ServerInfo
+
+// When a connection handler starts, +1 is written to this channel; when it
+// ends, -1 is written.
+var handlerChan = make(chan int)
+
+func usage() {
+ fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0])
+ fmt.Printf("WebSocket server pluggable transport for Tor.\n")
+ fmt.Printf("Works only as a managed proxy.\n")
+ fmt.Printf("\n")
+ fmt.Printf(" -h, --help show this help.\n")
+ fmt.Printf(" --log FILE log messages to FILE (default stderr).\n")
+ fmt.Printf(" --port PORT listen on PORT (overrides Tor's requested port).\n")
+}
+
+var logMutex sync.Mutex
+
+func log(format string, v ...interface{}) {
+ dateStr := time.Now().Format("2006-01-02 15:04:05")
+ logMutex.Lock()
+ defer logMutex.Unlock()
+ msg := fmt.Sprintf(format, v...)
+ fmt.Fprintf(logFile, "%s %s\n", dateStr, msg)
+}
+
+// An abstraction that makes an underlying WebSocket connection look like an
+// io.ReadWriteCloser. It internally takes care of things like base64 encoding
+// and decoding.
+type webSocketConn struct {
+ Ws *websocket.WebSocket
+ Base64 bool
+ messageBuf []byte
+}
+
+// Implements io.Reader.
+func (conn *webSocketConn) Read(b []byte) (n int, err error) {
+ for len(conn.messageBuf) == 0 {
+ var m websocket.Message
+ m, err = conn.Ws.ReadMessage()
+ if err != nil {
+ return
+ }
+ if m.Opcode == 8 {
+ err = io.EOF
+ return
+ }
+ if conn.Base64 {
+ if m.Opcode != 1 {
+ err = errors.New(fmt.Sprintf("got non-text opcode %d with the base64 subprotocol", m.Opcode))
+ return
+ }
+ conn.messageBuf = make([]byte, base64.StdEncoding.DecodedLen(len(m.Payload)))
+ var num int
+ num, err = base64.StdEncoding.Decode(conn.messageBuf, m.Payload)
+ if err != nil {
+ return
+ }
+ conn.messageBuf = conn.messageBuf[:num]
+ } else {
+ if m.Opcode != 2 {
+ err = errors.New(fmt.Sprintf("got non-binary opcode %d with no subprotocol", m.Opcode))
+ return
+ }
+ conn.messageBuf = m.Payload
+ }
+ }
+
+ n = copy(b, conn.messageBuf)
+ conn.messageBuf = conn.messageBuf[n:]
+
+ return
+}
+
+// Implements io.Writer.
+func (conn *webSocketConn) Write(b []byte) (n int, err error) {
+ if conn.Base64 {
+ buf := make([]byte, base64.StdEncoding.EncodedLen(len(b)))
+ base64.StdEncoding.Encode(buf, b)
+ err = conn.Ws.WriteMessage(1, buf)
+ if err != nil {
+ return
+ }
+ n = len(b)
+ } else {
+ err = conn.Ws.WriteMessage(2, b)
+ n = len(b)
+ }
+ return
+}
+
+// Implements io.Closer.
+func (conn *webSocketConn) Close() error {
+ // Ignore any error in trying to write a Close frame.
+ _ = conn.Ws.WriteFrame(8, nil)
+ return conn.Ws.Conn.Close()
+}
+
+// Create a new webSocketConn.
+func newWebSocketConn(ws *websocket.WebSocket) webSocketConn {
+ var conn webSocketConn
+ conn.Ws = ws
+ conn.Base64 = (ws.Subprotocol == "base64")
+ return conn
+}
+
+// Copy from WebSocket to socket and vice versa.
+func proxy(local *net.TCPConn, conn *webSocketConn) {
+ var wg sync.WaitGroup
+
+ wg.Add(2)
+
+ go func() {
+ _, err := io.Copy(conn, local)
+ if err != nil {
+ log("error copying ORPort to WebSocket")
+ }
+ local.CloseRead()
+ conn.Close()
+ wg.Done()
+ }()
+
+ go func() {
+ _, err := io.Copy(local, conn)
+ if err != nil {
+ log("error copying WebSocket to ORPort")
+ }
+ local.CloseWrite()
+ conn.Close()
+ wg.Done()
+ }()
+
+ wg.Wait()
+}
+
+func webSocketHandler(ws *websocket.WebSocket) {
+ // Undo timeouts on HTTP request handling.
+ ws.Conn.SetDeadline(time.Time{})
+ conn := newWebSocketConn(ws)
+
+ handlerChan <- 1
+ defer func() {
+ handlerChan <- -1
+ }()
+
+ s, err := pt.ConnectOr(&ptInfo, ws.Conn, ptMethodName)
+ if err != nil {
+ log("Failed to connect to ORPort: " + err.Error())
+ return
+ }
+
+ proxy(s, &conn)
+}
+
+func startListener(addr *net.TCPAddr) (*net.TCPListener, error) {
+ ln, err := net.ListenTCP("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ go func() {
+ var config websocket.Config
+ config.Subprotocols = []string{"base64"}
+ config.MaxMessageSize = maxMessageSize
+ s := &http.Server{
+ Handler: config.Handler(webSocketHandler),
+ ReadTimeout: requestTimeout,
+ }
+ err = s.Serve(ln)
+ if err != nil {
+ log("http.Serve: " + err.Error())
+ }
+ }()
+ return ln, nil
+}
+
+func main() {
+ var logFilename string
+ var port int
+
+ flag.Usage = usage
+ flag.StringVar(&logFilename, "log", "", "log file to write to")
+ flag.IntVar(&port, "port", 0, "port to listen on if unspecified by Tor")
+ flag.Parse()
+
+ if logFilename != "" {
+ f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error())
+ os.Exit(1)
+ }
+ logFile = f
+ }
+
+ log("starting")
+ ptInfo = pt.ServerSetup([]string{ptMethodName})
+
+ listeners := make([]*net.TCPListener, 0)
+ for _, bindAddr := range ptInfo.BindAddrs {
+ // Override tor's requested port (which is 0 if this transport
+ // has not been run before) with the one requested by the --port
+ // option.
+ if port != 0 {
+ bindAddr.Addr.Port = port
+ }
+
+ ln, err := startListener(bindAddr.Addr)
+ if err != nil {
+ pt.SmethodError(bindAddr.MethodName, err.Error())
+ continue
+ }
+ pt.Smethod(bindAddr.MethodName, ln.Addr())
+ log("listening on %s", ln.Addr().String())
+ listeners = append(listeners, ln)
+ }
+ pt.SmethodsDone()
+
+ var numHandlers int = 0
+ var sig os.Signal
+ sigChan := make(chan os.Signal, 1)
+ signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
+
+ sig = nil
+ for sig == nil {
+ select {
+ case n := <-handlerChan:
+ numHandlers += n
+ case sig = <-sigChan:
+ }
+ }
+ log("Got first signal %q with %d running handlers.", sig, numHandlers)
+ for _, ln := range listeners {
+ ln.Close()
+ }
+
+ if sig == syscall.SIGTERM {
+ log("Caught signal %q, exiting.", sig)
+ return
+ }
+
+ sig = nil
+ for sig == nil && numHandlers != 0 {
+ select {
+ case n := <-handlerChan:
+ numHandlers += n
+ log("%d remaining handlers.", numHandlers)
+ case sig = <-sigChan:
+ }
+ }
+ if sig != nil {
+ log("Got second signal %q with %d running handlers.", sig, numHandlers)
+ }
+}
diff --git a/websocket/websocket.go b/websocket/websocket.go
new file mode 100644
index 0000000..dc228d1
--- /dev/null
+++ b/websocket/websocket.go
@@ -0,0 +1,431 @@
+// WebSocket library. Only the RFC 6455 variety of WebSocket is supported.
+//
+// Reading and writing is strictly per-frame (or per-message). There is no way
+// to partially read a frame. Config.MaxMessageSize affords control of the
+// maximum buffering of messages.
+//
+// The reason for using this custom implementation instead of
+// code.google.com/p/go.net/websocket is that the latter has problems with long
+// messages and does not support server subprotocols.
+// "Denial of Service Protection in Go HTTP Servers"
+// https://code.google.com/p/go/issues/detail?id=2093
+// "go.websocket: Read/Copy fail with long frames"
+// https://code.google.com/p/go/issues/detail?id=2134
+// http://golang.org/pkg/net/textproto/#pkg-bugs
+// "To let callers manage exposure to denial of service attacks, Reader should
+// allow them to set and reset a limit on the number of bytes read from the
+// connection."
+// "websocket.Dial doesn't limit response header length as http.Get does"
+// https://groups.google.com/forum/?fromgroups=#!topic/golang-nuts/2Tge6U8-QYI
+//
+// Example usage:
+//
+// func doSomething(ws *WebSocket) {
+// }
+// var config websocket.Config
+// config.Subprotocols = []string{"base64"}
+// config.MaxMessageSize = 2500
+// http.Handle("/", config.Handler(doSomething))
+// err = http.ListenAndServe(":8080", nil)
+
+package websocket
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/rand"
+ "crypto/sha1"
+ "encoding/base64"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "strings"
+)
+
+// Settings for potential WebSocket connections. Subprotocols is a list of
+// supported subprotocols as in RFC 6455 section 1.9. When answering client
+// requests, the first of the client's requests subprotocols that is also in
+// this list (if any) will be used as the subprotocol for the connection.
+// MaxMessageSize is a limit on buffering messages.
+type Config struct {
+ Subprotocols []string
+ MaxMessageSize int
+}
+
+// Representation of a WebSocket frame. The Payload is always without masking.
+type Frame struct {
+ Fin bool
+ Opcode byte
+ Payload []byte
+}
+
+// Return true iff the frame's opcode says it is a control frame.
+func (frame *Frame) IsControl() bool {
+ return (frame.Opcode & 0x08) != 0
+}
+
+// Representation of a WebSocket message. The Payload is always without masking.
+type Message struct {
+ Opcode byte
+ Payload []byte
+}
+
+// A WebSocket connection after hijacking from HTTP.
+type WebSocket struct {
+ // Conn and ReadWriter from http.ResponseWriter.Hijack.
+ Conn net.Conn
+ Bufrw *bufio.ReadWriter
+ // Whether we are a client or a server has implications for masking.
+ IsClient bool
+ // Set from a parent Config.
+ MaxMessageSize int
+ // The single selected subprotocol after negotiation, or "".
+ Subprotocol string
+ // Buffer for message payloads, which may be interrupted by control
+ // messages.
+ messageBuf bytes.Buffer
+}
+
+func applyMask(payload []byte, maskKey [4]byte) {
+ for i := 0; i < len(payload); i++ {
+ payload[i] = payload[i] ^ maskKey[i%4]
+ }
+}
+
+func (ws *WebSocket) maxMessageSize() int {
+ if ws.MaxMessageSize == 0 {
+ return 64000
+ }
+ return ws.MaxMessageSize
+}
+
+// Read a single frame from the WebSocket.
+func (ws *WebSocket) ReadFrame() (frame Frame, err error) {
+ var b byte
+ err = binary.Read(ws.Bufrw, binary.BigEndian, &b)
+ if err != nil {
+ return
+ }
+ frame.Fin = (b & 0x80) != 0
+ frame.Opcode = b & 0x0f
+ err = binary.Read(ws.Bufrw, binary.BigEndian, &b)
+ if err != nil {
+ return
+ }
+ masked := (b & 0x80) != 0
+
+ payloadLen := uint64(b & 0x7f)
+ if payloadLen == 126 {
+ var short uint16
+ err = binary.Read(ws.Bufrw, binary.BigEndian, &short)
+ if err != nil {
+ return
+ }
+ payloadLen = uint64(short)
+ } else if payloadLen == 127 {
+ var long uint64
+ err = binary.Read(ws.Bufrw, binary.BigEndian, &long)
+ if err != nil {
+ return
+ }
+ payloadLen = long
+ }
+ if payloadLen > uint64(ws.maxMessageSize()) {
+ err = errors.New(fmt.Sprintf("frame payload length of %d exceeds maximum of %d", payloadLen, ws.MaxMessageSize))
+ return
+ }
+
+ maskKey := [4]byte{}
+ if masked {
+ if ws.IsClient {
+ err = errors.New("client got masked frame")
+ return
+ }
+ err = binary.Read(ws.Bufrw, binary.BigEndian, &maskKey)
+ if err != nil {
+ return
+ }
+ } else {
+ if !ws.IsClient {
+ err = errors.New("server got unmasked frame")
+ return
+ }
+ }
+
+ frame.Payload = make([]byte, payloadLen)
+ _, err = io.ReadFull(ws.Bufrw, frame.Payload)
+ if err != nil {
+ return
+ }
+ if masked {
+ applyMask(frame.Payload, maskKey)
+ }
+
+ return frame, nil
+}
+
+// Read a single message from the WebSocket. Multiple fragmented frames are
+// combined into a single message before being returned. Non-control messages
+// may be interrupted by control frames. The control frames are returned as
+// individual messages before the message that they interrupt.
+func (ws *WebSocket) ReadMessage() (message Message, err error) {
+ var opcode byte = 0
+ for {
+ var frame Frame
+ frame, err = ws.ReadFrame()
+ if err != nil {
+ return
+ }
+ if frame.IsControl() {
+ if !frame.Fin {
+ err = errors.New("control frame has fin bit unset")
+ return
+ }
+ message.Opcode = frame.Opcode
+ message.Payload = frame.Payload
+ return message, nil
+ }
+
+ if opcode == 0 {
+ if frame.Opcode == 0 {
+ err = errors.New("first frame has opcode 0")
+ return
+ }
+ opcode = frame.Opcode
+ } else {
+ if frame.Opcode != 0 {
+ err = errors.New(fmt.Sprintf("non-first frame has nonzero opcode %d", frame.Opcode))
+ return
+ }
+ }
+ if ws.messageBuf.Len()+len(frame.Payload) > ws.MaxMessageSize {
+ err = errors.New(fmt.Sprintf("message payload length of %d exceeds maximum of %d",
+ ws.messageBuf.Len()+len(frame.Payload), ws.MaxMessageSize))
+ return
+ }
+ ws.messageBuf.Write(frame.Payload)
+ if frame.Fin {
+ break
+ }
+ }
+ message.Opcode = opcode
+ message.Payload = ws.messageBuf.Bytes()
+ ws.messageBuf.Reset()
+
+ return message, nil
+}
+
+// Write a single frame to the WebSocket stream. Destructively masks payload in
+// place if ws.IsClient. Frames are always unfragmented.
+func (ws *WebSocket) WriteFrame(opcode byte, payload []byte) (err error) {
+ if opcode >= 16 {
+ err = errors.New(fmt.Sprintf("opcode %d is >= 16", opcode))
+ return
+ }
+ ws.Bufrw.WriteByte(0x80 | opcode)
+
+ var maskBit byte
+ var maskKey [4]byte
+ if ws.IsClient {
+ _, err = io.ReadFull(rand.Reader, maskKey[:])
+ if err != nil {
+ return
+ }
+ applyMask(payload, maskKey)
+ maskBit = 0x80
+ } else {
+ maskBit = 0x00
+ }
+
+ if len(payload) < 126 {
+ ws.Bufrw.WriteByte(maskBit | byte(len(payload)))
+ } else if len(payload) <= 0xffff {
+ ws.Bufrw.WriteByte(maskBit | 126)
+ binary.Write(ws.Bufrw, binary.BigEndian, uint16(len(payload)))
+ } else {
+ ws.Bufrw.WriteByte(maskBit | 127)
+ binary.Write(ws.Bufrw, binary.BigEndian, uint64(len(payload)))
+ }
+
+ if ws.IsClient {
+ _, err = ws.Bufrw.Write(maskKey[:])
+ if err != nil {
+ return
+ }
+ }
+ _, err = ws.Bufrw.Write(payload)
+ if err != nil {
+ return
+ }
+
+ ws.Bufrw.Flush()
+
+ return
+}
+
+// Write a single message to the WebSocket stream. Destructively masks payload
+// in place if ws.IsClient. Messages are always sent as a single unfragmented
+// frame.
+func (ws *WebSocket) WriteMessage(opcode byte, payload []byte) (err error) {
+ return ws.WriteFrame(opcode, payload)
+}
+
+// Split a string on commas and trim whitespace.
+func commaSplit(s string) []string {
+ var result []string
+ if strings.TrimSpace(s) == "" {
+ return result
+ }
+ for _, e := range strings.Split(s, ",") {
+ result = append(result, strings.TrimSpace(e))
+ }
+ return result
+}
+
+// Returns true iff one of the strings in haystack is needle (case-insensitive).
+func containsCase(haystack []string, needle string) bool {
+ for _, e := range haystack {
+ if strings.ToLower(e) == strings.ToLower(needle) {
+ return true
+ }
+ }
+ return false
+}
+
+// One-step SHA-1 hash of a string.
+func sha1Hash(data string) []byte {
+ h := sha1.New()
+ h.Write([]byte(data))
+ return h.Sum(nil)
+}
+
+func httpError(w http.ResponseWriter, bufrw *bufio.ReadWriter, code int) {
+ w.Header().Set("Connection", "close")
+ bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", code, http.StatusText(code)))
+ w.Header().Write(bufrw)
+ bufrw.WriteString("\r\n")
+ bufrw.Flush()
+}
+
+// An implementation of http.Handler with a Config. The ServeHTTP function calls
+// Callback assuming WebSocket HTTP negotiation is successful.
+type HTTPHandler struct {
+ Config *Config
+ Callback func(*WebSocket)
+}
+
+// Implements the http.Handler interface.
+func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ conn, bufrw, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ defer conn.Close()
+
+ // See RFC 6455 section 4.2.1 for this sequence of checks.
+
+ // 1. An HTTP/1.1 or higher GET request, including a "Request-URI"...
+ if req.Method != "GET" {
+ httpError(w, bufrw, http.StatusMethodNotAllowed)
+ return
+ }
+ if req.URL.Path != "/" {
+ httpError(w, bufrw, http.StatusNotFound)
+ return
+ }
+ // 2. A |Host| header field containing the server's authority.
+ // We deliberately skip this test.
+ // 3. An |Upgrade| header field containing the value "websocket",
+ // treated as an ASCII case-insensitive value.
+ if !containsCase(commaSplit(req.Header.Get("Upgrade")), "websocket") {
+ httpError(w, bufrw, http.StatusBadRequest)
+ return
+ }
+ // 4. A |Connection| header field that includes the token "Upgrade",
+ // treated as an ASCII case-insensitive value.
+ if !containsCase(commaSplit(req.Header.Get("Connection")), "Upgrade") {
+ httpError(w, bufrw, http.StatusBadRequest)
+ return
+ }
+ // 5. A |Sec-WebSocket-Key| header field with a base64-encoded value
+ // that, when decoded, is 16 bytes in length.
+ websocketKey := req.Header.Get("Sec-WebSocket-Key")
+ key, err := base64.StdEncoding.DecodeString(websocketKey)
+ if err != nil || len(key) != 16 {
+ httpError(w, bufrw, http.StatusBadRequest)
+ return
+ }
+ // 6. A |Sec-WebSocket-Version| header field, with a value of 13.
+ // We also allow 8 from draft-ietf-hybi-thewebsocketprotocol-10.
+ var knownVersions = []string{"8", "13"}
+ websocketVersion := req.Header.Get("Sec-WebSocket-Version")
+ if !containsCase(knownVersions, websocketVersion) {
+ // "If this version does not match a version understood by the
+ // server, the server MUST abort the WebSocket handshake
+ // described in this section and instead send an appropriate
+ // HTTP error code (such as 426 Upgrade Required) and a
+ // |Sec-WebSocket-Version| header field indicating the
+ // version(s) the server is capable of understanding."
+ w.Header().Set("Sec-WebSocket-Version", strings.Join(knownVersions, ", "))
+ httpError(w, bufrw, 426)
+ return
+ }
+ // 7. Optionally, an |Origin| header field.
+ // 8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list of
+ // values indicating which protocols the client would like to speak, ordered
+ // by preference.
+ clientProtocols := commaSplit(req.Header.Get("Sec-WebSocket-Protocol"))
+ // 9. Optionally, a |Sec-WebSocket-Extensions| header field...
+ // 10. Optionally, other header fields...
+
+ var ws WebSocket
+ ws.Conn = conn
+ ws.Bufrw = bufrw
+ ws.IsClient = false
+ ws.MaxMessageSize = handler.Config.MaxMessageSize
+
+ // See RFC 6455 section 4.2.2, item 5 for these steps.
+
+ // 1. A Status-Line with a 101 response code as per RFC 2616.
+ bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)))
+ // 2. An |Upgrade| header field with value "websocket" as per RFC 2616.
+ w.Header().Set("Upgrade", "websocket")
+ // 3. A |Connection| header field with value "Upgrade".
+ w.Header().Set("Connection", "Upgrade")
+ // 4. A |Sec-WebSocket-Accept| header field. The value of this header
+ // field is constructed by concatenating /key/, defined above in step 4
+ // in Section 4.2.2, with the string
+ // "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this
+ // concatenated value to obtain a 20-byte value and base64-encoding (see
+ // Section 4 of [RFC4648]) this 20-byte hash.
+ const magicGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
+ acceptKey := base64.StdEncoding.EncodeToString(sha1Hash(websocketKey + magicGUID))
+ w.Header().Set("Sec-WebSocket-Accept", acceptKey)
+ // 5. Optionally, a |Sec-WebSocket-Protocol| header field, with a value
+ // /subprotocol/ as defined in step 4 in Section 4.2.2.
+ for _, clientProto := range clientProtocols {
+ for _, serverProto := range handler.Config.Subprotocols {
+ if clientProto == serverProto {
+ ws.Subprotocol = clientProto
+ w.Header().Set("Sec-WebSocket-Protocol", clientProto)
+ break
+ }
+ }
+ }
+ // 6. Optionally, a |Sec-WebSocket-Extensions| header field...
+ w.Header().Write(bufrw)
+ bufrw.WriteString("\r\n")
+ bufrw.Flush()
+
+ // Call the WebSocket-specific handler.
+ handler.Callback(&ws)
+}
+
+// Return an http.Handler with the given callback function.
+func (config *Config) Handler(callback func(*WebSocket)) http.Handler {
+ return &HTTPHandler{config, callback}
+}
More information about the tor-commits
mailing list