[tor-commits] [snowflake/master] Import Turbo Tunnel support code.

dcf at torproject.org dcf at torproject.org
Thu Apr 23 22:43:24 UTC 2020


commit 222ab3d85a4113088db3e3b742411806922c028c
Author: David Fifield <david at bamsoftware.com>
Date:   Tue Jan 28 02:29:34 2020 -0700

    Import Turbo Tunnel support code.
    
    Copied and slightly modified from
    https://gitweb.torproject.org/pluggable-transports/meek.git/log/?h=turbotunnel&id=7eb94209f857fc71c2155907b0462cc587fc76cc
    https://github.com/net4people/bbs/issues/21
    
    RedialPacketConn is adapted from clientPacketConn in
    https://dip.torproject.org/dcf/obfs4/blob/c64a61c6da3bf1c2f98221bb1e1af8a358f22b87/obfs4proxy/turbotunnel_client.go
    https://github.com/net4people/bbs/issues/14#issuecomment-544747519
---
 common/encapsulation/encapsulation.go      | 194 +++++++++++++++++
 common/encapsulation/encapsulation_test.go | 330 +++++++++++++++++++++++++++++
 common/turbotunnel/clientid.go             |  28 +++
 common/turbotunnel/clientmap.go            | 144 +++++++++++++
 common/turbotunnel/consts.go               |  13 ++
 common/turbotunnel/queuepacketconn.go      | 137 ++++++++++++
 common/turbotunnel/redialpacketconn.go     | 204 ++++++++++++++++++
 7 files changed, 1050 insertions(+)

diff --git a/common/encapsulation/encapsulation.go b/common/encapsulation/encapsulation.go
new file mode 100644
index 0000000..bfe9b5b
--- /dev/null
+++ b/common/encapsulation/encapsulation.go
@@ -0,0 +1,194 @@
+// Package encapsulation implements a way of encoding variable-size chunks of
+// data and padding into a byte stream.
+//
+// Each chunk of data or padding starts with a variable-size length prefix. One
+// bit ("d") in the first byte of the prefix indicates whether the chunk
+// represents data or padding (1=data, 0=padding). Another bit ("c" for
+// "continuation") is the indicates whether there are more bytes in the length
+// prefix. The remaining 6 bits ("x") encode part of the length value.
+// 	dcxxxxxx
+// If the continuation bit is set, then the next byte is also part of the length
+// prefix. It lacks the "d" bit, has its own "c" bit, and 7 value-carrying bits
+// ("y").
+// 	cyyyyyyy
+// The length is decoded by concatenating value-carrying bits, from left to
+// right, of all value-carrying bits, up to and including the first byte whose
+// "c" bit is 0. Although in principle this encoding would allow for length
+// prefixes of any size, length prefixes are arbitrarily limited to 3 bytes and
+// any attempt to read or write a longer one is an error. These are therefore
+// the only valid formats:
+// 	00xxxxxx			xxxxxx₂ bytes of padding
+// 	10xxxxxx			xxxxxx₂ bytes of data
+// 	01xxxxxx 0yyyyyyy		xxxxxxyyyyyyy₂ bytes of padding
+// 	11xxxxxx 0yyyyyyy		xxxxxxyyyyyyy₂ bytes of data
+// 	01xxxxxx 1yyyyyyy 0zzzzzzz	xxxxxxyyyyyyyzzzzzzz₂ bytes of padding
+// 	11xxxxxx 1yyyyyyy 0zzzzzzz	xxxxxxyyyyyyyzzzzzzz₂ bytes of data
+// The maximum encodable length is 11111111111111111111₂ = 0xfffff = 1048575.
+// There is no requirement to use a length prefix of minimum size; i.e. 00000100
+// and 01000000 00000100 are both valid encodings of the value 4.
+//
+// After the length prefix follow that many bytes of padding or data. There are
+// no restrictions on the value of bytes comprising padding.
+//
+// The idea for this encapsulation is sketched here:
+// https://github.com/net4people/bbs/issues/9#issuecomment-524095186
+package encapsulation
+
+import (
+	"errors"
+	"io"
+	"io/ioutil"
+)
+
+// ErrTooLong is the error returned when an encoded length prefix is longer than
+// 3 bytes, or when ReadData receives an input whose length is too large to
+// encode in a 3-byte length prefix.
+var ErrTooLong = errors.New("length prefix is too long")
+
+// ReadData returns a new slice with the contents of the next available data
+// chunk, skipping over any padding chunks that may come first. The returned
+// error value is nil if and only if a data chunk was present and was read in
+// its entirety. The returned error is io.EOF only if r ended before the first
+// byte of a length prefix. If r ended in the middle of a length prefix or
+// data/padding, the returned error is io.ErrUnexpectedEOF.
+func ReadData(r io.Reader) ([]byte, error) {
+	for {
+		var b [1]byte
+		_, err := r.Read(b[:])
+		if err != nil {
+			// This is the only place we may return a real io.EOF.
+			return nil, err
+		}
+		isData := (b[0] & 0x80) != 0
+		moreLength := (b[0] & 0x40) != 0
+		n := int(b[0] & 0x3f)
+		for i := 0; moreLength; i++ {
+			if i >= 2 {
+				return nil, ErrTooLong
+			}
+			_, err := r.Read(b[:])
+			if err == io.EOF {
+				err = io.ErrUnexpectedEOF
+			}
+			if err != nil {
+				return nil, err
+			}
+			moreLength = (b[0] & 0x80) != 0
+			n = (n << 7) | int(b[0]&0x7f)
+		}
+		if isData {
+			p := make([]byte, n)
+			_, err := io.ReadFull(r, p)
+			if err == io.EOF {
+				err = io.ErrUnexpectedEOF
+			}
+			if err != nil {
+				return nil, err
+			}
+			return p, err
+		} else {
+			_, err := io.CopyN(ioutil.Discard, r, int64(n))
+			if err == io.EOF {
+				err = io.ErrUnexpectedEOF
+			}
+			if err != nil {
+				return nil, err
+			}
+		}
+	}
+}
+
+// dataPrefixForLength returns a length prefix for the given length, with the
+// "d" bit set to 1.
+func dataPrefixForLength(n int) ([]byte, error) {
+	switch {
+	case (n>>0)&0x3f == (n >> 0):
+		return []byte{0x80 | byte((n>>0)&0x3f)}, nil
+	case (n>>7)&0x3f == (n >> 7):
+		return []byte{0xc0 | byte((n>>7)&0x3f), byte((n >> 0) & 0x7f)}, nil
+	case (n>>14)&0x3f == (n >> 14):
+		return []byte{0xc0 | byte((n>>14)&0x3f), 0x80 | byte((n>>7)&0x7f), byte((n >> 0) & 0x7f)}, nil
+	default:
+		return nil, ErrTooLong
+	}
+}
+
+// WriteData encodes a data chunk into w. It returns the total number of bytes
+// written; i.e., including the length prefix. The error is ErrTooLong if the
+// length of data cannot fit into a length prefix.
+func WriteData(w io.Writer, data []byte) (int, error) {
+	prefix, err := dataPrefixForLength(len(data))
+	if err != nil {
+		return 0, err
+	}
+	total := 0
+	n, err := w.Write(prefix)
+	total += n
+	if err != nil {
+		return total, err
+	}
+	n, err = w.Write(data)
+	total += n
+	return total, err
+}
+
+var paddingBuffer = make([]byte, 1024)
+
+// WritePadding encodes padding chunks, whose total size (including their own
+// length prefixes) is n. Returns the total number of bytes written to w, which
+// will be exactly n unless there was an error. The error cannot be ErrTooLong
+// because this function will write multiple padding chunks if necessary to
+// reach the requested size. Panics if n is negative.
+func WritePadding(w io.Writer, n int) (int, error) {
+	if n < 0 {
+		panic("negative length")
+	}
+	total := 0
+	for n > 0 {
+		p := len(paddingBuffer)
+		if p > n {
+			p = n
+		}
+		n -= p
+		var prefix []byte
+		switch {
+		case ((p-1)>>0)&0x3f == ((p - 1) >> 0):
+			p = p - 1
+			prefix = []byte{byte((p >> 0) & 0x3f)}
+		case ((p-2)>>7)&0x3f == ((p - 2) >> 7):
+			p = p - 2
+			prefix = []byte{0x40 | byte((p>>7)&0x3f), byte((p >> 0) & 0x7f)}
+		case ((p-3)>>14)&0x3f == ((p - 3) >> 14):
+			p = p - 3
+			prefix = []byte{0x40 | byte((p>>14)&0x3f), 0x80 | byte((p>>7)&0x3f), byte((p >> 0) & 0x7f)}
+		}
+		nn, err := w.Write(prefix)
+		total += nn
+		if err != nil {
+			return total, err
+		}
+		nn, err = w.Write(paddingBuffer[:p])
+		total += nn
+		if err != nil {
+			return total, err
+		}
+	}
+	return total, nil
+}
+
+// MaxDataForSize returns the length of the longest slice that can pe passed to
+// WriteData, whose total encoded size (including length prefix) is no larger
+// than n. Call this to find out if a chunk of data will fit into a length
+// budget. Panics if n == 0.
+func MaxDataForSize(n int) int {
+	if n == 0 {
+		panic("zero length")
+	}
+	prefix, err := dataPrefixForLength(n)
+	if err == ErrTooLong {
+		return (1 << (6 + 7 + 7)) - 1 - 3
+	} else if err != nil {
+		panic(err)
+	}
+	return n - len(prefix)
+}
diff --git a/common/encapsulation/encapsulation_test.go b/common/encapsulation/encapsulation_test.go
new file mode 100644
index 0000000..333abb4
--- /dev/null
+++ b/common/encapsulation/encapsulation_test.go
@@ -0,0 +1,330 @@
+package encapsulation
+
+import (
+	"bytes"
+	"io"
+	"math/rand"
+	"testing"
+)
+
+// Return a byte slice with non-trivial contents.
+func pseudorandomBuffer(n int) []byte {
+	source := rand.NewSource(0)
+	p := make([]byte, n)
+	for i := 0; i < len(p); i++ {
+		p[i] = byte(source.Int63() & 0xff)
+	}
+	return p
+}
+
+func mustWriteData(w io.Writer, p []byte) int {
+	n, err := WriteData(w, p)
+	if err != nil {
+		panic(err)
+	}
+	return n
+}
+
+func mustWritePadding(w io.Writer, n int) int {
+	n, err := WritePadding(w, n)
+	if err != nil {
+		panic(err)
+	}
+	return n
+}
+
+// Test that ReadData(WriteData()) recovers the original data.
+func TestRoundtrip(t *testing.T) {
+	// Test above and below interesting thresholds.
+	for _, i := range []int{
+		0x00, 0x01,
+		0x3e, 0x3f, 0x40, 0x41,
+		0xfe, 0xff, 0x100, 0x101,
+		0x1ffe, 0x1fff, 0x2000, 0x2001,
+		0xfffe, 0xffff, 0x10000, 0x10001,
+		0xffffe, 0xfffff,
+	} {
+		original := pseudorandomBuffer(i)
+		var enc bytes.Buffer
+		n, err := WriteData(&enc, original)
+		if err != nil {
+			t.Fatalf("size %d, WriteData returned error %v", i, err)
+		}
+		if enc.Len() != n {
+			t.Fatalf("size %d, returned length was %d, written length was %d",
+				i, n, enc.Len())
+		}
+		inverse, err := ReadData(&enc)
+		if err != nil {
+			t.Fatalf("size %d, ReadData returned error %v", i, err)
+		}
+		if !bytes.Equal(inverse, original) {
+			t.Fatalf("size %d, got <%x>, expected <%x>", i, inverse, original)
+		}
+	}
+}
+
+// Test that WritePadding writes exactly as much as requested.
+func TestPaddingLength(t *testing.T) {
+	// Test above and below interesting thresholds. WritePadding also gets
+	// values above 0xfffff, the maximum value of a single length prefix.
+	for _, i := range []int{
+		0x00, 0x01,
+		0x3f, 0x40, 0x41, 0x42,
+		0xff, 0x100, 0x101, 0x102,
+		0x2000, 0x2001, 0x2002, 0x2003,
+		0x10000, 0x10001, 0x10002, 0x10003,
+		0x100001, 0x100002, 0x100003, 0x100004,
+	} {
+		var enc bytes.Buffer
+		n, err := WritePadding(&enc, i)
+		if err != nil {
+			t.Fatalf("size %d, WritePadding returned error %v", i, err)
+		}
+		if n != i {
+			t.Fatalf("requested %d bytes, returned %d", i, n)
+		}
+		if enc.Len() != n {
+			t.Fatalf("requested %d bytes, wrote %d bytes", i, enc.Len())
+		}
+	}
+}
+
+// Test that ReadData skips over padding.
+func TestSkipPadding(t *testing.T) {
+	var data = [][]byte{{}, {}, []byte("hello"), {}, []byte("world")}
+	var enc bytes.Buffer
+	mustWritePadding(&enc, 10)
+	mustWritePadding(&enc, 100)
+	mustWriteData(&enc, data[0])
+	mustWriteData(&enc, data[1])
+	mustWritePadding(&enc, 10)
+	mustWriteData(&enc, data[2])
+	mustWriteData(&enc, data[3])
+	mustWritePadding(&enc, 10)
+	mustWriteData(&enc, data[4])
+	mustWritePadding(&enc, 10)
+	mustWritePadding(&enc, 10)
+	for i, expected := range data {
+		actual, err := ReadData(&enc)
+		if err != nil {
+			t.Fatalf("slice %d, got error %v, expected %v", i, err, nil)
+		}
+		if !bytes.Equal(actual, expected) {
+			t.Fatalf("slice %d, got <%x>, expected <%x>", i, actual, expected)
+		}
+	}
+	p, err := ReadData(&enc)
+	if p != nil || err != io.EOF {
+		t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, io.EOF)
+	}
+}
+
+// Test that EOF before a length prefix returns io.EOF.
+func TestEOF(t *testing.T) {
+	p, err := ReadData(bytes.NewReader(nil))
+	if p != nil || err != io.EOF {
+		t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, io.EOF)
+	}
+}
+
+// Test that an EOF while reading a length prefix, or while reading the
+// subsequent data/padding, returns io.ErrUnexpectedEOF.
+func TestUnexpectedEOF(t *testing.T) {
+	for _, test := range [][]byte{
+		{0x40},                  // expecting a second length byte
+		{0xc0},                  // expecting a second length byte
+		{0x41, 0x80},            // expecting a third length byte
+		{0xc1, 0x80},            // expecting a third length byte
+		{0x02},                  // expecting 2 bytes of padding
+		{0x82},                  // expecting 2 bytes of data
+		{0x02, 'X'},             // expecting 1 byte of padding
+		{0x82, 'X'},             // expecting 1 byte of data
+		{0x41, 0x00},            // expecting 128 bytes of padding
+		{0xc1, 0x00},            // expecting 128 bytes of data
+		{0x41, 0x00, 'X'},       // expecting 127 bytes of padding
+		{0xc1, 0x00, 'X'},       // expecting 127 bytes of data
+		{0x41, 0x80, 0x00},      // expecting 32768 bytes of padding
+		{0xc1, 0x80, 0x00},      // expecting 32768 bytes of data
+		{0x41, 0x80, 0x00, 'X'}, // expecting 32767 bytes of padding
+		{0xc1, 0x80, 0x00, 'X'}, // expecting 32767 bytes of data
+	} {
+		p, err := ReadData(bytes.NewReader(test))
+		if p != nil || err != io.ErrUnexpectedEOF {
+			t.Fatalf("<%x> got (<%x>, %v), expected (%v, %v)", test, p, err, nil, io.ErrUnexpectedEOF)
+		}
+	}
+}
+
+// Test that length encodings that are longer than they could be are still
+// interpreted.
+func TestNonMinimalLengthEncoding(t *testing.T) {
+	for _, test := range []struct {
+		enc      []byte
+		expected []byte
+	}{
+		{[]byte{0x81, 'X'}, []byte("X")},
+		{[]byte{0xc0, 0x01, 'X'}, []byte("X")},
+		{[]byte{0xc0, 0x80, 0x01, 'X'}, []byte("X")},
+	} {
+		p, err := ReadData(bytes.NewReader(test.enc))
+		if err != nil {
+			t.Fatalf("<%x> got error %v, expected %v", test.enc, err, nil)
+		}
+		if !bytes.Equal(p, test.expected) {
+			t.Fatalf("<%x> got <%x>, expected <%x>", test.enc, p, test.expected)
+		}
+	}
+}
+
+// Test that ReadData only reads up to 3 bytes of length prefix.
+func TestReadLimits(t *testing.T) {
+	// Test the maximum length that's possible with 3 bytes of length
+	// prefix.
+	maxLength := (0x3f << 14) | (0x7f << 7) | 0x7f
+	data := bytes.Repeat([]byte{'X'}, maxLength)
+	prefix := []byte{0xff, 0xff, 0x7f} // encodes 0xfffff
+	p, err := ReadData(bytes.NewReader(append(prefix, data...)))
+	if err != nil {
+		t.Fatalf("got error %v, expected %v", err, nil)
+	}
+	if !bytes.Equal(p, data) {
+		t.Fatalf("got %d bytes unequal to %d bytes", len(p), len(data))
+	}
+	// Test a 4-byte prefix.
+	prefix = []byte{0xc0, 0xc0, 0x80, 0x80} // encodes 0x100000
+	data = bytes.Repeat([]byte{'X'}, maxLength+1)
+	p, err = ReadData(bytes.NewReader(append(prefix, data...)))
+	if p != nil || err != ErrTooLong {
+		t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
+	}
+	// Test that 4 bytes don't work, even when they encode an integer that
+	// would fix in 3 bytes.
+	prefix = []byte{0xc0, 0x80, 0x80, 0x80} // encodes 0x0
+	data = []byte{}
+	p, err = ReadData(bytes.NewReader(append(prefix, data...)))
+	if p != nil || err != ErrTooLong {
+		t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
+	}
+
+	// Do the same tests with padding lengths.
+	data = []byte("hello")
+	prefix = []byte{0x7f, 0xff, 0x7f} // encodes 0xfffff
+	padding := bytes.Repeat([]byte{'X'}, maxLength)
+	enc := bytes.NewBuffer(append(prefix, padding...))
+	mustWriteData(enc, data)
+	p, err = ReadData(enc)
+	if err != nil {
+		t.Fatalf("got error %v, expected %v", err, nil)
+	}
+	if !bytes.Equal(p, data) {
+		t.Fatalf("got <%x>, expected <%x>", p, data)
+	}
+	prefix = []byte{0x40, 0xc0, 0x80, 0x80} // encodes 0x100000
+	padding = bytes.Repeat([]byte{'X'}, maxLength+1)
+	enc = bytes.NewBuffer(append(prefix, padding...))
+	mustWriteData(enc, data)
+	p, err = ReadData(enc)
+	if p != nil || err != ErrTooLong {
+		t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
+	}
+	prefix = []byte{0x40, 0x80, 0x80, 0x80} // encodes 0x0
+	padding = []byte{}
+	enc = bytes.NewBuffer(append(prefix, padding...))
+	mustWriteData(enc, data)
+	p, err = ReadData(enc)
+	if p != nil || err != ErrTooLong {
+		t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
+	}
+}
+
+// Test that WriteData and WritePadding only accept lengths that can be encoded
+// in up to 3 bytes of length prefix.
+func TestWriteLimits(t *testing.T) {
+	maxLength := (0x3f << 14) | (0x7f << 7) | 0x7f
+	var enc bytes.Buffer
+	n, err := WriteData(&enc, bytes.Repeat([]byte{'X'}, maxLength))
+	if n != maxLength+3 || err != nil {
+		t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, maxLength, nil)
+	}
+	enc.Reset()
+	n, err = WriteData(&enc, bytes.Repeat([]byte{'X'}, maxLength+1))
+	if n != 0 || err != ErrTooLong {
+		t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, 0, ErrTooLong)
+	}
+
+	// Padding gets an extra 3 bytes because the prefix is counted as part
+	// of the length.
+	enc.Reset()
+	n, err = WritePadding(&enc, maxLength+3)
+	if n != maxLength+3 || err != nil {
+		t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, maxLength+3, nil)
+	}
+	// Writing a too-long padding is okay because WritePadding will break it
+	// into smaller chunks.
+	enc.Reset()
+	n, err = WritePadding(&enc, maxLength+4)
+	if n != maxLength+4 || err != nil {
+		t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, maxLength+4, nil)
+	}
+}
+
+// Test that WritePadding panics when given a negative length.
+func TestNegativeLength(t *testing.T) {
+	for _, n := range []int{-1, ^0} {
+		var enc bytes.Buffer
+		panicked, nn, err := testNegativeLengthSub(t, &enc, n)
+		if !panicked {
+			t.Fatalf("WritePadding(%d) returned (%d, %v) instead of panicking", n, nn, err)
+		}
+	}
+}
+
+// Calls WritePadding(w, n) and augments the return value with a flag indicating
+// whether the call panicked.
+func testNegativeLengthSub(t *testing.T, w io.Writer, n int) (panicked bool, nn int, err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			panicked = true
+		}
+	}()
+	t.Helper()
+	nn, err = WritePadding(w, n)
+	return false, n, err
+}
+
+// Test that MaxDataForSize panics when given a 0 length.
+func TestMaxDataForSizeZero(t *testing.T) {
+	defer func() {
+		if r := recover(); r == nil {
+			t.Fatal("didn't panic")
+		}
+	}()
+	MaxDataForSize(0)
+}
+
+// Test thresholds of available sizes for MaxDataForSize.
+func TestMaxDataForSize(t *testing.T) {
+	for _, test := range []struct {
+		size     int
+		expected int
+	}{
+		{0x01, 0x00},
+		{0x02, 0x01},
+		{0x3f, 0x3e},
+		{0x40, 0x3e},
+		{0x41, 0x3f},
+		{0x1fff, 0x1ffd},
+		{0x2000, 0x1ffd},
+		{0x2001, 0x1ffe},
+		{0xfffff, 0xffffc},
+		{0x100000, 0xffffc},
+		{0x100001, 0xffffc},
+		{0x7fffffff, 0xffffc},
+	} {
+		max := MaxDataForSize(test.size)
+		if max != test.expected {
+			t.Fatalf("size %d, got %d, expected %d", test.size, max, test.expected)
+		}
+	}
+}
diff --git a/common/turbotunnel/clientid.go b/common/turbotunnel/clientid.go
new file mode 100644
index 0000000..17257e1
--- /dev/null
+++ b/common/turbotunnel/clientid.go
@@ -0,0 +1,28 @@
+package turbotunnel
+
+import (
+	"crypto/rand"
+	"encoding/hex"
+)
+
+// ClientID is an abstract identifier that binds together all the communications
+// belonging to a single client session, even though those communications may
+// arrive from multiple IP addresses or over multiple lower-level connections.
+// It plays the same role that an (IP address, port number) tuple plays in a
+// net.UDPConn: it's the return address pertaining to a long-lived abstract
+// client session. The client attaches its ClientID to each of its
+// communications, enabling the server to disambiguate requests among its many
+// clients. ClientID implements the net.Addr interface.
+type ClientID [8]byte
+
+func NewClientID() ClientID {
+	var id ClientID
+	_, err := rand.Read(id[:])
+	if err != nil {
+		panic(err)
+	}
+	return id
+}
+
+func (id ClientID) Network() string { return "clientid" }
+func (id ClientID) String() string  { return hex.EncodeToString(id[:]) }
diff --git a/common/turbotunnel/clientmap.go b/common/turbotunnel/clientmap.go
new file mode 100644
index 0000000..fa12915
--- /dev/null
+++ b/common/turbotunnel/clientmap.go
@@ -0,0 +1,144 @@
+package turbotunnel
+
+import (
+	"container/heap"
+	"net"
+	"sync"
+	"time"
+)
+
+// clientRecord is a record of a recently seen client, with the time it was last
+// seen and a send queue.
+type clientRecord struct {
+	Addr      net.Addr
+	LastSeen  time.Time
+	SendQueue chan []byte
+}
+
+// ClientMap manages a mapping of live clients (keyed by address, which will be
+// a ClientID) to their respective send queues. ClientMap's functions are safe
+// to call from multiple goroutines.
+type ClientMap struct {
+	// We use an inner structure to avoid exposing public heap.Interface
+	// functions to users of clientMap.
+	inner clientMapInner
+	// Synchronizes access to inner.
+	lock sync.Mutex
+}
+
+// NewClientMap creates a ClientMap that expires clients after a timeout.
+//
+// The timeout does not have to be kept in sync with QUIC's internal idle
+// timeout. If a client is removed from the client map while the QUIC session is
+// still live, the worst that can happen is a loss of whatever packets were in
+// the send queue at the time. If QUIC later decides to send more packets to the
+// same client, we'll instantiate a new send queue, and if the client ever
+// connects again with the proper client ID, we'll deliver them.
+func NewClientMap(timeout time.Duration) *ClientMap {
+	m := &ClientMap{
+		inner: clientMapInner{
+			byAge:  make([]*clientRecord, 0),
+			byAddr: make(map[net.Addr]int),
+		},
+	}
+	go func() {
+		for {
+			time.Sleep(timeout / 2)
+			now := time.Now()
+			m.lock.Lock()
+			m.inner.removeExpired(now, timeout)
+			m.lock.Unlock()
+		}
+	}()
+	return m
+}
+
+// SendQueue returns the send queue corresponding to addr, creating it if
+// necessary.
+func (m *ClientMap) SendQueue(addr net.Addr) chan []byte {
+	m.lock.Lock()
+	defer m.lock.Unlock()
+	return m.inner.SendQueue(addr, time.Now())
+}
+
+// clientMapInner is the inner type of ClientMap, implementing heap.Interface.
+// byAge is the backing store, a heap ordered by LastSeen time, to facilitate
+// expiring old client records. byAddr is a map from addresses (i.e., ClientIDs)
+// to heap indices, to allow looking up by address. Unlike ClientMap,
+// clientMapInner requires external synchonization.
+type clientMapInner struct {
+	byAge  []*clientRecord
+	byAddr map[net.Addr]int
+}
+
+// removeExpired removes all client records whose LastSeen timestamp is more
+// than timeout in the past.
+func (inner *clientMapInner) removeExpired(now time.Time, timeout time.Duration) {
+	for len(inner.byAge) > 0 && now.Sub(inner.byAge[0].LastSeen) >= timeout {
+		heap.Pop(inner)
+	}
+}
+
+// SendQueue finds the existing client record corresponding to addr, or creates
+// a new one if none exists yet. It updates the client record's LastSeen time
+// and returns its SendQueue.
+func (inner *clientMapInner) SendQueue(addr net.Addr, now time.Time) chan []byte {
+	var record *clientRecord
+	i, ok := inner.byAddr[addr]
+	if ok {
+		// Found one, update its LastSeen.
+		record = inner.byAge[i]
+		record.LastSeen = now
+		heap.Fix(inner, i)
+	} else {
+		// Not found, create a new one.
+		record = &clientRecord{
+			Addr:      addr,
+			LastSeen:  now,
+			SendQueue: make(chan []byte, queueSize),
+		}
+		heap.Push(inner, record)
+	}
+	return record.SendQueue
+}
+
+// heap.Interface for clientMapInner.
+
+func (inner *clientMapInner) Len() int {
+	if len(inner.byAge) != len(inner.byAddr) {
+		panic("inconsistent clientMap")
+	}
+	return len(inner.byAge)
+}
+
+func (inner *clientMapInner) Less(i, j int) bool {
+	return inner.byAge[i].LastSeen.Before(inner.byAge[j].LastSeen)
+}
+
+func (inner *clientMapInner) Swap(i, j int) {
+	inner.byAge[i], inner.byAge[j] = inner.byAge[j], inner.byAge[i]
+	inner.byAddr[inner.byAge[i].Addr] = i
+	inner.byAddr[inner.byAge[j].Addr] = j
+}
+
+func (inner *clientMapInner) Push(x interface{}) {
+	record := x.(*clientRecord)
+	if _, ok := inner.byAddr[record.Addr]; ok {
+		panic("duplicate address in clientMap")
+	}
+	// Insert into byAddr map.
+	inner.byAddr[record.Addr] = len(inner.byAge)
+	// Insert into byAge slice.
+	inner.byAge = append(inner.byAge, record)
+}
+
+func (inner *clientMapInner) Pop() interface{} {
+	n := len(inner.byAddr)
+	// Remove from byAge slice.
+	record := inner.byAge[n-1]
+	inner.byAge[n-1] = nil
+	inner.byAge = inner.byAge[:n-1]
+	// Remove from byAddr map.
+	delete(inner.byAddr, record.Addr)
+	return record
+}
diff --git a/common/turbotunnel/consts.go b/common/turbotunnel/consts.go
new file mode 100644
index 0000000..4699d1d
--- /dev/null
+++ b/common/turbotunnel/consts.go
@@ -0,0 +1,13 @@
+// Package turbotunnel provides support for overlaying a virtual net.PacketConn
+// on some other network carrier.
+//
+// https://github.com/net4people/bbs/issues/9
+package turbotunnel
+
+import "errors"
+
+// The size of receive and send queues.
+const queueSize = 32
+
+var errClosedPacketConn = errors.New("operation on closed connection")
+var errNotImplemented = errors.New("not implemented")
diff --git a/common/turbotunnel/queuepacketconn.go b/common/turbotunnel/queuepacketconn.go
new file mode 100644
index 0000000..14a9833
--- /dev/null
+++ b/common/turbotunnel/queuepacketconn.go
@@ -0,0 +1,137 @@
+package turbotunnel
+
+import (
+	"net"
+	"sync"
+	"sync/atomic"
+	"time"
+)
+
+// taggedPacket is a combination of a []byte and a net.Addr, encapsulating the
+// return type of PacketConn.ReadFrom.
+type taggedPacket struct {
+	P    []byte
+	Addr net.Addr
+}
+
+// QueuePacketConn implements net.PacketConn by storing queues of packets. There
+// is one incoming queue (where packets are additionally tagged by the source
+// address of the client that sent them). There are many outgoing queues, one
+// for each client address that has been recently seen. The QueueIncoming method
+// inserts a packet into the incoming queue, to eventually be returned by
+// ReadFrom. WriteTo inserts a packet into an address-specific outgoing queue,
+// which can later by accessed through the OutgoingQueue method.
+type QueuePacketConn struct {
+	clients   *ClientMap
+	localAddr net.Addr
+	recvQueue chan taggedPacket
+	closeOnce sync.Once
+	closed    chan struct{}
+	// What error to return when the QueuePacketConn is closed.
+	err atomic.Value
+}
+
+// NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients
+// for at least a duration of timeout.
+func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn {
+	return &QueuePacketConn{
+		clients:   NewClientMap(timeout),
+		localAddr: localAddr,
+		recvQueue: make(chan taggedPacket, queueSize),
+		closed:    make(chan struct{}),
+	}
+}
+
+// QueueIncoming queues and incoming packet and its source address, to be
+// returned in a future call to ReadFrom.
+func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
+	select {
+	case <-c.closed:
+		// If we're closed, silently drop it.
+		return
+	default:
+	}
+	// Copy the slice so that the caller may reuse it.
+	buf := make([]byte, len(p))
+	copy(buf, p)
+	select {
+	case c.recvQueue <- taggedPacket{buf, addr}:
+	default:
+		// Drop the incoming packet if the receive queue is full.
+	}
+}
+
+// OutgoingQueue returns the queue of outgoing packets corresponding to addr,
+// creating it if necessary. The contents of the queue will be packets that are
+// written to the address in question using WriteTo.
+func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte {
+	return c.clients.SendQueue(addr)
+}
+
+// ReadFrom returns a packet and address previously stored by QueueIncoming.
+func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
+	select {
+	case <-c.closed:
+		return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
+	default:
+	}
+	select {
+	case <-c.closed:
+		return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
+	case packet := <-c.recvQueue:
+		return copy(p, packet.P), packet.Addr, nil
+	}
+}
+
+// WriteTo queues an outgoing packet for the given address. The queue can later
+// be retrieved using the OutgoingQueue method.
+func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
+	select {
+	case <-c.closed:
+		return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
+	default:
+	}
+	// Copy the slice so that the caller may reuse it.
+	buf := make([]byte, len(p))
+	copy(buf, p)
+	select {
+	case c.clients.SendQueue(addr) <- buf:
+		return len(buf), nil
+	default:
+		// Drop the outgoing packet if the send queue is full.
+		return len(buf), nil
+	}
+}
+
+// closeWithError unblocks pending operations and makes future operations fail
+// with the given error. If err is nil, it becomes errClosedPacketConn.
+func (c *QueuePacketConn) closeWithError(err error) error {
+	var newlyClosed bool
+	c.closeOnce.Do(func() {
+		newlyClosed = true
+		// Store the error to be returned by future PacketConn
+		// operations.
+		if err == nil {
+			err = errClosedPacketConn
+		}
+		c.err.Store(err)
+		close(c.closed)
+	})
+	if !newlyClosed {
+		return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
+	}
+	return nil
+}
+
+// Close unblocks pending operations and makes future operations fail with a
+// "closed connection" error.
+func (c *QueuePacketConn) Close() error {
+	return c.closeWithError(nil)
+}
+
+// LocalAddr returns the localAddr value that was passed to NewQueuePacketConn.
+func (c *QueuePacketConn) LocalAddr() net.Addr { return c.localAddr }
+
+func (c *QueuePacketConn) SetDeadline(t time.Time) error      { return errNotImplemented }
+func (c *QueuePacketConn) SetReadDeadline(t time.Time) error  { return errNotImplemented }
+func (c *QueuePacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }
diff --git a/common/turbotunnel/redialpacketconn.go b/common/turbotunnel/redialpacketconn.go
new file mode 100644
index 0000000..cf6a8c9
--- /dev/null
+++ b/common/turbotunnel/redialpacketconn.go
@@ -0,0 +1,204 @@
+package turbotunnel
+
+import (
+	"context"
+	"errors"
+	"net"
+	"sync"
+	"sync/atomic"
+	"time"
+)
+
+// RedialPacketConn implements a long-lived net.PacketConn atop a sequence of
+// other, transient net.PacketConns. RedialPacketConn creates a new
+// net.PacketConn by calling a provided dialContext function. Whenever the
+// net.PacketConn experiences a ReadFrom or WriteTo error, RedialPacketConn
+// calls the dialContext function again and starts sending and receiving packets
+// on the new net.PacketConn. RedialPacketConn's own ReadFrom and WriteTo
+// methods return an error only when the dialContext function returns an error.
+//
+// RedialPacketConn uses static local and remote addresses that are independent
+// of those of any dialed net.PacketConn.
+type RedialPacketConn struct {
+	localAddr   net.Addr
+	remoteAddr  net.Addr
+	dialContext func(context.Context) (net.PacketConn, error)
+	recvQueue   chan []byte
+	sendQueue   chan []byte
+	closed      chan struct{}
+	closeOnce   sync.Once
+	// The first dial error, which causes the clientPacketConn to be
+	// closed and is returned from future read/write operations. Compare to
+	// the rerr and werr in io.Pipe.
+	err atomic.Value
+}
+
+// NewQueuePacketConn makes a new RedialPacketConn, with the given static local
+// and remote addresses, and dialContext function.
+func NewRedialPacketConn(
+	localAddr, remoteAddr net.Addr,
+	dialContext func(context.Context) (net.PacketConn, error),
+) *RedialPacketConn {
+	c := &RedialPacketConn{
+		localAddr:   localAddr,
+		remoteAddr:  remoteAddr,
+		dialContext: dialContext,
+		recvQueue:   make(chan []byte, queueSize),
+		sendQueue:   make(chan []byte, queueSize),
+		closed:      make(chan struct{}),
+		err:         atomic.Value{},
+	}
+	go c.dialLoop()
+	return c
+}
+
+// dialLoop repeatedly calls c.dialContext and passes the resulting
+// net.PacketConn to c.exchange. It returns only when c is closed or dialContext
+// returns an error.
+func (c *RedialPacketConn) dialLoop() {
+	ctx, cancel := context.WithCancel(context.Background())
+	for {
+		select {
+		case <-c.closed:
+			cancel()
+			return
+		default:
+		}
+		conn, err := c.dialContext(ctx)
+		if err != nil {
+			c.closeWithError(err)
+			cancel()
+			return
+		}
+		c.exchange(conn)
+		conn.Close()
+	}
+}
+
+// exchange calls ReadFrom on the given net.PacketConn and places the resulting
+// packets in the receive queue, and takes packets from the send queue and calls
+// WriteTo on them, making the current net.PacketConn active.
+func (c *RedialPacketConn) exchange(conn net.PacketConn) {
+	readErrCh := make(chan error)
+	writeErrCh := make(chan error)
+
+	go func() {
+		defer close(readErrCh)
+		for {
+			select {
+			case <-c.closed:
+				return
+			case <-writeErrCh:
+				return
+			default:
+			}
+
+			var buf [1500]byte
+			n, _, err := conn.ReadFrom(buf[:])
+			if err != nil {
+				readErrCh <- err
+				return
+			}
+			p := make([]byte, n)
+			copy(p, buf[:])
+			select {
+			case c.recvQueue <- p:
+			default: // OK to drop packets.
+			}
+		}
+	}()
+
+	go func() {
+		defer close(writeErrCh)
+		for {
+			select {
+			case <-c.closed:
+				return
+			case <-readErrCh:
+				return
+			case p := <-c.sendQueue:
+				_, err := conn.WriteTo(p, c.remoteAddr)
+				if err != nil {
+					writeErrCh <- err
+					return
+				}
+			}
+		}
+	}()
+
+	select {
+	case <-readErrCh:
+	case <-writeErrCh:
+	}
+}
+
+// ReadFrom reads a packet from the currently active net.PacketConn. The
+// packet's original remote address is replaced with the RedialPacketConn's own
+// remote address.
+func (c *RedialPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
+	select {
+	case <-c.closed:
+		return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
+	default:
+	}
+	select {
+	case <-c.closed:
+		return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
+	case buf := <-c.recvQueue:
+		return copy(p, buf), c.remoteAddr, nil
+	}
+}
+
+// WriteTo writes a packet to the currently active net.PacketConn. The addr
+// argument is ignored and instead replaced with the RedialPacketConn's own
+// remote address.
+func (c *RedialPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
+	// addr is ignored.
+	select {
+	case <-c.closed:
+		return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
+	default:
+	}
+	buf := make([]byte, len(p))
+	copy(buf, p)
+	select {
+	case c.sendQueue <- buf:
+		return len(buf), nil
+	default:
+		// Drop the outgoing packet if the send queue is full.
+		return len(buf), nil
+	}
+}
+
+// closeWithError unblocks pending operations and makes future operations fail
+// with the given error. If err is nil, it becomes errClosedPacketConn.
+func (c *RedialPacketConn) closeWithError(err error) error {
+	var once bool
+	c.closeOnce.Do(func() {
+		// Store the error to be returned by future read/write
+		// operations.
+		if err == nil {
+			err = errors.New("operation on closed connection")
+		}
+		c.err.Store(err)
+		close(c.closed)
+		once = true
+	})
+	if !once {
+		return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
+	}
+	return nil
+}
+
+// Close unblocks pending operations and makes future operations fail with a
+// "closed connection" error.
+func (c *RedialPacketConn) Close() error {
+	return c.closeWithError(nil)
+}
+
+// LocalAddr returns the localAddr value that was passed to NewRedialPacketConn.
+func (c *RedialPacketConn) LocalAddr() net.Addr { return c.localAddr }
+
+func (c *RedialPacketConn) SetDeadline(t time.Time) error      { return errNotImplemented }
+func (c *RedialPacketConn) SetReadDeadline(t time.Time) error  { return errNotImplemented }
+func (c *RedialPacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }





More information about the tor-commits mailing list