Finalmask: Add Sudoku (TCP & UDP) (#5685)

https://github.com/SUDOKU-ASCII/sudoku/issues/23#issuecomment-3859972396
This commit is contained in:
saba-futai
2026-03-08 02:21:35 +08:00
committed by GitHub
parent a204873d79
commit acb06e831b
12 changed files with 3434 additions and 21 deletions

View File

@@ -30,6 +30,7 @@ import (
"github.com/xtls/xray-core/transport/internet/finalmask/mkcp/original"
"github.com/xtls/xray-core/transport/internet/finalmask/noise"
"github.com/xtls/xray-core/transport/internet/finalmask/salamander"
finalsudoku "github.com/xtls/xray-core/transport/internet/finalmask/sudoku"
"github.com/xtls/xray-core/transport/internet/finalmask/xdns"
"github.com/xtls/xray-core/transport/internet/finalmask/xicmp"
"github.com/xtls/xray-core/transport/internet/httpupgrade"
@@ -1314,6 +1315,7 @@ var (
tcpmaskLoader = NewJSONConfigLoader(ConfigCreatorCache{
"header-custom": func() interface{} { return new(HeaderCustomTCP) },
"fragment": func() interface{} { return new(FragmentMask) },
"sudoku": func() interface{} { return new(Sudoku) },
}, "type", "settings")
udpmaskLoader = NewJSONConfigLoader(ConfigCreatorCache{
@@ -1328,6 +1330,7 @@ var (
"mkcp-aes128gcm": func() interface{} { return new(Aes128Gcm) },
"noise": func() interface{} { return new(NoiseMask) },
"salamander": func() interface{} { return new(Salamander) },
"sudoku": func() interface{} { return new(Sudoku) },
"xdns": func() interface{} { return new(Xdns) },
"xicmp": func() interface{} { return new(Xicmp) },
}, "type", "settings")
@@ -1636,6 +1639,50 @@ func (c *Salamander) Build() (proto.Message, error) {
return config, nil
}
type Sudoku struct {
Password string `json:"password"`
ASCII string `json:"ascii"`
CustomTable string `json:"customTable"`
LegacyCustomTable string `json:"custom_table"`
CustomTables []string `json:"customTables"`
LegacyCustomSets []string `json:"custom_tables"`
PaddingMin uint32 `json:"paddingMin"`
LegacyPaddingMin uint32 `json:"padding_min"`
PaddingMax uint32 `json:"paddingMax"`
LegacyPaddingMax uint32 `json:"padding_max"`
}
func (c *Sudoku) Build() (proto.Message, error) {
customTable := c.CustomTable
if customTable == "" {
customTable = c.LegacyCustomTable
}
customTables := c.CustomTables
if len(customTables) == 0 {
customTables = c.LegacyCustomSets
}
paddingMin := c.PaddingMin
if paddingMin == 0 {
paddingMin = c.LegacyPaddingMin
}
paddingMax := c.PaddingMax
if paddingMax == 0 {
paddingMax = c.LegacyPaddingMax
}
return &finalsudoku.Config{
Password: c.Password,
Ascii: c.ASCII,
CustomTable: customTable,
CustomTables: customTables,
PaddingMin: paddingMin,
PaddingMax: paddingMax,
}, nil
}
type Xdns struct {
Domain string `json:"domain"`
}

View File

@@ -660,7 +660,7 @@ func XtlsFilterTls(buffer buf.MultiBuffer, trafficState *TrafficState, ctx conte
}
}
// UnwrapRawConn support unwrap encryption, stats, tls, utls, reality, proxyproto, uds-wrapper conn and get raw tcp/uds conn from it
// UnwrapRawConn support unwrap encryption, stats, mask wrappers, tls, utls, reality, proxyproto, uds-wrapper conn and get raw tcp/uds conn from it
func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
var readCounter, writerCounter stats.Counter
if conn != nil {
@@ -677,6 +677,7 @@ func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
readCounter = statConn.ReadCounter
writerCounter = statConn.WriteCounter
}
if !isEncryption { // avoids double penetration
if xc, ok := conn.(*tls.Conn); ok {
conn = xc.NetConn()

View File

@@ -0,0 +1,163 @@
package sudoku
import (
"fmt"
"math/rand"
)
var perm4 = [24][4]byte{
{0, 1, 2, 3},
{0, 1, 3, 2},
{0, 2, 1, 3},
{0, 2, 3, 1},
{0, 3, 1, 2},
{0, 3, 2, 1},
{1, 0, 2, 3},
{1, 0, 3, 2},
{1, 2, 0, 3},
{1, 2, 3, 0},
{1, 3, 0, 2},
{1, 3, 2, 0},
{2, 0, 1, 3},
{2, 0, 3, 1},
{2, 1, 0, 3},
{2, 1, 3, 0},
{2, 3, 0, 1},
{2, 3, 1, 0},
{3, 0, 1, 2},
{3, 0, 2, 1},
{3, 1, 0, 2},
{3, 1, 2, 0},
{3, 2, 0, 1},
{3, 2, 1, 0},
}
type codec struct {
tables []*table
rng *rand.Rand
paddingChance int
tableIndex int
}
func newCodec(tables []*table, pMin, pMax int) *codec {
if len(tables) == 0 {
tables = nil
}
rng := newSeededRand()
return &codec{
tables: tables,
rng: rng,
paddingChance: pickPaddingChance(rng, pMin, pMax),
}
}
func pickPaddingChance(rng *rand.Rand, pMin, pMax int) int {
if pMin < 0 {
pMin = 0
}
if pMax < pMin {
pMax = pMin
}
if pMin > 100 {
pMin = 100
}
if pMax > 100 {
pMax = 100
}
if pMax == pMin {
return pMin
}
return pMin + rng.Intn(pMax-pMin+1)
}
func (c *codec) shouldPad() bool {
if c.paddingChance <= 0 {
return false
}
if c.paddingChance >= 100 {
return true
}
return c.rng.Intn(100) < c.paddingChance
}
func (c *codec) currentTable() *table {
if len(c.tables) == 0 {
return nil
}
return c.tables[c.tableIndex%len(c.tables)]
}
func (c *codec) randomPadding(t *table) byte {
pool := t.layout.paddingPool
return pool[c.rng.Intn(len(pool))]
}
func (c *codec) encode(in []byte) ([]byte, error) {
if len(in) == 0 {
return nil, nil
}
out := make([]byte, 0, len(in)*6+8)
for _, b := range in {
t := c.currentTable()
if t == nil {
return nil, fmt.Errorf("sudoku table set missing")
}
if c.shouldPad() {
out = append(out, c.randomPadding(t))
}
enc := t.encode[b]
if len(enc) == 0 {
return nil, fmt.Errorf("sudoku encode table missing for byte %d", b)
}
hints := enc[c.rng.Intn(len(enc))]
perm := perm4[c.rng.Intn(len(perm4))]
for _, idx := range perm {
if c.shouldPad() {
out = append(out, c.randomPadding(t))
}
out = append(out, hints[idx])
}
c.tableIndex++
}
if c.shouldPad() {
if t := c.currentTable(); t != nil {
out = append(out, c.randomPadding(t))
}
}
return out, nil
}
func decodeBytes(tables []*table, tableIndex *int, in []byte, hintBuf []byte, out []byte) ([]byte, []byte, error) {
if len(tables) == 0 {
return hintBuf, out, fmt.Errorf("sudoku table set missing")
}
for _, b := range in {
t := tables[*tableIndex%len(tables)]
if !t.layout.isHint(b) {
continue
}
hintBuf = append(hintBuf, b)
if len(hintBuf) < 4 {
continue
}
keyBytes := sort4([4]byte{hintBuf[0], hintBuf[1], hintBuf[2], hintBuf[3]})
key := packKey(keyBytes)
decoded, ok := t.decode[key]
if !ok {
return hintBuf[:0], out, fmt.Errorf("invalid sudoku hint tuple")
}
out = append(out, decoded)
hintBuf = hintBuf[:0]
*tableIndex++
}
return hintBuf, out, nil
}

View File

@@ -0,0 +1,57 @@
package sudoku
import (
"net"
"github.com/xtls/xray-core/common/errors"
)
func (c *Config) TCP() {
}
func (c *Config) UDP() {
}
// Sudoku in finalmask mode is a pure appearance transform with no standalone handshake.
// TCP always keeps classic sudoku on uplink and uses packed downlink optimization on server writes.
func (c *Config) WrapConnClient(raw net.Conn) (net.Conn, error) {
return newPackedDirectionalConn(raw, c, true)
}
func (c *Config) WrapConnServer(raw net.Conn) (net.Conn, error) {
return newPackedDirectionalConn(raw, c, false)
}
func newPackedDirectionalConn(raw net.Conn, config *Config, readPacked bool) (net.Conn, error) {
pureReader, pureWriter, err := newPureReaderWriter(raw, config)
if err != nil {
return nil, err
}
packedReader, packedWriter, err := newPackedReaderWriter(raw, config)
if err != nil {
return nil, err
}
reader, writer := pureReader, pureWriter
if readPacked {
reader = packedReader
} else {
writer = packedWriter
}
return newWrappedConn(raw, reader, writer), nil
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
if level != levelCount {
return nil, errors.New("sudoku udp mask must be the innermost mask in chain")
}
return NewUDPConn(raw, c)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
if level != levelCount {
return nil, errors.New("sudoku udp mask must be the innermost mask in chain")
}
return NewUDPConn(raw, c)
}

View File

@@ -0,0 +1,170 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v6.33.5
// source: transport/internet/finalmask/sudoku/config.proto
package sudoku
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"`
Ascii string `protobuf:"bytes,2,opt,name=ascii,proto3" json:"ascii,omitempty"`
CustomTable string `protobuf:"bytes,3,opt,name=custom_table,json=customTable,proto3" json:"custom_table,omitempty"`
PaddingMin uint32 `protobuf:"varint,4,opt,name=padding_min,json=paddingMin,proto3" json:"padding_min,omitempty"`
PaddingMax uint32 `protobuf:"varint,5,opt,name=padding_max,json=paddingMax,proto3" json:"padding_max,omitempty"`
CustomTables []string `protobuf:"bytes,7,rep,name=custom_tables,json=customTables,proto3" json:"custom_tables,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Config) Reset() {
*x = Config{}
mi := &file_transport_internet_finalmask_sudoku_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Config) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Config) ProtoMessage() {}
func (x *Config) ProtoReflect() protoreflect.Message {
mi := &file_transport_internet_finalmask_sudoku_config_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_sudoku_config_proto_rawDescGZIP(), []int{0}
}
func (x *Config) GetPassword() string {
if x != nil {
return x.Password
}
return ""
}
func (x *Config) GetAscii() string {
if x != nil {
return x.Ascii
}
return ""
}
func (x *Config) GetCustomTable() string {
if x != nil {
return x.CustomTable
}
return ""
}
func (x *Config) GetPaddingMin() uint32 {
if x != nil {
return x.PaddingMin
}
return 0
}
func (x *Config) GetPaddingMax() uint32 {
if x != nil {
return x.PaddingMax
}
return 0
}
func (x *Config) GetCustomTables() []string {
if x != nil {
return x.CustomTables
}
return nil
}
var File_transport_internet_finalmask_sudoku_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_sudoku_config_proto_rawDesc = "" +
"\n" +
"0transport/internet/finalmask/sudoku/config.proto\x12(xray.transport.internet.finalmask.sudoku\"\xc4\x01\n" +
"\x06Config\x12\x1a\n" +
"\bpassword\x18\x01 \x01(\tR\bpassword\x12\x14\n" +
"\x05ascii\x18\x02 \x01(\tR\x05ascii\x12!\n" +
"\fcustom_table\x18\x03 \x01(\tR\vcustomTable\x12\x1f\n" +
"\vpadding_min\x18\x04 \x01(\rR\n" +
"paddingMin\x12\x1f\n" +
"\vpadding_max\x18\x05 \x01(\rR\n" +
"paddingMax\x12#\n" +
"\rcustom_tables\x18\a \x03(\tR\fcustomTablesB\x9a\x01\n" +
",com.xray.transport.internet.finalmask.sudokuP\x01Z=github.com/xtls/xray-core/transport/internet/finalmask/sudoku\xaa\x02(Xray.Transport.Internet.Finalmask.Sudokub\x06proto3"
var (
file_transport_internet_finalmask_sudoku_config_proto_rawDescOnce sync.Once
file_transport_internet_finalmask_sudoku_config_proto_rawDescData []byte
)
func file_transport_internet_finalmask_sudoku_config_proto_rawDescGZIP() []byte {
file_transport_internet_finalmask_sudoku_config_proto_rawDescOnce.Do(func() {
file_transport_internet_finalmask_sudoku_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_sudoku_config_proto_rawDesc), len(file_transport_internet_finalmask_sudoku_config_proto_rawDesc)))
})
return file_transport_internet_finalmask_sudoku_config_proto_rawDescData
}
var file_transport_internet_finalmask_sudoku_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_sudoku_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.finalmask.sudoku.Config
}
var file_transport_internet_finalmask_sudoku_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_transport_internet_finalmask_sudoku_config_proto_init() }
func file_transport_internet_finalmask_sudoku_config_proto_init() {
if File_transport_internet_finalmask_sudoku_config_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_sudoku_config_proto_rawDesc), len(file_transport_internet_finalmask_sudoku_config_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_transport_internet_finalmask_sudoku_config_proto_goTypes,
DependencyIndexes: file_transport_internet_finalmask_sudoku_config_proto_depIdxs,
MessageInfos: file_transport_internet_finalmask_sudoku_config_proto_msgTypes,
}.Build()
File_transport_internet_finalmask_sudoku_config_proto = out.File
file_transport_internet_finalmask_sudoku_config_proto_goTypes = nil
file_transport_internet_finalmask_sudoku_config_proto_depIdxs = nil
}

View File

@@ -0,0 +1,16 @@
syntax = "proto3";
package xray.transport.internet.finalmask.sudoku;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Sudoku";
option go_package = "github.com/xtls/xray-core/transport/internet/finalmask/sudoku";
option java_package = "com.xray.transport.internet.finalmask.sudoku";
option java_multiple_files = true;
message Config {
string password = 1;
string ascii = 2;
string custom_table = 3;
uint32 padding_min = 4;
uint32 padding_max = 5;
repeated string custom_tables = 7;
}

View File

@@ -0,0 +1,212 @@
package sudoku
import (
"bufio"
"io"
"net"
"sync"
"github.com/xtls/xray-core/transport/internet/finalmask"
)
const ioBufferSize = 32 * 1024
var _ finalmask.TcpMaskConn = (*wrappedConn)(nil)
type streamDecoder interface {
decodeChunk(in []byte, pending []byte) ([]byte, error)
reset()
}
type streamReader struct {
reader *bufio.Reader
rawBuf []byte
pending []byte
decode streamDecoder
mu sync.Mutex
}
func newStreamReader(raw net.Conn, decode streamDecoder) io.Reader {
return &streamReader{
reader: bufio.NewReaderSize(raw, ioBufferSize),
rawBuf: make([]byte, ioBufferSize),
pending: make([]byte, 0, 4096),
decode: decode,
}
}
func (r *streamReader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if n, ok := drainPending(p, &r.pending); ok {
return n, nil
}
for len(r.pending) == 0 {
nr, rErr := r.reader.Read(r.rawBuf)
if nr > 0 {
var dErr error
r.pending, dErr = r.decode.decodeChunk(r.rawBuf[:nr], r.pending)
if dErr != nil {
return 0, dErr
}
}
if rErr != nil {
if rErr == io.EOF {
r.decode.reset()
if len(r.pending) > 0 {
break
}
}
return 0, rErr
}
}
n, _ := drainPending(p, &r.pending)
return n, nil
}
type streamWriter struct {
conn net.Conn
encode func([]byte) ([]byte, error)
mu sync.Mutex
}
func newStreamWriter(raw net.Conn, encode func([]byte) ([]byte, error)) io.Writer {
return &streamWriter{
conn: raw,
encode: encode,
}
}
func (w *streamWriter) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
w.mu.Lock()
defer w.mu.Unlock()
encoded, err := w.encode(p)
if err != nil {
return 0, err
}
if err := writeAll(w.conn, encoded); err != nil {
return 0, err
}
return len(p), nil
}
type wrappedConn struct {
net.Conn
reader io.Reader
writer io.Writer
}
type closeWriteConn interface {
CloseWrite() error
}
func newWrappedConn(raw net.Conn, reader io.Reader, writer io.Writer) net.Conn {
return &wrappedConn{
Conn: raw,
reader: reader,
writer: writer,
}
}
func (c *wrappedConn) Read(p []byte) (int, error) {
return c.reader.Read(p)
}
func (c *wrappedConn) Write(p []byte) (int, error) {
return c.writer.Write(p)
}
func (c *wrappedConn) TcpMaskConn() {}
func (c *wrappedConn) RawConn() net.Conn {
return c.Conn
}
func (c *wrappedConn) Splice() bool {
// Sudoku transforms the entire stream; bypassing it would disable masking.
return false
}
func (c *wrappedConn) CloseWrite() error {
if raw, ok := c.Conn.(closeWriteConn); ok {
return raw.CloseWrite()
}
return net.ErrClosed
}
func NewTCPConn(raw net.Conn, config *Config) (net.Conn, error) {
reader, writer, err := newPureReaderWriter(raw, config)
if err != nil {
return nil, err
}
return newWrappedConn(raw, reader, writer), nil
}
func newPureReaderWriter(raw net.Conn, config *Config) (io.Reader, io.Writer, error) {
tables, err := getTables(config)
if err != nil {
return nil, nil, err
}
pMin, pMax := normalizedPadding(config)
c := newCodec(tables, pMin, pMax)
return newStreamReader(raw, newHintStreamDecoder(tables)), newStreamWriter(raw, c.encode), nil
}
type hintStreamDecoder struct {
tables []*table
tableIndex int
hintBuf []byte
}
func newHintStreamDecoder(tables []*table) *hintStreamDecoder {
return &hintStreamDecoder{
tables: tables,
hintBuf: make([]byte, 0, 4),
}
}
func (d *hintStreamDecoder) decodeChunk(in []byte, pending []byte) ([]byte, error) {
var err error
d.hintBuf, pending, err = decodeBytes(d.tables, &d.tableIndex, in, d.hintBuf, pending)
return pending, err
}
func (d *hintStreamDecoder) reset() {}
func drainPending(p []byte, pending *[]byte) (int, bool) {
if len(*pending) == 0 {
return 0, false
}
n := copy(p, *pending)
if n >= len(*pending) {
*pending = (*pending)[:0]
return n, true
}
remaining := len(*pending) - n
copy(*pending, (*pending)[n:])
*pending = (*pending)[:remaining]
return n, true
}
func writeAll(conn net.Conn, b []byte) error {
for len(b) > 0 {
n, err := conn.Write(b)
if err != nil {
return err
}
b = b[n:]
}
return nil
}

View File

@@ -0,0 +1,182 @@
package sudoku
import (
"fmt"
"io"
"net"
)
type packedEncoder struct {
layouts []*byteLayout
codec *codec
groupIndex int
}
func newPackedEncoder(tables []*table, pMin, pMax int) *packedEncoder {
layouts := make([]*byteLayout, 0, len(tables))
for _, t := range tables {
layouts = append(layouts, t.layout)
}
if len(layouts) == 0 {
layouts = append(layouts, entropyLayout())
}
return &packedEncoder{
layouts: layouts,
codec: newCodec(nil, pMin, pMax),
}
}
func (e *packedEncoder) encode(p []byte) ([]byte, error) {
out := make([]byte, 0, len(p)*2+8)
var bitBuf uint64
var bitCount uint8
for _, b := range p {
bitBuf = (bitBuf << 8) | uint64(b)
bitCount += 8
for bitCount >= 6 {
bitCount -= 6
layout := e.layouts[e.groupIndex%len(e.layouts)]
group := byte(bitBuf >> bitCount)
out = e.maybePad(out, layout)
out = append(out, layout.encodeGroup(group&0x3f))
e.groupIndex++
if bitCount > 0 {
bitBuf &= (uint64(1) << bitCount) - 1
} else {
bitBuf = 0
}
}
}
if bitCount > 0 {
layout := e.layouts[e.groupIndex%len(e.layouts)]
group := byte(bitBuf << (6 - bitCount))
out = e.maybePad(out, layout)
out = append(out, layout.encodeGroup(group&0x3f))
e.groupIndex++
nextLayout := e.layouts[e.groupIndex%len(e.layouts)]
out = append(out, nextLayout.padMarker)
}
out = e.maybePad(out, e.layouts[e.groupIndex%len(e.layouts)])
return out, nil
}
func (e *packedEncoder) maybePad(out []byte, layout *byteLayout) []byte {
if !e.codec.shouldPad() {
return out
}
if len(layout.paddingPool) == 1 {
return append(out, layout.paddingPool[0])
}
for {
b := layout.paddingPool[e.codec.rng.Intn(len(layout.paddingPool))]
if b != layout.padMarker {
return append(out, b)
}
}
}
type packedStreamDecoder struct {
layouts []*byteLayout
groupIndex int
bitBuf uint64
bitCount int
}
func (d *packedStreamDecoder) decodeChunk(in []byte, pending []byte) ([]byte, error) {
var err error
d.bitBuf, d.bitCount, d.groupIndex, pending, err = decodePackedBytes(
d.layouts,
in,
d.bitBuf,
d.bitCount,
d.groupIndex,
pending,
)
return pending, err
}
func (d *packedStreamDecoder) reset() {
d.bitBuf = 0
d.bitCount = 0
}
func NewPackedTCPConn(raw net.Conn, config *Config) (net.Conn, error) {
reader, writer, err := newPackedReaderWriter(raw, config)
if err != nil {
return nil, err
}
return newWrappedConn(raw, reader, writer), nil
}
func newPackedReaderWriter(raw net.Conn, config *Config) (io.Reader, io.Writer, error) {
tables, err := getTables(config)
if err != nil {
return nil, nil, err
}
pMin, pMax := normalizedPadding(config)
encoder := newPackedEncoder(tables, pMin, pMax)
decoder := &packedStreamDecoder{
layouts: tablesToLayouts(tables),
}
return newStreamReader(raw, decoder), newStreamWriter(raw, encoder.encode), nil
}
func tablesToLayouts(tables []*table) []*byteLayout {
layouts := make([]*byteLayout, 0, len(tables))
for _, t := range tables {
layouts = append(layouts, t.layout)
}
if len(layouts) == 0 {
layouts = append(layouts, entropyLayout())
}
return layouts
}
func decodePackedBytes(
layouts []*byteLayout,
in []byte,
bitBuf uint64,
bitCount int,
groupIndex int,
out []byte,
) (uint64, int, int, []byte, error) {
if len(layouts) == 0 {
return bitBuf, bitCount, groupIndex, out, fmt.Errorf("sudoku layout set missing")
}
for _, b := range in {
layout := layouts[groupIndex%len(layouts)]
if !layout.isHint(b) {
if b == layout.padMarker {
bitBuf = 0
bitCount = 0
}
continue
}
group, ok := layout.decodeGroup(b)
if !ok {
return bitBuf, bitCount, groupIndex, out, fmt.Errorf("invalid packed sudoku byte: %d", b)
}
groupIndex++
bitBuf = (bitBuf << 6) | uint64(group)
bitCount += 6
for bitCount >= 8 {
bitCount -= 8
out = append(out, byte(bitBuf>>bitCount))
if bitCount > 0 {
bitBuf &= (uint64(1) << bitCount) - 1
} else {
bitBuf = 0
}
}
}
return bitBuf, bitCount, groupIndex, out, nil
}

View File

@@ -0,0 +1,106 @@
package sudoku
import (
"io"
"net"
"sync"
"time"
)
type udpConn struct {
conn net.PacketConn
tables []*table
pMin int
pMax int
readBuf []byte
readMu sync.Mutex
writeMu sync.Mutex
}
func NewUDPConn(raw net.PacketConn, config *Config) (net.PacketConn, error) {
tables, err := getTables(config)
if err != nil {
return nil, err
}
pMin, pMax := normalizedPadding(config)
return &udpConn{
conn: raw,
tables: tables,
pMin: pMin,
pMax: pMax,
readBuf: make([]byte, 65535),
}, nil
}
func (c *udpConn) Size() int32 {
return 0
}
func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
c.readMu.Lock()
defer c.readMu.Unlock()
n, addr, err = c.conn.ReadFrom(c.readBuf)
if err != nil {
return n, addr, err
}
decoded := make([]byte, 0, n/4+1)
hints := make([]byte, 0, 4)
tableIndex := 0
hints, decoded, err = decodeBytes(c.tables, &tableIndex, c.readBuf[:n], hints, decoded)
if err != nil {
return 0, addr, err
}
if len(hints) != 0 {
return 0, addr, io.ErrUnexpectedEOF
}
if len(p) < len(decoded) {
return 0, addr, io.ErrShortBuffer
}
copy(p, decoded)
return len(decoded), addr, nil
}
func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()
// UDP decoding restarts at table 0 for every datagram, so encoding must do the same.
encoded, err := newCodec(c.tables, c.pMin, c.pMax).encode(p)
if err != nil {
return 0, err
}
nn, err := c.conn.WriteTo(encoded, addr)
if err != nil {
return 0, err
}
if nn != len(encoded) {
return 0, io.ErrShortWrite
}
return len(p), nil
}
func (c *udpConn) Close() error {
return c.conn.Close()
}
func (c *udpConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *udpConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *udpConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *udpConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,580 @@
package sudoku
import (
crypto_rand "crypto/rand"
"crypto/sha256"
"encoding/binary"
"fmt"
"math/bits"
"math/rand"
"sort"
"strings"
"sync"
"time"
)
type table struct {
encode [256][][4]byte
decode map[uint32]byte
layout *byteLayout
}
type tableCacheKey struct {
password string
ascii string
customTable string
}
var (
tableCache sync.Map
tableSetCache sync.Map
basePatternsOnce sync.Once
basePatterns [][][4]byte
basePatternsErr error
)
type byteLayout struct {
hintMask byte
hintValue byte
padMarker byte
paddingPool []byte
encodeHint func(group byte) byte
encodeGroup func(group byte) byte
decodeGroup func(b byte) (byte, bool)
}
func (l *byteLayout) isHint(b byte) bool {
if (b & l.hintMask) == l.hintValue {
return true
}
// ASCII layout maps 0x7f to '\n' to avoid DEL on the wire.
return l.hintMask == 0x40 && b == '\n'
}
func getTable(config *Config) (*table, error) {
tables, err := getTables(config)
if err != nil {
return nil, err
}
if len(tables) == 0 {
return nil, fmt.Errorf("empty sudoku table set")
}
return tables[0], nil
}
func getTables(config *Config) ([]*table, error) {
if config == nil {
return nil, fmt.Errorf("nil sudoku config")
}
mode, err := normalizeASCII(config.GetAscii())
if err != nil {
return nil, err
}
patterns, err := normalizedCustomPatterns(config, mode)
if err != nil {
return nil, err
}
cacheKey := tableCacheKey{
password: config.GetPassword(),
ascii: mode,
customTable: strings.Join(patterns, "\x00"),
}
if cached, ok := tableSetCache.Load(cacheKey); ok {
return cached.([]*table), nil
}
tables := make([]*table, 0, len(patterns))
for _, pattern := range patterns {
layout, err := resolveLayout(mode, pattern)
if err != nil {
return nil, err
}
t, err := buildTable(config.GetPassword(), layout)
if err != nil {
return nil, err
}
tables = append(tables, t)
}
actual, _ := tableSetCache.LoadOrStore(cacheKey, tables)
return actual.([]*table), nil
}
func normalizedCustomPatterns(config *Config, mode string) ([]string, error) {
if config == nil {
return []string{""}, nil
}
if mode == "prefer_ascii" {
return []string{""}, nil
}
rawPatterns := config.GetCustomTables()
if len(rawPatterns) == 0 {
rawPatterns = []string{config.GetCustomTable()}
}
patterns := make([]string, 0, len(rawPatterns))
seen := make(map[string]struct{}, len(rawPatterns))
for _, raw := range rawPatterns {
pattern := strings.TrimSpace(raw)
if pattern != "" {
var err error
pattern, err = normalizeCustomTable(pattern)
if err != nil {
return nil, err
}
}
if _, ok := seen[pattern]; ok {
continue
}
seen[pattern] = struct{}{}
patterns = append(patterns, pattern)
}
if len(patterns) == 0 {
return []string{""}, nil
}
return patterns, nil
}
func normalizedPadding(config *Config) (int, int) {
if config == nil {
return 0, 0
}
pMin := int(config.GetPaddingMin())
pMax := int(config.GetPaddingMax())
if pMin > 100 {
pMin = 100
}
if pMax > 100 {
pMax = 100
}
if pMax < pMin {
pMax = pMin
}
return pMin, pMax
}
func normalizeASCII(mode string) (string, error) {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "", "entropy", "prefer_entropy":
return "prefer_entropy", nil
case "ascii", "prefer_ascii":
return "prefer_ascii", nil
default:
return "", fmt.Errorf("invalid sudoku ascii mode: %s", mode)
}
}
func normalizeCustomTable(pattern string) (string, error) {
cleaned := strings.ToLower(strings.TrimSpace(pattern))
cleaned = strings.ReplaceAll(cleaned, " ", "")
if len(cleaned) != 8 {
return "", fmt.Errorf("customTable must be 8 chars, got %d", len(cleaned))
}
var xCount, pCount, vCount int
for _, ch := range cleaned {
switch ch {
case 'x':
xCount++
case 'p':
pCount++
case 'v':
vCount++
default:
return "", fmt.Errorf("customTable has invalid char %q", ch)
}
}
if xCount != 2 || pCount != 2 || vCount != 4 {
return "", fmt.Errorf("customTable must contain exactly 2 x, 2 p and 4 v")
}
return cleaned, nil
}
func resolveLayout(mode, customTable string) (*byteLayout, error) {
if mode == "prefer_ascii" {
return asciiLayout(), nil
}
if customTable != "" {
return customLayout(customTable)
}
return entropyLayout(), nil
}
func asciiLayout() *byteLayout {
padding := make([]byte, 0, 32)
for i := 0; i < 32; i++ {
padding = append(padding, byte(0x20+i))
}
encodeGroup := func(group byte) byte {
b := byte(0x40 | (group & 0x3f))
if b == 0x7f {
return '\n'
}
return b
}
return &byteLayout{
hintMask: 0x40,
hintValue: 0x40,
padMarker: 0x3f,
paddingPool: padding,
encodeHint: encodeGroup,
encodeGroup: encodeGroup,
decodeGroup: func(b byte) (byte, bool) {
if b == '\n' {
return 0x3f, true
}
if (b & 0x40) == 0 {
return 0, false
}
return b & 0x3f, true
},
}
}
func entropyLayout() *byteLayout {
padding := make([]byte, 0, 16)
for i := 0; i < 8; i++ {
padding = append(padding, byte(0x80+i), byte(0x10+i))
}
encodeGroup := func(group byte) byte {
v := group & 0x3f
return ((v & 0x30) << 1) | (v & 0x0f)
}
return &byteLayout{
hintMask: 0x90,
hintValue: 0x00,
padMarker: 0x80,
paddingPool: padding,
encodeHint: encodeGroup,
encodeGroup: encodeGroup,
decodeGroup: func(b byte) (byte, bool) {
if (b & 0x90) != 0 {
return 0, false
}
return ((b >> 1) & 0x30) | (b & 0x0f), true
},
}
}
func customLayout(pattern string) (*byteLayout, error) {
pattern, err := normalizeCustomTable(pattern)
if err != nil {
return nil, err
}
var xBits, pBits, vBits []uint8
for i, c := range pattern {
bit := uint8(7 - i)
switch c {
case 'x':
xBits = append(xBits, bit)
case 'p':
pBits = append(pBits, bit)
case 'v':
vBits = append(vBits, bit)
}
}
xMask := byte(0)
for _, bit := range xBits {
xMask |= 1 << bit
}
encodeGroupWithDropX := func(group byte, dropX int) byte {
out := xMask
if dropX >= 0 {
out &^= 1 << xBits[dropX]
}
val := (group >> 4) & 0x03
pos := group & 0x0f
if (val & 0x02) != 0 {
out |= 1 << pBits[0]
}
if (val & 0x01) != 0 {
out |= 1 << pBits[1]
}
for i, bit := range vBits {
if (pos>>(3-uint8(i)))&0x01 == 1 {
out |= 1 << bit
}
}
return out
}
paddingSet := make(map[byte]struct{}, 64)
padding := make([]byte, 0, 64)
for drop := range xBits {
for val := byte(0); val < 4; val++ {
for pos := byte(0); pos < 16; pos++ {
group := (val << 4) | pos
b := encodeGroupWithDropX(group, drop)
if bits.OnesCount8(b) >= 5 {
if _, exists := paddingSet[b]; !exists {
paddingSet[b] = struct{}{}
padding = append(padding, b)
}
}
}
}
}
sort.Slice(padding, func(i, j int) bool { return padding[i] < padding[j] })
if len(padding) == 0 {
return nil, fmt.Errorf("customTable produced empty padding pool")
}
decodeGroup := func(b byte) (byte, bool) {
if (b & xMask) != xMask {
return 0, false
}
var val, pos byte
if b&(1<<pBits[0]) != 0 {
val |= 0x02
}
if b&(1<<pBits[1]) != 0 {
val |= 0x01
}
for i, bit := range vBits {
if b&(1<<bit) != 0 {
pos |= 1 << (3 - uint8(i))
}
}
return ((val & 0x03) << 4) | (pos & 0x0f), true
}
encodeGroup := func(group byte) byte {
return encodeGroupWithDropX(group, -1)
}
return &byteLayout{
hintMask: xMask,
hintValue: xMask,
padMarker: padding[0],
paddingPool: padding,
encodeHint: encodeGroup,
encodeGroup: encodeGroup,
decodeGroup: decodeGroup,
}, nil
}
func buildTable(password string, layout *byteLayout) (*table, error) {
patterns, err := getBasePatterns()
if err != nil {
return nil, err
}
if len(patterns) < 256 {
return nil, fmt.Errorf("not enough sudoku grids: %d", len(patterns))
}
order := make([]int, len(patterns))
for i := range order {
order[i] = i
}
hash := sha256.Sum256([]byte(password))
seed := int64(binary.BigEndian.Uint64(hash[:8]))
rng := rand.New(rand.NewSource(seed))
rng.Shuffle(len(order), func(i, j int) {
order[i], order[j] = order[j], order[i]
})
t := &table{
decode: make(map[uint32]byte, 1<<16),
layout: layout,
}
for b := 0; b < 256; b++ {
patList := patterns[order[b]]
if len(patList) == 0 {
return nil, fmt.Errorf("grid %d has no valid clue set", order[b])
}
enc := make([][4]byte, 0, len(patList))
for _, groups := range patList {
hints := [4]byte{
layout.encodeHint(groups[0]),
layout.encodeHint(groups[1]),
layout.encodeHint(groups[2]),
layout.encodeHint(groups[3]),
}
sortedHints := sort4(hints)
key := packKey(sortedHints)
if old, exists := t.decode[key]; exists && old != byte(b) {
return nil, fmt.Errorf("decode key collision for byte %d and %d", old, b)
}
t.decode[key] = byte(b)
enc = append(enc, hints)
}
t.encode[b] = enc
}
return t, nil
}
func getBasePatterns() ([][][4]byte, error) {
basePatternsOnce.Do(func() {
basePatterns, basePatternsErr = buildBasePatterns()
})
return basePatterns, basePatternsErr
}
type grid [16]byte
func buildBasePatterns() ([][][4]byte, error) {
grids := generateAllGrids()
positions := hintPositions()
patterns := make([][][4]byte, len(grids))
for _, ps := range positions {
counts := make(map[uint32]uint16, len(grids))
keys := make([]uint32, len(grids))
groupsByGrid := make([][4]byte, len(grids))
for gi, g := range grids {
groups := [4]byte{
clueGroup(g, ps[0]),
clueGroup(g, ps[1]),
clueGroup(g, ps[2]),
clueGroup(g, ps[3]),
}
groups = sort4(groups)
key := packKey(groups)
keys[gi] = key
groupsByGrid[gi] = groups
counts[key]++
}
for gi, key := range keys {
if counts[key] == 1 {
patterns[gi] = append(patterns[gi], groupsByGrid[gi])
}
}
}
for gi, list := range patterns {
if len(list) == 0 {
return nil, fmt.Errorf("grid %d has no uniquely decodable clue set", gi)
}
}
return patterns, nil
}
func clueGroup(g grid, pos byte) byte {
// 2 bits of value + 4 bits of position.
return ((g[pos] - 1) << 4) | (pos & 0x0f)
}
func generateAllGrids() []grid {
grids := make([]grid, 0, 288)
var g grid
var dfs func(idx int)
dfs = func(idx int) {
if idx == 16 {
grids = append(grids, g)
return
}
row := idx / 4
col := idx % 4
boxRow := (row / 2) * 2
boxCol := (col / 2) * 2
for num := byte(1); num <= 4; num++ {
valid := true
for i := 0; i < 4; i++ {
if g[row*4+i] == num || g[i*4+col] == num {
valid = false
break
}
}
if !valid {
continue
}
for r := 0; r < 2 && valid; r++ {
for c := 0; c < 2; c++ {
if g[(boxRow+r)*4+(boxCol+c)] == num {
valid = false
break
}
}
}
if !valid {
continue
}
g[idx] = num
dfs(idx + 1)
g[idx] = 0
}
}
dfs(0)
return grids
}
func hintPositions() [][4]byte {
// C(16, 4) = 1820.
positions := make([][4]byte, 0, 1820)
for a := 0; a < 13; a++ {
for b := a + 1; b < 14; b++ {
for c := b + 1; c < 15; c++ {
for d := c + 1; d < 16; d++ {
positions = append(positions, [4]byte{byte(a), byte(b), byte(c), byte(d)})
}
}
}
}
return positions
}
func packKey(in [4]byte) uint32 {
return uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
}
func sort4(in [4]byte) [4]byte {
if in[0] > in[1] {
in[0], in[1] = in[1], in[0]
}
if in[2] > in[3] {
in[2], in[3] = in[3], in[2]
}
if in[0] > in[2] {
in[0], in[2] = in[2], in[0]
}
if in[1] > in[3] {
in[1], in[3] = in[3], in[1]
}
if in[1] > in[2] {
in[1], in[2] = in[2], in[1]
}
return in
}
func newSeededRand() *rand.Rand {
seed := time.Now().UnixNano()
var seedBytes [8]byte
if _, err := crypto_rand.Read(seedBytes[:]); err == nil {
seed = int64(binary.BigEndian.Uint64(seedBytes[:]))
}
return rand.New(rand.NewSource(seed))
}

View File

@@ -2,10 +2,13 @@ package finalmask_test
import (
"bytes"
"io"
"net"
"sync/atomic"
"testing"
"time"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/transport/internet/finalmask"
"github.com/xtls/xray-core/transport/internet/finalmask/header/custom"
"github.com/xtls/xray-core/transport/internet/finalmask/header/dns"
@@ -16,6 +19,7 @@ import (
"github.com/xtls/xray-core/transport/internet/finalmask/mkcp/aes128gcm"
"github.com/xtls/xray-core/transport/internet/finalmask/mkcp/original"
"github.com/xtls/xray-core/transport/internet/finalmask/salamander"
"github.com/xtls/xray-core/transport/internet/finalmask/sudoku"
)
func mustSendRecv(
@@ -51,6 +55,22 @@ func mustSendRecv(
type layerMask struct {
name string
mask finalmask.Udpmask
layers int
}
type countingConn struct {
net.Conn
written atomic.Int64
}
func (c *countingConn) Write(p []byte) (int, error) {
n, err := c.Conn.Write(p)
c.written.Add(int64(n))
return n, err
}
func (c *countingConn) Written() int64 {
return c.written.Load()
}
func TestPacketConnReadWrite(t *testing.T) {
@@ -58,30 +78,68 @@ func TestPacketConnReadWrite(t *testing.T) {
{
name: "aes128gcm",
mask: &aes128gcm.Config{Password: "123"},
layers: 2,
},
{
name: "original",
mask: &original.Config{},
layers: 2,
},
{
name: "dns",
mask: &dns.Config{Domain: "www.baidu.com"},
layers: 2,
},
{
name: "srtp",
mask: &srtp.Config{},
layers: 2,
},
{
name: "utp",
mask: &utp.Config{},
layers: 2,
},
{
name: "wechat",
mask: &wechat.Config{},
layers: 2,
},
{
name: "wireguard",
mask: &wireguard.Config{},
layers: 2,
},
{
name: "salamander",
mask: &salamander.Config{Password: "1234"},
layers: 2,
},
{
name: "sudoku-prefer-ascii",
mask: &sudoku.Config{
Password: "sudoku-mask",
Ascii: "prefer_ascii",
},
layers: 1,
},
{
name: "sudoku-custom-table",
mask: &sudoku.Config{
Password: "sudoku-mask",
Ascii: "prefer_entropy",
CustomTable: "xpxvvpvv",
},
layers: 1,
},
{
name: "sudoku-custom-tables",
mask: &sudoku.Config{
Password: "sudoku-mask",
Ascii: "prefer_entropy",
CustomTables: []string{"xpxvvpvv", "vxpvxvvp"},
},
layers: 1,
},
{
name: "custom",
@@ -103,18 +161,27 @@ func TestPacketConnReadWrite(t *testing.T) {
},
},
},
layers: 1,
},
{
name: "salamander",
name: "salamander-single",
mask: &salamander.Config{Password: "1234"},
layers: 1,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
mask := c.mask
maskManager := finalmask.NewUdpmaskManager([]finalmask.Udpmask{mask, mask})
layers := c.layers
if layers <= 0 {
layers = 1
}
masks := make([]finalmask.Udpmask, 0, layers)
for i := 0; i < layers; i++ {
masks = append(masks, mask)
}
maskManager := finalmask.NewUdpmaskManager(masks)
client, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
@@ -147,3 +214,419 @@ func TestPacketConnReadWrite(t *testing.T) {
})
}
}
func TestSudokuBDD(t *testing.T) {
t.Run("GivenSudokuTCPMask_WhenRoundTripWithAsciiPreference_ThenPayloadMatches", func(t *testing.T) {
cfg := &sudoku.Config{
Password: "sudoku-tcp",
Ascii: "prefer_ascii",
}
clientRaw, serverRaw := net.Pipe()
defer clientRaw.Close()
defer serverRaw.Close()
clientConn, err := cfg.WrapConnClient(clientRaw)
if err != nil {
t.Fatal(err)
}
serverConn, err := cfg.WrapConnServer(serverRaw)
if err != nil {
t.Fatal(err)
}
send := bytes.Repeat([]byte("client->server"), 1024)
recv := make([]byte, len(send))
writeErr := make(chan error, 1)
go func() {
_, wErr := clientConn.Write(send)
writeErr <- wErr
}()
if _, err := io.ReadFull(serverConn, recv); err != nil {
t.Fatal(err)
}
if err := <-writeErr; err != nil {
t.Fatal(err)
}
if !bytes.Equal(send, recv) {
t.Fatal("tcp sudoku payload mismatch")
}
})
t.Run("GivenSudokuTCPMask_WhenRoundTrip_ThenBothDirectionsMatch", func(t *testing.T) {
cfg := &sudoku.Config{
Password: "sudoku-packed",
Ascii: "prefer_ascii",
PaddingMin: 0,
PaddingMax: 0,
}
clientRaw, serverRaw := net.Pipe()
defer clientRaw.Close()
defer serverRaw.Close()
clientConn, err := cfg.WrapConnClient(clientRaw)
if err != nil {
t.Fatal(err)
}
serverConn, err := cfg.WrapConnServer(serverRaw)
if err != nil {
t.Fatal(err)
}
clientToServer := bytes.Repeat([]byte("client-packed->server"), 257)
serverToClient := bytes.Repeat([]byte("server-packed->client"), 263)
c2sRecv := make([]byte, len(clientToServer))
c2sErr := make(chan error, 1)
go func() {
_, err := clientConn.Write(clientToServer)
c2sErr <- err
}()
if _, err := io.ReadFull(serverConn, c2sRecv); err != nil {
t.Fatal(err)
}
if err := <-c2sErr; err != nil {
t.Fatal(err)
}
if !bytes.Equal(clientToServer, c2sRecv) {
t.Fatal("tcp client->server payload mismatch")
}
s2cRecv := make([]byte, len(serverToClient))
s2cErr := make(chan error, 1)
go func() {
_, err := serverConn.Write(serverToClient)
s2cErr <- err
}()
if _, err := io.ReadFull(clientConn, s2cRecv); err != nil {
t.Fatal(err)
}
if err := <-s2cErr; err != nil {
t.Fatal(err)
}
if !bytes.Equal(serverToClient, s2cRecv) {
t.Fatal("tcp server->client payload mismatch")
}
})
t.Run("GivenSudokuTCPMask_WhenServerWritesDownlink_ThenWireBytesAreReduced", func(t *testing.T) {
payload := bytes.Repeat([]byte("0123456789abcdef"), 192) // 3072 bytes, divisible by 3.
countWireBytes := func(wrapServer func(net.Conn, *sudoku.Config) (net.Conn, error), cfg *sudoku.Config) int64 {
t.Helper()
clientRaw, serverRaw := net.Pipe()
watchedServerRaw := &countingConn{Conn: serverRaw}
clientConn, err := cfg.WrapConnClient(clientRaw)
if err != nil {
t.Fatal(err)
}
serverConn, err := wrapServer(watchedServerRaw, cfg)
if err != nil {
t.Fatal(err)
}
readErr := make(chan error, 1)
go func() {
_, err := io.CopyN(io.Discard, clientConn, int64(len(payload)))
readErr <- err
}()
if _, err := serverConn.Write(payload); err != nil {
t.Fatal(err)
}
if err := <-readErr; err != nil {
t.Fatal(err)
}
_ = clientConn.Close()
_ = serverConn.Close()
return watchedServerRaw.Written()
}
pureUplinkPackedDownlink := &sudoku.Config{
Password: "sudoku-bandwidth",
Ascii: "prefer_entropy",
PaddingMin: 0,
PaddingMax: 0,
}
packedDownlinkBytes := countWireBytes(func(raw net.Conn, cfg *sudoku.Config) (net.Conn, error) {
return cfg.WrapConnServer(raw)
}, pureUplinkPackedDownlink)
legacyPureBytes := countWireBytes(func(raw net.Conn, cfg *sudoku.Config) (net.Conn, error) {
return sudoku.NewTCPConn(raw, cfg)
}, pureUplinkPackedDownlink)
if packedDownlinkBytes >= legacyPureBytes {
t.Fatalf("expected default packed downlink bytes < legacy pure bytes, got packed=%d pure=%d", packedDownlinkBytes, legacyPureBytes)
}
})
t.Run("GivenSudokuMultiTableTCPMask_WhenRoundTrip_ThenPayloadMatches", func(t *testing.T) {
cfg := &sudoku.Config{
Password: "sudoku-multi-tcp",
Ascii: "prefer_entropy",
CustomTables: []string{"xpxvvpvv", "vxpvxvvp"},
}
clientRaw, serverRaw := net.Pipe()
defer clientRaw.Close()
defer serverRaw.Close()
clientConn, err := cfg.WrapConnClient(clientRaw)
if err != nil {
t.Fatal(err)
}
serverConn, err := cfg.WrapConnServer(serverRaw)
if err != nil {
t.Fatal(err)
}
send := bytes.Repeat([]byte("rotate-table"), 513)
recv := make([]byte, len(send))
writeErr := make(chan error, 1)
go func() {
_, wErr := clientConn.Write(send)
writeErr <- wErr
}()
if _, err := io.ReadFull(serverConn, recv); err != nil {
t.Fatal(err)
}
if err := <-writeErr; err != nil {
t.Fatal(err)
}
if !bytes.Equal(send, recv) {
t.Fatal("multi-table tcp sudoku payload mismatch")
}
})
t.Run("GivenSudokuMultiTableTCPMask_WhenPackedDownlink_ThenPayloadMatches", func(t *testing.T) {
cfg := &sudoku.Config{
Password: "sudoku-multi-packed",
Ascii: "prefer_entropy",
CustomTables: []string{"xpxvvpvv", "vxpvxvvp"},
PaddingMin: 0,
PaddingMax: 0,
}
clientRaw, serverRaw := net.Pipe()
defer clientRaw.Close()
defer serverRaw.Close()
clientConn, err := cfg.WrapConnClient(clientRaw)
if err != nil {
t.Fatal(err)
}
serverConn, err := cfg.WrapConnServer(serverRaw)
if err != nil {
t.Fatal(err)
}
send := bytes.Repeat([]byte("packed-rotate"), 257)
recv := make([]byte, len(send))
writeErr := make(chan error, 1)
go func() {
_, wErr := clientConn.Write(send)
writeErr <- wErr
}()
if _, err := io.ReadFull(serverConn, recv); err != nil {
t.Fatal(err)
}
if err := <-writeErr; err != nil {
t.Fatal(err)
}
if !bytes.Equal(send, recv) {
t.Fatal("multi-table tcp sudoku payload mismatch")
}
})
t.Run("GivenSudokuUDPMask_WhenNotInnermost_ThenWrapFails", func(t *testing.T) {
cfg := &sudoku.Config{Password: "sudoku-udp"}
raw, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer raw.Close()
if _, err := cfg.WrapPacketConnClient(raw, 0, 1); err == nil {
t.Fatal("expected innermost check failure")
}
})
t.Run("GivenSudokuMultiTableUDPMask_WhenClientSendsMultipleDatagrams_ThenPayloadMatches", func(t *testing.T) {
cfg := &sudoku.Config{
Password: "sudoku-udp-multi",
Ascii: "prefer_entropy",
CustomTables: []string{"xpxvvpvv", "vxpvxvvp"},
PaddingMin: 0,
PaddingMax: 0,
}
maskManager := finalmask.NewUdpmaskManager([]finalmask.Udpmask{cfg})
clientRaw, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer clientRaw.Close()
serverRaw, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer serverRaw.Close()
client, err := maskManager.WrapPacketConnClient(clientRaw)
if err != nil {
t.Fatal(err)
}
server, err := maskManager.WrapPacketConnServer(serverRaw)
if err != nil {
t.Fatal(err)
}
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
_ = server.SetDeadline(time.Now().Add(2 * time.Second))
mustSendRecv(t, client, server, []byte("first-datagram"))
mustSendRecv(t, client, server, []byte("second-datagram"))
mustSendRecv(t, client, server, []byte("third-datagram"))
})
t.Run("GivenSudokuTCPMask_WhenCloseWriteIsCalled_ThenEOFPropagates", func(t *testing.T) {
cfg := &sudoku.Config{
Password: "sudoku-closewrite",
Ascii: "prefer_ascii",
PaddingMin: 0,
PaddingMax: 0,
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
acceptCh := make(chan net.Conn, 1)
errCh := make(chan error, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
errCh <- err
return
}
acceptCh <- conn
}()
clientRaw, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
t.Fatal(err)
}
defer clientRaw.Close()
var serverRaw net.Conn
select {
case serverRaw = <-acceptCh:
case err := <-errCh:
t.Fatal(err)
case <-time.After(2 * time.Second):
t.Fatal("accept timeout")
}
defer serverRaw.Close()
clientConn, err := cfg.WrapConnClient(clientRaw)
if err != nil {
t.Fatal(err)
}
serverConn, err := cfg.WrapConnServer(serverRaw)
if err != nil {
t.Fatal(err)
}
closeWriter, ok := clientConn.(interface{ CloseWrite() error })
if !ok {
t.Fatalf("wrapped conn does not expose CloseWrite: %T", clientConn)
}
writeErr := make(chan error, 1)
go func() {
if _, err := clientConn.Write([]byte("closewrite")); err != nil {
writeErr <- err
return
}
writeErr <- closeWriter.CloseWrite()
}()
buf := make([]byte, len("closewrite"))
if _, err := io.ReadFull(serverConn, buf); err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf, []byte("closewrite")) {
t.Fatal("unexpected payload before closewrite")
}
if err := <-writeErr; err != nil {
t.Fatal(err)
}
one := make([]byte, 1)
n, err := serverConn.Read(one)
if n != 0 || err != io.EOF {
t.Fatalf("expected EOF after CloseWrite, got n=%d err=%v", n, err)
}
})
t.Run("GivenSudokuTCPMask_WhenProxyUnwrapRawConn_ThenMaskConnIsRetained", func(t *testing.T) {
cfg := &sudoku.Config{
Password: "sudoku-unwrap",
Ascii: "prefer_entropy",
}
clientRaw, serverRaw := net.Pipe()
defer clientRaw.Close()
defer serverRaw.Close()
clientConn, err := cfg.WrapConnClient(clientRaw)
if err != nil {
t.Fatal(err)
}
unwrapped, readCounter, writeCounter := proxy.UnwrapRawConn(clientConn)
if readCounter != nil || writeCounter != nil {
t.Fatal("unexpected stat counters while unwrapping sudoku conn")
}
if unwrapped != clientConn {
t.Fatalf("expected sudoku conn to stay wrapped, got %T", unwrapped)
}
})
t.Run("GivenSudokuTCPMask_WhenProxyUnwrapRawConn_AfterDownlinkOptimization_ThenMaskConnIsRetained", func(t *testing.T) {
cfg := &sudoku.Config{
Password: "sudoku-packed-unwrap",
Ascii: "prefer_entropy",
}
clientRaw, serverRaw := net.Pipe()
defer clientRaw.Close()
defer serverRaw.Close()
clientConn, err := cfg.WrapConnClient(clientRaw)
if err != nil {
t.Fatal(err)
}
unwrapped, readCounter, writeCounter := proxy.UnwrapRawConn(clientConn)
if readCounter != nil || writeCounter != nil {
t.Fatal("unexpected stat counters while unwrapping sudoku conn")
}
if unwrapped != clientConn {
t.Fatalf("expected sudoku conn to stay wrapped, got %T", unwrapped)
}
})
}