diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index d7e26f51..bf6d7a43 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -30,6 +30,7 @@ import ( "github.com/xtls/xray-core/transport/internet/finalmask/mkcp/original" "github.com/xtls/xray-core/transport/internet/finalmask/noise" "github.com/xtls/xray-core/transport/internet/finalmask/salamander" + finalsudoku "github.com/xtls/xray-core/transport/internet/finalmask/sudoku" "github.com/xtls/xray-core/transport/internet/finalmask/xdns" "github.com/xtls/xray-core/transport/internet/finalmask/xicmp" "github.com/xtls/xray-core/transport/internet/httpupgrade" @@ -1314,6 +1315,7 @@ var ( tcpmaskLoader = NewJSONConfigLoader(ConfigCreatorCache{ "header-custom": func() interface{} { return new(HeaderCustomTCP) }, "fragment": func() interface{} { return new(FragmentMask) }, + "sudoku": func() interface{} { return new(Sudoku) }, }, "type", "settings") udpmaskLoader = NewJSONConfigLoader(ConfigCreatorCache{ @@ -1328,6 +1330,7 @@ var ( "mkcp-aes128gcm": func() interface{} { return new(Aes128Gcm) }, "noise": func() interface{} { return new(NoiseMask) }, "salamander": func() interface{} { return new(Salamander) }, + "sudoku": func() interface{} { return new(Sudoku) }, "xdns": func() interface{} { return new(Xdns) }, "xicmp": func() interface{} { return new(Xicmp) }, }, "type", "settings") @@ -1636,6 +1639,50 @@ func (c *Salamander) Build() (proto.Message, error) { return config, nil } +type Sudoku struct { + Password string `json:"password"` + ASCII string `json:"ascii"` + + CustomTable string `json:"customTable"` + LegacyCustomTable string `json:"custom_table"` + CustomTables []string `json:"customTables"` + LegacyCustomSets []string `json:"custom_tables"` + + PaddingMin uint32 `json:"paddingMin"` + LegacyPaddingMin uint32 `json:"padding_min"` + PaddingMax uint32 `json:"paddingMax"` + LegacyPaddingMax uint32 `json:"padding_max"` +} + +func (c *Sudoku) Build() (proto.Message, error) { + customTable := c.CustomTable + if customTable == "" { + customTable = c.LegacyCustomTable + } + customTables := c.CustomTables + if len(customTables) == 0 { + customTables = c.LegacyCustomSets + } + + paddingMin := c.PaddingMin + if paddingMin == 0 { + paddingMin = c.LegacyPaddingMin + } + paddingMax := c.PaddingMax + if paddingMax == 0 { + paddingMax = c.LegacyPaddingMax + } + + return &finalsudoku.Config{ + Password: c.Password, + Ascii: c.ASCII, + CustomTable: customTable, + CustomTables: customTables, + PaddingMin: paddingMin, + PaddingMax: paddingMax, + }, nil +} + type Xdns struct { Domain string `json:"domain"` } diff --git a/proxy/proxy.go b/proxy/proxy.go index acda52d9..ea8c3b06 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -660,7 +660,7 @@ func XtlsFilterTls(buffer buf.MultiBuffer, trafficState *TrafficState, ctx conte } } -// UnwrapRawConn support unwrap encryption, stats, tls, utls, reality, proxyproto, uds-wrapper conn and get raw tcp/uds conn from it +// UnwrapRawConn support unwrap encryption, stats, mask wrappers, tls, utls, reality, proxyproto, uds-wrapper conn and get raw tcp/uds conn from it func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) { var readCounter, writerCounter stats.Counter if conn != nil { @@ -677,6 +677,7 @@ func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) { readCounter = statConn.ReadCounter writerCounter = statConn.WriteCounter } + if !isEncryption { // avoids double penetration if xc, ok := conn.(*tls.Conn); ok { conn = xc.NetConn() diff --git a/transport/internet/finalmask/sudoku/codec.go b/transport/internet/finalmask/sudoku/codec.go new file mode 100644 index 00000000..e748b8b7 --- /dev/null +++ b/transport/internet/finalmask/sudoku/codec.go @@ -0,0 +1,163 @@ +package sudoku + +import ( + "fmt" + "math/rand" +) + +var perm4 = [24][4]byte{ + {0, 1, 2, 3}, + {0, 1, 3, 2}, + {0, 2, 1, 3}, + {0, 2, 3, 1}, + {0, 3, 1, 2}, + {0, 3, 2, 1}, + {1, 0, 2, 3}, + {1, 0, 3, 2}, + {1, 2, 0, 3}, + {1, 2, 3, 0}, + {1, 3, 0, 2}, + {1, 3, 2, 0}, + {2, 0, 1, 3}, + {2, 0, 3, 1}, + {2, 1, 0, 3}, + {2, 1, 3, 0}, + {2, 3, 0, 1}, + {2, 3, 1, 0}, + {3, 0, 1, 2}, + {3, 0, 2, 1}, + {3, 1, 0, 2}, + {3, 1, 2, 0}, + {3, 2, 0, 1}, + {3, 2, 1, 0}, +} + +type codec struct { + tables []*table + rng *rand.Rand + paddingChance int + tableIndex int +} + +func newCodec(tables []*table, pMin, pMax int) *codec { + if len(tables) == 0 { + tables = nil + } + rng := newSeededRand() + return &codec{ + tables: tables, + rng: rng, + paddingChance: pickPaddingChance(rng, pMin, pMax), + } +} + +func pickPaddingChance(rng *rand.Rand, pMin, pMax int) int { + if pMin < 0 { + pMin = 0 + } + if pMax < pMin { + pMax = pMin + } + if pMin > 100 { + pMin = 100 + } + if pMax > 100 { + pMax = 100 + } + if pMax == pMin { + return pMin + } + return pMin + rng.Intn(pMax-pMin+1) +} + +func (c *codec) shouldPad() bool { + if c.paddingChance <= 0 { + return false + } + if c.paddingChance >= 100 { + return true + } + return c.rng.Intn(100) < c.paddingChance +} + +func (c *codec) currentTable() *table { + if len(c.tables) == 0 { + return nil + } + return c.tables[c.tableIndex%len(c.tables)] +} + +func (c *codec) randomPadding(t *table) byte { + pool := t.layout.paddingPool + return pool[c.rng.Intn(len(pool))] +} + +func (c *codec) encode(in []byte) ([]byte, error) { + if len(in) == 0 { + return nil, nil + } + + out := make([]byte, 0, len(in)*6+8) + for _, b := range in { + t := c.currentTable() + if t == nil { + return nil, fmt.Errorf("sudoku table set missing") + } + if c.shouldPad() { + out = append(out, c.randomPadding(t)) + } + + enc := t.encode[b] + if len(enc) == 0 { + return nil, fmt.Errorf("sudoku encode table missing for byte %d", b) + } + + hints := enc[c.rng.Intn(len(enc))] + perm := perm4[c.rng.Intn(len(perm4))] + for _, idx := range perm { + if c.shouldPad() { + out = append(out, c.randomPadding(t)) + } + out = append(out, hints[idx]) + } + c.tableIndex++ + } + + if c.shouldPad() { + if t := c.currentTable(); t != nil { + out = append(out, c.randomPadding(t)) + } + } + + return out, nil +} + +func decodeBytes(tables []*table, tableIndex *int, in []byte, hintBuf []byte, out []byte) ([]byte, []byte, error) { + if len(tables) == 0 { + return hintBuf, out, fmt.Errorf("sudoku table set missing") + } + for _, b := range in { + t := tables[*tableIndex%len(tables)] + if !t.layout.isHint(b) { + continue + } + + hintBuf = append(hintBuf, b) + if len(hintBuf) < 4 { + continue + } + + keyBytes := sort4([4]byte{hintBuf[0], hintBuf[1], hintBuf[2], hintBuf[3]}) + key := packKey(keyBytes) + decoded, ok := t.decode[key] + if !ok { + return hintBuf[:0], out, fmt.Errorf("invalid sudoku hint tuple") + } + + out = append(out, decoded) + hintBuf = hintBuf[:0] + *tableIndex++ + } + + return hintBuf, out, nil +} diff --git a/transport/internet/finalmask/sudoku/config.go b/transport/internet/finalmask/sudoku/config.go new file mode 100644 index 00000000..58a4562f --- /dev/null +++ b/transport/internet/finalmask/sudoku/config.go @@ -0,0 +1,57 @@ +package sudoku + +import ( + "net" + + "github.com/xtls/xray-core/common/errors" +) + +func (c *Config) TCP() { +} + +func (c *Config) UDP() { +} + +// Sudoku in finalmask mode is a pure appearance transform with no standalone handshake. +// TCP always keeps classic sudoku on uplink and uses packed downlink optimization on server writes. +func (c *Config) WrapConnClient(raw net.Conn) (net.Conn, error) { + return newPackedDirectionalConn(raw, c, true) +} + +func (c *Config) WrapConnServer(raw net.Conn) (net.Conn, error) { + return newPackedDirectionalConn(raw, c, false) +} + +func newPackedDirectionalConn(raw net.Conn, config *Config, readPacked bool) (net.Conn, error) { + pureReader, pureWriter, err := newPureReaderWriter(raw, config) + if err != nil { + return nil, err + } + packedReader, packedWriter, err := newPackedReaderWriter(raw, config) + if err != nil { + return nil, err + } + + reader, writer := pureReader, pureWriter + if readPacked { + reader = packedReader + } else { + writer = packedWriter + } + + return newWrappedConn(raw, reader, writer), nil +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + if level != levelCount { + return nil, errors.New("sudoku udp mask must be the innermost mask in chain") + } + return NewUDPConn(raw, c) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + if level != levelCount { + return nil, errors.New("sudoku udp mask must be the innermost mask in chain") + } + return NewUDPConn(raw, c) +} diff --git a/transport/internet/finalmask/sudoku/config.pb.go b/transport/internet/finalmask/sudoku/config.pb.go new file mode 100644 index 00000000..56b544eb --- /dev/null +++ b/transport/internet/finalmask/sudoku/config.pb.go @@ -0,0 +1,170 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.5 +// source: transport/internet/finalmask/sudoku/config.proto + +package sudoku + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"` + Ascii string `protobuf:"bytes,2,opt,name=ascii,proto3" json:"ascii,omitempty"` + CustomTable string `protobuf:"bytes,3,opt,name=custom_table,json=customTable,proto3" json:"custom_table,omitempty"` + PaddingMin uint32 `protobuf:"varint,4,opt,name=padding_min,json=paddingMin,proto3" json:"padding_min,omitempty"` + PaddingMax uint32 `protobuf:"varint,5,opt,name=padding_max,json=paddingMax,proto3" json:"padding_max,omitempty"` + CustomTables []string `protobuf:"bytes,7,rep,name=custom_tables,json=customTables,proto3" json:"custom_tables,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_sudoku_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_sudoku_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_sudoku_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +func (x *Config) GetAscii() string { + if x != nil { + return x.Ascii + } + return "" +} + +func (x *Config) GetCustomTable() string { + if x != nil { + return x.CustomTable + } + return "" +} + +func (x *Config) GetPaddingMin() uint32 { + if x != nil { + return x.PaddingMin + } + return 0 +} + +func (x *Config) GetPaddingMax() uint32 { + if x != nil { + return x.PaddingMax + } + return 0 +} + +func (x *Config) GetCustomTables() []string { + if x != nil { + return x.CustomTables + } + return nil +} + +var File_transport_internet_finalmask_sudoku_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_sudoku_config_proto_rawDesc = "" + + "\n" + + "0transport/internet/finalmask/sudoku/config.proto\x12(xray.transport.internet.finalmask.sudoku\"\xc4\x01\n" + + "\x06Config\x12\x1a\n" + + "\bpassword\x18\x01 \x01(\tR\bpassword\x12\x14\n" + + "\x05ascii\x18\x02 \x01(\tR\x05ascii\x12!\n" + + "\fcustom_table\x18\x03 \x01(\tR\vcustomTable\x12\x1f\n" + + "\vpadding_min\x18\x04 \x01(\rR\n" + + "paddingMin\x12\x1f\n" + + "\vpadding_max\x18\x05 \x01(\rR\n" + + "paddingMax\x12#\n" + + "\rcustom_tables\x18\a \x03(\tR\fcustomTablesB\x9a\x01\n" + + ",com.xray.transport.internet.finalmask.sudokuP\x01Z=github.com/xtls/xray-core/transport/internet/finalmask/sudoku\xaa\x02(Xray.Transport.Internet.Finalmask.Sudokub\x06proto3" + +var ( + file_transport_internet_finalmask_sudoku_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_sudoku_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_sudoku_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_sudoku_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_sudoku_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_sudoku_config_proto_rawDesc), len(file_transport_internet_finalmask_sudoku_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_sudoku_config_proto_rawDescData +} + +var file_transport_internet_finalmask_sudoku_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_sudoku_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.sudoku.Config +} +var file_transport_internet_finalmask_sudoku_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_sudoku_config_proto_init() } +func file_transport_internet_finalmask_sudoku_config_proto_init() { + if File_transport_internet_finalmask_sudoku_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_sudoku_config_proto_rawDesc), len(file_transport_internet_finalmask_sudoku_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_sudoku_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_sudoku_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_sudoku_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_sudoku_config_proto = out.File + file_transport_internet_finalmask_sudoku_config_proto_goTypes = nil + file_transport_internet_finalmask_sudoku_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/sudoku/config.proto b/transport/internet/finalmask/sudoku/config.proto new file mode 100644 index 00000000..7089e0dd --- /dev/null +++ b/transport/internet/finalmask/sudoku/config.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.sudoku; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Sudoku"; +option go_package = "github.com/xtls/xray-core/transport/internet/finalmask/sudoku"; +option java_package = "com.xray.transport.internet.finalmask.sudoku"; +option java_multiple_files = true; + +message Config { + string password = 1; + string ascii = 2; + string custom_table = 3; + uint32 padding_min = 4; + uint32 padding_max = 5; + repeated string custom_tables = 7; +} diff --git a/transport/internet/finalmask/sudoku/conn_tcp.go b/transport/internet/finalmask/sudoku/conn_tcp.go new file mode 100644 index 00000000..75abb98e --- /dev/null +++ b/transport/internet/finalmask/sudoku/conn_tcp.go @@ -0,0 +1,212 @@ +package sudoku + +import ( + "bufio" + "io" + "net" + "sync" + + "github.com/xtls/xray-core/transport/internet/finalmask" +) + +const ioBufferSize = 32 * 1024 + +var _ finalmask.TcpMaskConn = (*wrappedConn)(nil) + +type streamDecoder interface { + decodeChunk(in []byte, pending []byte) ([]byte, error) + reset() +} + +type streamReader struct { + reader *bufio.Reader + rawBuf []byte + pending []byte + decode streamDecoder + mu sync.Mutex +} + +func newStreamReader(raw net.Conn, decode streamDecoder) io.Reader { + return &streamReader{ + reader: bufio.NewReaderSize(raw, ioBufferSize), + rawBuf: make([]byte, ioBufferSize), + pending: make([]byte, 0, 4096), + decode: decode, + } +} + +func (r *streamReader) Read(p []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if n, ok := drainPending(p, &r.pending); ok { + return n, nil + } + + for len(r.pending) == 0 { + nr, rErr := r.reader.Read(r.rawBuf) + if nr > 0 { + var dErr error + r.pending, dErr = r.decode.decodeChunk(r.rawBuf[:nr], r.pending) + if dErr != nil { + return 0, dErr + } + } + + if rErr != nil { + if rErr == io.EOF { + r.decode.reset() + if len(r.pending) > 0 { + break + } + } + return 0, rErr + } + } + + n, _ := drainPending(p, &r.pending) + return n, nil +} + +type streamWriter struct { + conn net.Conn + encode func([]byte) ([]byte, error) + mu sync.Mutex +} + +func newStreamWriter(raw net.Conn, encode func([]byte) ([]byte, error)) io.Writer { + return &streamWriter{ + conn: raw, + encode: encode, + } +} + +func (w *streamWriter) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + w.mu.Lock() + defer w.mu.Unlock() + + encoded, err := w.encode(p) + if err != nil { + return 0, err + } + if err := writeAll(w.conn, encoded); err != nil { + return 0, err + } + return len(p), nil +} + +type wrappedConn struct { + net.Conn + reader io.Reader + writer io.Writer +} + +type closeWriteConn interface { + CloseWrite() error +} + +func newWrappedConn(raw net.Conn, reader io.Reader, writer io.Writer) net.Conn { + return &wrappedConn{ + Conn: raw, + reader: reader, + writer: writer, + } +} + +func (c *wrappedConn) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + +func (c *wrappedConn) Write(p []byte) (int, error) { + return c.writer.Write(p) +} + +func (c *wrappedConn) TcpMaskConn() {} + +func (c *wrappedConn) RawConn() net.Conn { + return c.Conn +} + +func (c *wrappedConn) Splice() bool { + // Sudoku transforms the entire stream; bypassing it would disable masking. + return false +} + +func (c *wrappedConn) CloseWrite() error { + if raw, ok := c.Conn.(closeWriteConn); ok { + return raw.CloseWrite() + } + return net.ErrClosed +} + +func NewTCPConn(raw net.Conn, config *Config) (net.Conn, error) { + reader, writer, err := newPureReaderWriter(raw, config) + if err != nil { + return nil, err + } + return newWrappedConn(raw, reader, writer), nil +} + +func newPureReaderWriter(raw net.Conn, config *Config) (io.Reader, io.Writer, error) { + tables, err := getTables(config) + if err != nil { + return nil, nil, err + } + + pMin, pMax := normalizedPadding(config) + c := newCodec(tables, pMin, pMax) + return newStreamReader(raw, newHintStreamDecoder(tables)), newStreamWriter(raw, c.encode), nil +} + +type hintStreamDecoder struct { + tables []*table + tableIndex int + hintBuf []byte +} + +func newHintStreamDecoder(tables []*table) *hintStreamDecoder { + return &hintStreamDecoder{ + tables: tables, + hintBuf: make([]byte, 0, 4), + } +} + +func (d *hintStreamDecoder) decodeChunk(in []byte, pending []byte) ([]byte, error) { + var err error + d.hintBuf, pending, err = decodeBytes(d.tables, &d.tableIndex, in, d.hintBuf, pending) + return pending, err +} + +func (d *hintStreamDecoder) reset() {} + +func drainPending(p []byte, pending *[]byte) (int, bool) { + if len(*pending) == 0 { + return 0, false + } + + n := copy(p, *pending) + if n >= len(*pending) { + *pending = (*pending)[:0] + return n, true + } + + remaining := len(*pending) - n + copy(*pending, (*pending)[n:]) + *pending = (*pending)[:remaining] + return n, true +} + +func writeAll(conn net.Conn, b []byte) error { + for len(b) > 0 { + n, err := conn.Write(b) + if err != nil { + return err + } + b = b[n:] + } + return nil +} diff --git a/transport/internet/finalmask/sudoku/conn_tcp_packed.go b/transport/internet/finalmask/sudoku/conn_tcp_packed.go new file mode 100644 index 00000000..fa3c4c86 --- /dev/null +++ b/transport/internet/finalmask/sudoku/conn_tcp_packed.go @@ -0,0 +1,182 @@ +package sudoku + +import ( + "fmt" + "io" + "net" +) + +type packedEncoder struct { + layouts []*byteLayout + codec *codec + groupIndex int +} + +func newPackedEncoder(tables []*table, pMin, pMax int) *packedEncoder { + layouts := make([]*byteLayout, 0, len(tables)) + for _, t := range tables { + layouts = append(layouts, t.layout) + } + if len(layouts) == 0 { + layouts = append(layouts, entropyLayout()) + } + return &packedEncoder{ + layouts: layouts, + codec: newCodec(nil, pMin, pMax), + } +} + +func (e *packedEncoder) encode(p []byte) ([]byte, error) { + out := make([]byte, 0, len(p)*2+8) + var bitBuf uint64 + var bitCount uint8 + + for _, b := range p { + bitBuf = (bitBuf << 8) | uint64(b) + bitCount += 8 + + for bitCount >= 6 { + bitCount -= 6 + layout := e.layouts[e.groupIndex%len(e.layouts)] + group := byte(bitBuf >> bitCount) + out = e.maybePad(out, layout) + out = append(out, layout.encodeGroup(group&0x3f)) + e.groupIndex++ + if bitCount > 0 { + bitBuf &= (uint64(1) << bitCount) - 1 + } else { + bitBuf = 0 + } + } + } + + if bitCount > 0 { + layout := e.layouts[e.groupIndex%len(e.layouts)] + group := byte(bitBuf << (6 - bitCount)) + out = e.maybePad(out, layout) + out = append(out, layout.encodeGroup(group&0x3f)) + e.groupIndex++ + nextLayout := e.layouts[e.groupIndex%len(e.layouts)] + out = append(out, nextLayout.padMarker) + } + + out = e.maybePad(out, e.layouts[e.groupIndex%len(e.layouts)]) + return out, nil +} + +func (e *packedEncoder) maybePad(out []byte, layout *byteLayout) []byte { + if !e.codec.shouldPad() { + return out + } + if len(layout.paddingPool) == 1 { + return append(out, layout.paddingPool[0]) + } + for { + b := layout.paddingPool[e.codec.rng.Intn(len(layout.paddingPool))] + if b != layout.padMarker { + return append(out, b) + } + } +} + +type packedStreamDecoder struct { + layouts []*byteLayout + groupIndex int + bitBuf uint64 + bitCount int +} + +func (d *packedStreamDecoder) decodeChunk(in []byte, pending []byte) ([]byte, error) { + var err error + d.bitBuf, d.bitCount, d.groupIndex, pending, err = decodePackedBytes( + d.layouts, + in, + d.bitBuf, + d.bitCount, + d.groupIndex, + pending, + ) + return pending, err +} + +func (d *packedStreamDecoder) reset() { + d.bitBuf = 0 + d.bitCount = 0 +} + +func NewPackedTCPConn(raw net.Conn, config *Config) (net.Conn, error) { + reader, writer, err := newPackedReaderWriter(raw, config) + if err != nil { + return nil, err + } + return newWrappedConn(raw, reader, writer), nil +} + +func newPackedReaderWriter(raw net.Conn, config *Config) (io.Reader, io.Writer, error) { + tables, err := getTables(config) + if err != nil { + return nil, nil, err + } + + pMin, pMax := normalizedPadding(config) + encoder := newPackedEncoder(tables, pMin, pMax) + decoder := &packedStreamDecoder{ + layouts: tablesToLayouts(tables), + } + return newStreamReader(raw, decoder), newStreamWriter(raw, encoder.encode), nil +} + +func tablesToLayouts(tables []*table) []*byteLayout { + layouts := make([]*byteLayout, 0, len(tables)) + for _, t := range tables { + layouts = append(layouts, t.layout) + } + if len(layouts) == 0 { + layouts = append(layouts, entropyLayout()) + } + return layouts +} + +func decodePackedBytes( + layouts []*byteLayout, + in []byte, + bitBuf uint64, + bitCount int, + groupIndex int, + out []byte, +) (uint64, int, int, []byte, error) { + if len(layouts) == 0 { + return bitBuf, bitCount, groupIndex, out, fmt.Errorf("sudoku layout set missing") + } + for _, b := range in { + layout := layouts[groupIndex%len(layouts)] + if !layout.isHint(b) { + if b == layout.padMarker { + bitBuf = 0 + bitCount = 0 + } + continue + } + + group, ok := layout.decodeGroup(b) + if !ok { + return bitBuf, bitCount, groupIndex, out, fmt.Errorf("invalid packed sudoku byte: %d", b) + } + groupIndex++ + + bitBuf = (bitBuf << 6) | uint64(group) + bitCount += 6 + + for bitCount >= 8 { + bitCount -= 8 + out = append(out, byte(bitBuf>>bitCount)) + if bitCount > 0 { + bitBuf &= (uint64(1) << bitCount) - 1 + } else { + bitBuf = 0 + } + } + } + + return bitBuf, bitCount, groupIndex, out, nil +} diff --git a/transport/internet/finalmask/sudoku/conn_udp.go b/transport/internet/finalmask/sudoku/conn_udp.go new file mode 100644 index 00000000..0a774682 --- /dev/null +++ b/transport/internet/finalmask/sudoku/conn_udp.go @@ -0,0 +1,106 @@ +package sudoku + +import ( + "io" + "net" + "sync" + "time" +) + +type udpConn struct { + conn net.PacketConn + tables []*table + pMin int + pMax int + + readBuf []byte + + readMu sync.Mutex + writeMu sync.Mutex +} + +func NewUDPConn(raw net.PacketConn, config *Config) (net.PacketConn, error) { + tables, err := getTables(config) + if err != nil { + return nil, err + } + + pMin, pMax := normalizedPadding(config) + return &udpConn{ + conn: raw, + tables: tables, + pMin: pMin, + pMax: pMax, + readBuf: make([]byte, 65535), + }, nil +} + +func (c *udpConn) Size() int32 { + return 0 +} + +func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + n, addr, err = c.conn.ReadFrom(c.readBuf) + if err != nil { + return n, addr, err + } + + decoded := make([]byte, 0, n/4+1) + hints := make([]byte, 0, 4) + tableIndex := 0 + hints, decoded, err = decodeBytes(c.tables, &tableIndex, c.readBuf[:n], hints, decoded) + if err != nil { + return 0, addr, err + } + if len(hints) != 0 { + return 0, addr, io.ErrUnexpectedEOF + } + if len(p) < len(decoded) { + return 0, addr, io.ErrShortBuffer + } + copy(p, decoded) + return len(decoded), addr, nil +} + +func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + // UDP decoding restarts at table 0 for every datagram, so encoding must do the same. + encoded, err := newCodec(c.tables, c.pMin, c.pMax).encode(p) + if err != nil { + return 0, err + } + + nn, err := c.conn.WriteTo(encoded, addr) + if err != nil { + return 0, err + } + if nn != len(encoded) { + return 0, io.ErrShortWrite + } + return len(p), nil +} + +func (c *udpConn) Close() error { + return c.conn.Close() +} + +func (c *udpConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *udpConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *udpConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *udpConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/sudoku/sudoku_test.go b/transport/internet/finalmask/sudoku/sudoku_test.go new file mode 100644 index 00000000..4c22aba1 --- /dev/null +++ b/transport/internet/finalmask/sudoku/sudoku_test.go @@ -0,0 +1,1396 @@ +package sudoku + +import ( + "bytes" + "crypto/ecdh" + "crypto/rand" + cryptotls "crypto/tls" + "encoding/base64" + "encoding/hex" + "fmt" + "io" + stdnet "net" + "os" + "os/exec" + "os/signal" + "path/filepath" + "sync" + "syscall" + "testing" + "time" + + "github.com/xtls/xray-core/app/dispatcher" + "github.com/xtls/xray-core/app/log" + "github.com/xtls/xray-core/app/proxyman" + clog "github.com/xtls/xray-core/common/log" + xnet "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/protocol/tls/cert" + "github.com/xtls/xray-core/common/serial" + "github.com/xtls/xray-core/common/uuid" + core "github.com/xtls/xray-core/core" + "github.com/xtls/xray-core/proxy/dokodemo" + "github.com/xtls/xray-core/proxy/freedom" + hyproxy "github.com/xtls/xray-core/proxy/hysteria" + hyaccount "github.com/xtls/xray-core/proxy/hysteria/account" + "github.com/xtls/xray-core/proxy/vless" + vin "github.com/xtls/xray-core/proxy/vless/inbound" + vout "github.com/xtls/xray-core/proxy/vless/outbound" + testingtcp "github.com/xtls/xray-core/testing/servers/tcp" + "github.com/xtls/xray-core/transport/internet" + hytransport "github.com/xtls/xray-core/transport/internet/hysteria" + "github.com/xtls/xray-core/transport/internet/reality" + splithttp "github.com/xtls/xray-core/transport/internet/splithttp" + transtcp "github.com/xtls/xray-core/transport/internet/tcp" + xtls "github.com/xtls/xray-core/transport/internet/tls" + "google.golang.org/protobuf/proto" +) + +var ( + e2eBinaryOnce sync.Once + e2eBinaryPath string + e2eBinaryErr error +) + +type trafficMode struct { + name string + config *Config +} + +type protocolCase struct { + name string + transport string + run func(t *testing.T, bin string, mode trafficMode) caseResult +} + +type caseResult struct { + Protocol string + Mode string + TotalBytes int + ASCIIBytes int + ASCIIRatio float64 + AvgHammingOnes float64 + RotationSeen int + RotationExpected int + DecodedUnits int + ClientToServer directionResult + ServerToClient directionResult +} + +type directionResult struct { + RawBytes int + ASCIIBytes int + ASCIIRatio float64 + AvgHammingOnes float64 + RotationSeen int + DecodedUnits int +} + +type tcpRelay struct { + listener stdnet.Listener + target string + + mu sync.Mutex + captures []*tcpCapture + wg sync.WaitGroup + stopCh chan struct{} +} + +type tcpCapture struct { + mu sync.Mutex + c2s []byte + s2c []byte +} + +type udpRelay struct { + conn stdnet.PacketConn + target *stdnet.UDPAddr + clientMu sync.Mutex + client *stdnet.UDPAddr + stopCh chan struct{} + wg sync.WaitGroup + captureMu sync.Mutex + c2s [][]byte + s2c [][]byte +} + +type tlsDecoy struct { + ln stdnet.Listener + done chan struct{} + wg sync.WaitGroup +} + +func TestSudokuE2ETemp(t *testing.T) { + if testing.Short() { + t.Skip("skipping sudoku e2e harness in short mode") + } + + bin := buildE2EBinary(t) + payloadSize := 192 * 1024 + modes := []trafficMode{ + { + name: "prefer_ascii", + config: &Config{ + Password: "sudoku-e2e-shared-secret", + Ascii: "prefer_ascii", + }, + }, + { + name: "prefer_entropy", + config: &Config{ + Password: "sudoku-e2e-shared-secret", + Ascii: "prefer_entropy", + CustomTables: []string{ + "xpxvvpvv", + "vxpvxvvp", + "pxvvxvvp", + "vpxvxvpv", + "xvpvvxpv", + "vvxpxpvv", + }, + }, + }, + } + + cases := []protocolCase{ + {name: "vless-reality", transport: "tcp", run: func(t *testing.T, bin string, mode trafficMode) caseResult { + return runVLESSRealityCase(t, bin, mode, payloadSize) + }}, + {name: "hysteria2", transport: "udp", run: func(t *testing.T, bin string, mode trafficMode) caseResult { + return runHysteria2Case(t, bin, mode, payloadSize) + }}, + {name: "vless-enc", transport: "tcp", run: func(t *testing.T, bin string, mode trafficMode) caseResult { + return runVLesseEncCase(t, bin, mode, payloadSize) + }}, + {name: "vless-xhttp", transport: "tcp", run: func(t *testing.T, bin string, mode trafficMode) caseResult { + return runVLESSXHTTPCase(t, bin, mode, payloadSize) + }}, + } + + results := make([]caseResult, 0, len(cases)*len(modes)) + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + for _, mode := range modes { + mode := mode + t.Run(mode.name, func(t *testing.T) { + result := tc.run(t, bin, mode) + if mode.name == "prefer_ascii" && result.ASCIIRatio < 0.97 { + t.Fatalf("%s %s ascii ratio %.4f < 0.97", tc.name, mode.name, result.ASCIIRatio) + } + if mode.name == "prefer_entropy" { + if result.RotationSeen != result.RotationExpected { + t.Fatalf("%s %s saw %d/%d rotation tables", tc.name, mode.name, result.RotationSeen, result.RotationExpected) + } + if diff := result.AvgHammingOnes - 5.0; diff < -0.3 || diff > 0.3 { + t.Fatalf("%s %s average ones %.4f too far from 5", tc.name, mode.name, result.AvgHammingOnes) + } + } + t.Logf( + "%s %s total=%d ascii=%.4f avg_ones=%.4f rotation=%d/%d c2s_ascii=%.4f s2c_ascii=%.4f", + tc.name, + mode.name, + result.TotalBytes, + result.ASCIIRatio, + result.AvgHammingOnes, + result.RotationSeen, + result.RotationExpected, + result.ClientToServer.ASCIIRatio, + result.ServerToClient.ASCIIRatio, + ) + results = append(results, result) + }) + } + }) + } + + for _, result := range results { + t.Logf( + "summary protocol=%s mode=%s bytes=%d ascii=%.4f avg_ones=%.4f rotation=%d/%d decoded=%d", + result.Protocol, + result.Mode, + result.TotalBytes, + result.ASCIIRatio, + result.AvgHammingOnes, + result.RotationSeen, + result.RotationExpected, + result.DecodedUnits, + ) + } +} + +func runVLESSRealityCase(t *testing.T, bin string, mode trafficMode, payloadSize int) caseResult { + backend := startXOREchoServer(t) + defer backend.Close() + + decoyCert, _ := cert.MustGenerate(nil, cert.CommonName("localhost"), cert.DNSNames("localhost")) + decoy := startTLSEchoDecoy(t, decoyCert) + defer decoy.Close() + + serverPort := testingtcp.PickPort() + relayPort := testingtcp.PickPort() + clientPort := testingtcp.PickPort() + + relay := startTCPRelay(t, int(relayPort), fmt.Sprintf("127.0.0.1:%d", serverPort)) + defer relay.Close() + + userID := protocol.NewID(uuid.New()) + realityPriv, realityPub := mustX25519Keypair(t) + shortID := mustDecodeHex(t, "0123456789abcdef") + + serverConfig := defaultApps(&core.Config{ + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &xnet.PortList{Range: []*xnet.PortRange{xnet.SinglePortRange(serverPort)}}, + Listen: xnet.NewIPOrDomain(xnet.LocalHostIP), + StreamSettings: &internet.StreamConfig{ + ProtocolName: "tcp", + TransportSettings: []*internet.TransportConfig{ + { + ProtocolName: "tcp", + Settings: serial.ToTypedMessage(&transtcp.Config{}), + }, + }, + SecurityType: serial.GetMessageType(&reality.Config{}), + SecuritySettings: []*serial.TypedMessage{ + serial.ToTypedMessage(&reality.Config{ + Show: true, + Dest: fmt.Sprintf("localhost:%d", decoy.Port()), + ServerNames: []string{"localhost"}, + PrivateKey: realityPriv, + ShortIds: [][]byte{shortID}, + Type: "tcp", + }), + }, + Tcpmasks: []*serial.TypedMessage{serial.ToTypedMessage(cloneConfig(mode.config))}, + }, + }), + ProxySettings: serial.ToTypedMessage(&vin.Config{ + Clients: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vless.Account{ + Id: userID.String(), + }), + }, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + {ProxySettings: serial.ToTypedMessage(&freedom.Config{})}, + }, + }) + + clientConfig := defaultApps(&core.Config{ + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &xnet.PortList{Range: []*xnet.PortRange{xnet.SinglePortRange(clientPort)}}, + Listen: xnet.NewIPOrDomain(xnet.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: xnet.NewIPOrDomain(backend.Address()), + Port: uint32(backend.Port()), + Networks: []xnet.Network{xnet.Network_TCP}, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&vout.Config{ + Vnext: &protocol.ServerEndpoint{ + Address: xnet.NewIPOrDomain(xnet.LocalHostIP), + Port: uint32(relayPort), + User: &protocol.User{ + Account: serial.ToTypedMessage(&vless.Account{ + Id: userID.String(), + }), + }, + }, + }), + SenderSettings: serial.ToTypedMessage(&proxyman.SenderConfig{ + StreamSettings: &internet.StreamConfig{ + ProtocolName: "tcp", + TransportSettings: []*internet.TransportConfig{ + { + ProtocolName: "tcp", + Settings: serial.ToTypedMessage(&transtcp.Config{}), + }, + }, + SecurityType: serial.GetMessageType(&reality.Config{}), + SecuritySettings: []*serial.TypedMessage{ + serial.ToTypedMessage(&reality.Config{ + Show: true, + Fingerprint: "chrome", + ServerName: "localhost", + PublicKey: realityPub, + ShortId: shortID, + SpiderX: "/", + }), + }, + Tcpmasks: []*serial.TypedMessage{serial.ToTypedMessage(cloneConfig(mode.config))}, + }, + }), + }, + }, + }) + + serverCmd, clientCmd := runXrayPair(t, bin, serverConfig, clientConfig) + defer stopCmd(clientCmd) + defer stopCmd(serverCmd) + exerciseTCPClient(t, int(clientPort), payloadSize) + + return analyzeTCPRelay(t, "vless-reality", mode, relay.Snapshots()) +} + +func runHysteria2Case(t *testing.T, bin string, mode trafficMode, payloadSize int) caseResult { + backend := startXOREchoServer(t) + defer backend.Close() + + serverPort := testingtcp.PickPort() + relayPort := testingtcp.PickPort() + clientPort := testingtcp.PickPort() + + relay := startUDPRelay(t, int(relayPort), int(serverPort)) + defer relay.Close() + + ct, ctHash := cert.MustGenerate(nil, cert.CommonName("localhost"), cert.DNSNames("localhost")) + auth := "hy2-auth-secret" + + serverConfig := defaultApps(&core.Config{ + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &xnet.PortList{Range: []*xnet.PortRange{xnet.SinglePortRange(serverPort)}}, + Listen: xnet.NewIPOrDomain(xnet.LocalHostIP), + StreamSettings: &internet.StreamConfig{ + ProtocolName: "hysteria", + TransportSettings: []*internet.TransportConfig{ + { + ProtocolName: "hysteria", + Settings: serial.ToTypedMessage(&hytransport.Config{ + Version: 2, + Auth: auth, + Congestion: "bbr", + Up: 10 * 1024 * 1024, + Down: 10 * 1024 * 1024, + UdpIdleTimeout: 60, + }), + }, + }, + SecurityType: serial.GetMessageType(&xtls.Config{}), + SecuritySettings: []*serial.TypedMessage{ + serial.ToTypedMessage(&xtls.Config{ + Certificate: []*xtls.Certificate{xtls.ParseCertificate(ct)}, + NextProtocol: []string{"h3"}, + }), + }, + Udpmasks: []*serial.TypedMessage{serial.ToTypedMessage(cloneConfig(mode.config))}, + }, + }), + ProxySettings: serial.ToTypedMessage(&hyproxy.ServerConfig{ + Users: []*protocol.User{ + { + Account: serial.ToTypedMessage(&hyaccount.Account{Auth: auth}), + }, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + {ProxySettings: serial.ToTypedMessage(&freedom.Config{})}, + }, + }) + + clientConfig := defaultApps(&core.Config{ + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &xnet.PortList{Range: []*xnet.PortRange{xnet.SinglePortRange(clientPort)}}, + Listen: xnet.NewIPOrDomain(xnet.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: xnet.NewIPOrDomain(backend.Address()), + Port: uint32(backend.Port()), + Networks: []xnet.Network{xnet.Network_TCP}, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&hyproxy.ClientConfig{ + Version: 2, + Server: &protocol.ServerEndpoint{ + Address: xnet.NewIPOrDomain(xnet.LocalHostIP), + Port: uint32(relayPort), + User: &protocol.User{ + Account: serial.ToTypedMessage(&hyaccount.Account{Auth: auth}), + }, + }, + }), + SenderSettings: serial.ToTypedMessage(&proxyman.SenderConfig{ + StreamSettings: &internet.StreamConfig{ + ProtocolName: "hysteria", + TransportSettings: []*internet.TransportConfig{ + { + ProtocolName: "hysteria", + Settings: serial.ToTypedMessage(&hytransport.Config{ + Version: 2, + Auth: auth, + Congestion: "bbr", + Up: 10 * 1024 * 1024, + Down: 10 * 1024 * 1024, + UdpIdleTimeout: 60, + }), + }, + }, + SecurityType: serial.GetMessageType(&xtls.Config{}), + SecuritySettings: []*serial.TypedMessage{ + serial.ToTypedMessage(&xtls.Config{ + ServerName: "localhost", + PinnedPeerCertSha256: [][]byte{ctHash[:]}, + NextProtocol: []string{"h3"}, + }), + }, + Udpmasks: []*serial.TypedMessage{serial.ToTypedMessage(cloneConfig(mode.config))}, + }, + }), + }, + }, + }) + + serverCmd, clientCmd := runXrayPair(t, bin, serverConfig, clientConfig) + defer stopCmd(clientCmd) + defer stopCmd(serverCmd) + if err := exerciseTCPClientErr(t, int(clientPort), payloadSize); err != nil { + c2s, s2c := relay.Snapshots() + t.Fatalf("hy2 traffic failed: %v (udp packets c2s=%d s2c=%d first_c2s=%d first_s2c=%d)", err, len(c2s), len(s2c), firstChunkLen(c2s), firstChunkLen(s2c)) + } + + c2s, s2c := relay.Snapshots() + return analyzeUDPRelay(t, "hysteria2", mode, c2s, s2c) +} + +func runVLesseEncCase(t *testing.T, bin string, mode trafficMode, payloadSize int) caseResult { + backend := startXOREchoServer(t) + defer backend.Close() + + serverPort := testingtcp.PickPort() + relayPort := testingtcp.PickPort() + clientPort := testingtcp.PickPort() + + relay := startTCPRelay(t, int(relayPort), fmt.Sprintf("127.0.0.1:%d", serverPort)) + defer relay.Close() + + userID := protocol.NewID(uuid.New()) + priv, pub := mustX25519Keypair(t) + pubB64 := base64.RawURLEncoding.EncodeToString(pub) + privB64 := base64.RawURLEncoding.EncodeToString(priv) + + serverConfig := defaultApps(&core.Config{ + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &xnet.PortList{Range: []*xnet.PortRange{xnet.SinglePortRange(serverPort)}}, + Listen: xnet.NewIPOrDomain(xnet.LocalHostIP), + StreamSettings: &internet.StreamConfig{ + ProtocolName: "tcp", + TransportSettings: []*internet.TransportConfig{ + {ProtocolName: "tcp", Settings: serial.ToTypedMessage(&transtcp.Config{})}, + }, + Tcpmasks: []*serial.TypedMessage{serial.ToTypedMessage(cloneConfig(mode.config))}, + }, + }), + ProxySettings: serial.ToTypedMessage(&vin.Config{ + Clients: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vless.Account{ + Id: userID.String(), + }), + }, + }, + Decryption: privB64, + XorMode: 1, + SecondsFrom: 0, + SecondsTo: 0, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + {ProxySettings: serial.ToTypedMessage(&freedom.Config{})}, + }, + }) + + clientConfig := defaultApps(&core.Config{ + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &xnet.PortList{Range: []*xnet.PortRange{xnet.SinglePortRange(clientPort)}}, + Listen: xnet.NewIPOrDomain(xnet.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: xnet.NewIPOrDomain(backend.Address()), + Port: uint32(backend.Port()), + Networks: []xnet.Network{xnet.Network_TCP}, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&vout.Config{ + Vnext: &protocol.ServerEndpoint{ + Address: xnet.NewIPOrDomain(xnet.LocalHostIP), + Port: uint32(relayPort), + User: &protocol.User{ + Account: serial.ToTypedMessage(&vless.Account{ + Id: userID.String(), + Encryption: pubB64, + XorMode: 1, + }), + }, + }, + }), + SenderSettings: serial.ToTypedMessage(&proxyman.SenderConfig{ + StreamSettings: &internet.StreamConfig{ + ProtocolName: "tcp", + TransportSettings: []*internet.TransportConfig{ + {ProtocolName: "tcp", Settings: serial.ToTypedMessage(&transtcp.Config{})}, + }, + Tcpmasks: []*serial.TypedMessage{serial.ToTypedMessage(cloneConfig(mode.config))}, + }, + }), + }, + }, + }) + + serverCmd, clientCmd := runXrayPair(t, bin, serverConfig, clientConfig) + defer stopCmd(clientCmd) + defer stopCmd(serverCmd) + exerciseTCPClient(t, int(clientPort), payloadSize) + + return analyzeTCPRelay(t, "vless-enc", mode, relay.Snapshots()) +} + +func runVLESSXHTTPCase(t *testing.T, bin string, mode trafficMode, payloadSize int) caseResult { + backend := startXOREchoServer(t) + defer backend.Close() + + serverPort := testingtcp.PickPort() + relayPort := testingtcp.PickPort() + clientPort := testingtcp.PickPort() + + relay := startTCPRelay(t, int(relayPort), fmt.Sprintf("127.0.0.1:%d", serverPort)) + defer relay.Close() + + userID := protocol.NewID(uuid.New()) + xhttpConfig := &splithttp.Config{ + Host: "localhost", + Path: "/sudoku", + Mode: "auto", + } + + serverConfig := defaultApps(&core.Config{ + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &xnet.PortList{Range: []*xnet.PortRange{xnet.SinglePortRange(serverPort)}}, + Listen: xnet.NewIPOrDomain(xnet.LocalHostIP), + StreamSettings: &internet.StreamConfig{ + ProtocolName: "splithttp", + TransportSettings: []*internet.TransportConfig{ + {ProtocolName: "splithttp", Settings: serial.ToTypedMessage(xhttpConfig)}, + }, + Tcpmasks: []*serial.TypedMessage{serial.ToTypedMessage(cloneConfig(mode.config))}, + }, + }), + ProxySettings: serial.ToTypedMessage(&vin.Config{ + Clients: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vless.Account{ + Id: userID.String(), + }), + }, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + {ProxySettings: serial.ToTypedMessage(&freedom.Config{})}, + }, + }) + + clientConfig := defaultApps(&core.Config{ + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &xnet.PortList{Range: []*xnet.PortRange{xnet.SinglePortRange(clientPort)}}, + Listen: xnet.NewIPOrDomain(xnet.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: xnet.NewIPOrDomain(backend.Address()), + Port: uint32(backend.Port()), + Networks: []xnet.Network{xnet.Network_TCP}, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&vout.Config{ + Vnext: &protocol.ServerEndpoint{ + Address: xnet.NewIPOrDomain(xnet.LocalHostIP), + Port: uint32(relayPort), + User: &protocol.User{ + Account: serial.ToTypedMessage(&vless.Account{ + Id: userID.String(), + }), + }, + }, + }), + SenderSettings: serial.ToTypedMessage(&proxyman.SenderConfig{ + StreamSettings: &internet.StreamConfig{ + ProtocolName: "splithttp", + TransportSettings: []*internet.TransportConfig{ + {ProtocolName: "splithttp", Settings: serial.ToTypedMessage(xhttpConfig)}, + }, + Tcpmasks: []*serial.TypedMessage{serial.ToTypedMessage(cloneConfig(mode.config))}, + }, + }), + }, + }, + }) + + serverCmd, clientCmd := runXrayPair(t, bin, serverConfig, clientConfig) + defer stopCmd(clientCmd) + defer stopCmd(serverCmd) + exerciseTCPClient(t, int(clientPort), payloadSize) + + return analyzeTCPRelay(t, "vless-xhttp", mode, relay.Snapshots()) +} + +func analyzeTCPRelay(t *testing.T, protocol string, mode trafficMode, captures []*tcpCapture) caseResult { + tables, err := getTables(mode.config) + if err != nil { + t.Fatal(err) + } + + allC2S := make([][]byte, 0, len(captures)) + allS2C := make([][]byte, 0, len(captures)) + for _, capture := range captures { + c2s, s2c := capture.snapshot() + if len(c2s) > 0 { + allC2S = append(allC2S, c2s) + } + if len(s2c) > 0 { + allS2C = append(allS2C, s2c) + } + } + + c2sMetrics := metricFromBytes(flattenChunks(allC2S)) + s2cMetrics := metricFromBytes(flattenChunks(allS2C)) + + c2sUsed, c2sDecoded, err := analyzePureChunks(tables, allC2S) + if err != nil { + t.Fatalf("%s %s pure decode failed: %v", protocol, mode.name, err) + } + s2cUsed, s2cDecoded, err := analyzePackedChunks(tables, allS2C) + if err != nil { + t.Fatalf("%s %s packed decode failed: %v", protocol, mode.name, err) + } + + allBytes := append(append([]byte{}, flattenChunks(allC2S)...), flattenChunks(allS2C)...) + totalMetrics := metricFromBytes(allBytes) + rotationSeen := len(unionKeys(c2sUsed, s2cUsed)) + + return caseResult{ + Protocol: protocol, + Mode: mode.name, + TotalBytes: len(allBytes), + ASCIIBytes: totalMetrics.asciiBytes, + ASCIIRatio: totalMetrics.asciiRatio, + AvgHammingOnes: totalMetrics.avgOnes, + RotationSeen: rotationSeen, + RotationExpected: expectedRotation(mode.config), + DecodedUnits: c2sDecoded + s2cDecoded, + ClientToServer: directionResult{ + RawBytes: len(flattenChunks(allC2S)), + ASCIIBytes: c2sMetrics.asciiBytes, + ASCIIRatio: c2sMetrics.asciiRatio, + AvgHammingOnes: c2sMetrics.avgOnes, + RotationSeen: len(c2sUsed), + DecodedUnits: c2sDecoded, + }, + ServerToClient: directionResult{ + RawBytes: len(flattenChunks(allS2C)), + ASCIIBytes: s2cMetrics.asciiBytes, + ASCIIRatio: s2cMetrics.asciiRatio, + AvgHammingOnes: s2cMetrics.avgOnes, + RotationSeen: len(s2cUsed), + DecodedUnits: s2cDecoded, + }, + } +} + +func analyzeUDPRelay(t *testing.T, protocol string, mode trafficMode, c2s [][]byte, s2c [][]byte) caseResult { + tables, err := getTables(mode.config) + if err != nil { + t.Fatal(err) + } + + c2sMetrics := metricFromBytes(flattenChunks(c2s)) + s2cMetrics := metricFromBytes(flattenChunks(s2c)) + + c2sUsed, c2sDecoded, err := analyzePureChunks(tables, c2s) + if err != nil { + t.Fatalf("%s %s udp c2s decode failed: %v", protocol, mode.name, err) + } + s2cUsed, s2cDecoded, err := analyzePureChunks(tables, s2c) + if err != nil { + t.Fatalf("%s %s udp s2c decode failed: %v", protocol, mode.name, err) + } + + allBytes := append(append([]byte{}, flattenChunks(c2s)...), flattenChunks(s2c)...) + totalMetrics := metricFromBytes(allBytes) + rotationSeen := len(unionKeys(c2sUsed, s2cUsed)) + + return caseResult{ + Protocol: protocol, + Mode: mode.name, + TotalBytes: len(allBytes), + ASCIIBytes: totalMetrics.asciiBytes, + ASCIIRatio: totalMetrics.asciiRatio, + AvgHammingOnes: totalMetrics.avgOnes, + RotationSeen: rotationSeen, + RotationExpected: expectedRotation(mode.config), + DecodedUnits: c2sDecoded + s2cDecoded, + ClientToServer: directionResult{ + RawBytes: len(flattenChunks(c2s)), + ASCIIBytes: c2sMetrics.asciiBytes, + ASCIIRatio: c2sMetrics.asciiRatio, + AvgHammingOnes: c2sMetrics.avgOnes, + RotationSeen: len(c2sUsed), + DecodedUnits: c2sDecoded, + }, + ServerToClient: directionResult{ + RawBytes: len(flattenChunks(s2c)), + ASCIIBytes: s2cMetrics.asciiBytes, + ASCIIRatio: s2cMetrics.asciiRatio, + AvgHammingOnes: s2cMetrics.avgOnes, + RotationSeen: len(s2cUsed), + DecodedUnits: s2cDecoded, + }, + } +} + +type byteMetrics struct { + asciiBytes int + asciiRatio float64 + avgOnes float64 +} + +func metricFromBytes(b []byte) byteMetrics { + if len(b) == 0 { + return byteMetrics{} + } + var ascii, ones int + for _, v := range b { + if v < 0x80 { + ascii++ + } + ones += bitsInByte(v) + } + return byteMetrics{ + asciiBytes: ascii, + asciiRatio: float64(ascii) / float64(len(b)), + avgOnes: float64(ones) / float64(len(b)), + } +} + +func bitsInByte(b byte) int { + n := 0 + for b != 0 { + n += int(b & 1) + b >>= 1 + } + return n +} + +func analyzePureChunks(tables []*table, chunks [][]byte) (map[int]int, int, error) { + if len(tables) == 0 { + return nil, 0, fmt.Errorf("no sudoku tables") + } + used := make(map[int]int) + decoded := 0 + for _, chunk := range chunks { + hintBuf := make([]byte, 0, 4) + tableIndex := 0 + for _, b := range chunk { + t := tables[tableIndex%len(tables)] + if !t.layout.isHint(b) { + continue + } + hintBuf = append(hintBuf, b) + if len(hintBuf) < 4 { + continue + } + keyBytes := sort4([4]byte{hintBuf[0], hintBuf[1], hintBuf[2], hintBuf[3]}) + key := packKey(keyBytes) + if _, ok := t.decode[key]; !ok { + return nil, 0, fmt.Errorf("invalid pure tuple at table %d", tableIndex%len(tables)) + } + used[tableIndex%len(tables)]++ + decoded++ + tableIndex++ + hintBuf = hintBuf[:0] + } + if len(hintBuf) != 0 { + return nil, 0, fmt.Errorf("leftover pure hints") + } + } + return used, decoded, nil +} + +func analyzePackedChunks(tables []*table, chunks [][]byte) (map[int]int, int, error) { + layouts := tablesToLayouts(tables) + if len(layouts) == 0 { + return nil, 0, fmt.Errorf("no sudoku layouts") + } + used := make(map[int]int) + decoded := 0 + for _, chunk := range chunks { + var bitBuf uint64 + var bitCount int + groupIndex := 0 + for _, b := range chunk { + layout := layouts[groupIndex%len(layouts)] + if !layout.isHint(b) { + if b == layout.padMarker { + bitBuf = 0 + bitCount = 0 + } + continue + } + group, ok := layout.decodeGroup(b) + if !ok { + return nil, 0, fmt.Errorf("invalid packed byte %d", b) + } + used[groupIndex%len(layouts)]++ + groupIndex++ + bitBuf = (bitBuf << 6) | uint64(group) + bitCount += 6 + for bitCount >= 8 { + bitCount -= 8 + decoded++ + if bitCount > 0 { + bitBuf &= (uint64(1) << bitCount) - 1 + } else { + bitBuf = 0 + } + } + } + } + return used, decoded, nil +} + +func expectedRotation(cfg *Config) int { + tables, err := getTables(cfg) + if err != nil { + return 0 + } + return len(tables) +} + +func unionKeys(a, b map[int]int) map[int]struct{} { + out := make(map[int]struct{}, len(a)+len(b)) + for k := range a { + out[k] = struct{}{} + } + for k := range b { + out[k] = struct{}{} + } + return out +} + +func flattenChunks(chunks [][]byte) []byte { + total := 0 + for _, chunk := range chunks { + total += len(chunk) + } + out := make([]byte, 0, total) + for _, chunk := range chunks { + out = append(out, chunk...) + } + return out +} + +func cloneConfig(cfg *Config) *Config { + if cfg == nil { + return nil + } + out := proto.Clone(cfg).(*Config) + return out +} + +func defaultApps(cfg *core.Config) *core.Config { + cfg.App = append(cfg.App, + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: clog.Severity_Warning, + ErrorLogType: log.LogType_Console, + }), + serial.ToTypedMessage(&dispatcher.Config{}), + serial.ToTypedMessage(&proxyman.InboundConfig{}), + serial.ToTypedMessage(&proxyman.OutboundConfig{}), + ) + return cfg +} + +func buildE2EBinary(t *testing.T) string { + t.Helper() + e2eBinaryOnce.Do(func() { + tempDir, err := os.MkdirTemp("", "xray-sudoku-e2e-*") + if err != nil { + e2eBinaryErr = err + return + } + e2eBinaryPath = filepath.Join(tempDir, "xray.test") + cmd := exec.Command("go", "build", "-o", e2eBinaryPath, "./main") + cmd.Dir = repoRoot(t) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + e2eBinaryErr = cmd.Run() + }) + if e2eBinaryErr != nil { + t.Fatal(e2eBinaryErr) + } + return e2eBinaryPath +} + +func repoRoot(t *testing.T) string { + t.Helper() + dir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + t.Fatal("failed to locate repo root") + } + dir = parent + } +} + +func runXrayPair(t *testing.T, bin string, serverCfg, clientCfg *core.Config) (*exec.Cmd, *exec.Cmd) { + t.Helper() + serverCmd := runXray(t, bin, serverCfg) + + time.Sleep(500 * time.Millisecond) + + clientCmd := runXray(t, bin, clientCfg) + + time.Sleep(1500 * time.Millisecond) + return serverCmd, clientCmd +} + +func runXray(t *testing.T, bin string, cfg *core.Config) *exec.Cmd { + t.Helper() + cfgBytes, err := proto.Marshal(cfg) + if err != nil { + t.Fatal(err) + } + cmd := exec.Command(bin, "-config=stdin:", "-format=pb") + cmd.Stdin = bytes.NewReader(cfgBytes) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + return cmd +} + +func stopCmd(cmd *exec.Cmd) { + if cmd == nil || cmd.Process == nil { + return + } + _ = cmd.Process.Signal(syscall.SIGTERM) + done := make(chan struct{}) + go func() { + _, _ = cmd.Process.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(3 * time.Second): + _ = cmd.Process.Kill() + <-done + } +} + +func startTCPRelay(t *testing.T, listenPort int, target string) *tcpRelay { + t.Helper() + ln, err := stdnet.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort)) + if err != nil { + t.Fatal(err) + } + r := &tcpRelay{ + listener: ln, + target: target, + stopCh: make(chan struct{}), + } + r.wg.Add(1) + go func() { + defer r.wg.Done() + for { + conn, err := ln.Accept() + if err != nil { + select { + case <-r.stopCh: + return + default: + } + return + } + targetConn, err := stdnet.Dial("tcp", target) + if err != nil { + _ = conn.Close() + continue + } + capture := &tcpCapture{} + r.mu.Lock() + r.captures = append(r.captures, capture) + r.mu.Unlock() + r.wg.Add(1) + go func(client, server stdnet.Conn, cap *tcpCapture) { + defer r.wg.Done() + defer client.Close() + defer server.Close() + var inner sync.WaitGroup + inner.Add(2) + go func() { + defer inner.Done() + _, _ = io.Copy(server, io.TeeReader(client, &captureWriter{capture: cap, dir: "c2s"})) + if tcp, ok := server.(*stdnet.TCPConn); ok { + _ = tcp.CloseWrite() + } + }() + go func() { + defer inner.Done() + _, _ = io.Copy(client, io.TeeReader(server, &captureWriter{capture: cap, dir: "s2c"})) + if tcp, ok := client.(*stdnet.TCPConn); ok { + _ = tcp.CloseWrite() + } + }() + inner.Wait() + }(conn, targetConn, capture) + } + }() + return r +} + +func (r *tcpRelay) Close() { + close(r.stopCh) + _ = r.listener.Close() + r.wg.Wait() +} + +func (r *tcpRelay) Snapshots() []*tcpCapture { + r.mu.Lock() + defer r.mu.Unlock() + out := make([]*tcpCapture, 0, len(r.captures)) + for _, capture := range r.captures { + out = append(out, capture) + } + return out +} + +func (c *tcpCapture) snapshot() ([]byte, []byte) { + c.mu.Lock() + defer c.mu.Unlock() + return append([]byte{}, c.c2s...), append([]byte{}, c.s2c...) +} + +type captureWriter struct { + capture *tcpCapture + dir string +} + +func (w *captureWriter) Write(p []byte) (int, error) { + w.capture.mu.Lock() + defer w.capture.mu.Unlock() + if w.dir == "c2s" { + w.capture.c2s = append(w.capture.c2s, p...) + } else { + w.capture.s2c = append(w.capture.s2c, p...) + } + return len(p), nil +} + +func startUDPRelay(t *testing.T, listenPort, targetPort int) *udpRelay { + t.Helper() + conn, err := stdnet.ListenPacket("udp", fmt.Sprintf("127.0.0.1:%d", listenPort)) + if err != nil { + t.Fatal(err) + } + targetAddr := &stdnet.UDPAddr{IP: stdnet.IPv4(127, 0, 0, 1), Port: targetPort} + r := &udpRelay{ + conn: conn, + target: targetAddr, + stopCh: make(chan struct{}), + } + r.wg.Add(1) + go func() { + defer r.wg.Done() + buf := make([]byte, 64*1024) + for { + n, addr, err := conn.ReadFrom(buf) + if err != nil { + select { + case <-r.stopCh: + return + default: + } + return + } + payload := append([]byte{}, buf[:n]...) + udpAddr := addr.(*stdnet.UDPAddr) + if udpAddr.IP.Equal(r.target.IP) && udpAddr.Port == r.target.Port { + r.captureMu.Lock() + r.s2c = append(r.s2c, payload) + r.captureMu.Unlock() + r.clientMu.Lock() + client := r.client + r.clientMu.Unlock() + if client != nil { + _, _ = conn.WriteTo(payload, client) + } + continue + } + r.clientMu.Lock() + r.client = udpAddr + r.clientMu.Unlock() + r.captureMu.Lock() + r.c2s = append(r.c2s, payload) + r.captureMu.Unlock() + _, _ = conn.WriteTo(payload, r.target) + } + }() + return r +} + +func (r *udpRelay) Close() { + close(r.stopCh) + _ = r.conn.Close() + r.wg.Wait() +} + +func (r *udpRelay) Snapshots() ([][]byte, [][]byte) { + r.captureMu.Lock() + defer r.captureMu.Unlock() + c2s := make([][]byte, 0, len(r.c2s)) + s2c := make([][]byte, 0, len(r.s2c)) + for _, packet := range r.c2s { + c2s = append(c2s, append([]byte{}, packet...)) + } + for _, packet := range r.s2c { + s2c = append(s2c, append([]byte{}, packet...)) + } + return c2s, s2c +} + +type xorEchoServer struct { + ln stdnet.Listener + wg sync.WaitGroup +} + +func startXOREchoServer(t *testing.T) *xorEchoServer { + t.Helper() + ln, err := stdnet.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + s := &xorEchoServer{ln: ln} + s.wg.Add(1) + go func() { + defer s.wg.Done() + for { + conn, err := ln.Accept() + if err != nil { + return + } + s.wg.Add(1) + go func(c stdnet.Conn) { + defer s.wg.Done() + defer c.Close() + buf := make([]byte, 4096) + for { + n, err := c.Read(buf) + if err != nil { + return + } + for i := 0; i < n; i++ { + buf[i] ^= 'c' + } + if _, err := c.Write(buf[:n]); err != nil { + return + } + for i := 0; i < n; i++ { + buf[i] ^= 'c' + } + } + }(conn) + } + }() + return s +} + +func (s *xorEchoServer) Address() xnet.Address { + return xnet.IPAddress(s.ln.Addr().(*stdnet.TCPAddr).IP) +} + +func (s *xorEchoServer) Port() xnet.Port { + return xnet.Port(s.ln.Addr().(*stdnet.TCPAddr).Port) +} + +func (s *xorEchoServer) Close() { + _ = s.ln.Close() + s.wg.Wait() +} + +func startTLSEchoDecoy(t *testing.T, c *cert.Certificate) *tlsDecoy { + t.Helper() + certPEM, keyPEM := c.ToPEM() + keyPair, err := cryptotls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatal(err) + } + config := &cryptotls.Config{ + Certificates: []cryptotls.Certificate{keyPair}, + } + ln, err := stdnet.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + tlsLn := cryptotls.NewListener(ln, config) + d := &tlsDecoy{ + ln: tlsLn, + done: make(chan struct{}), + } + d.wg.Add(1) + go func() { + defer d.wg.Done() + for { + conn, err := tlsLn.Accept() + if err != nil { + return + } + d.wg.Add(1) + go func(c stdnet.Conn) { + defer d.wg.Done() + defer c.Close() + buf := make([]byte, 2048) + _, _ = c.Read(buf) + _, _ = c.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK")) + }(conn) + } + }() + return d +} + +func (d *tlsDecoy) Port() int { + return d.ln.Addr().(*stdnet.TCPAddr).Port +} + +func (d *tlsDecoy) Close() { + _ = d.ln.Close() + d.wg.Wait() +} + +func exerciseTCPClient(t *testing.T, port int, payloadSize int) { + t.Helper() + if err := exerciseTCPClientErr(t, port, payloadSize); err != nil { + t.Fatal(err) + } +} + +func exerciseTCPClientErr(t *testing.T, port int, payloadSize int) error { + conn := waitTCPConn(t, port, 10*time.Second) + defer conn.Close() + payload := make([]byte, payloadSize) + if _, err := rand.Read(payload); err != nil { + return err + } + offset := 0 + for offset < len(payload) { + chunk := 1024 + if remain := len(payload) - offset; remain < chunk { + chunk = remain + } + part := payload[offset : offset+chunk] + if _, err := conn.Write(part); err != nil { + return err + } + resp := make([]byte, chunk) + if _, err := io.ReadFull(conn, resp); err != nil { + return err + } + for i := range part { + if resp[i] != (part[i] ^ 'c') { + return fmt.Errorf("unexpected xor response at offset %d", offset+i) + } + } + offset += chunk + } + return nil +} + +func firstChunkLen(chunks [][]byte) int { + if len(chunks) == 0 { + return 0 + } + return len(chunks[0]) +} + +func waitTCPConn(t *testing.T, port int, timeout time.Duration) stdnet.Conn { + t.Helper() + deadline := time.Now().Add(timeout) + for { + conn, err := stdnet.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 500*time.Millisecond) + if err == nil { + return conn + } + if time.Now().After(deadline) { + t.Fatal(err) + } + time.Sleep(100 * time.Millisecond) + } +} + +func mustX25519Keypair(t *testing.T) ([]byte, []byte) { + t.Helper() + priv, err := ecdh.X25519().GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + return priv.Bytes(), priv.PublicKey().Bytes() +} + +func mustDecodeHex(t *testing.T, s string) []byte { + t.Helper() + out := make([]byte, len(s)/2) + if _, err := hex.Decode(out, []byte(s)); err != nil { + t.Fatal(err) + } + return out +} + +func init() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-ch + os.Exit(130) + }() +} diff --git a/transport/internet/finalmask/sudoku/table.go b/transport/internet/finalmask/sudoku/table.go new file mode 100644 index 00000000..6396d864 --- /dev/null +++ b/transport/internet/finalmask/sudoku/table.go @@ -0,0 +1,580 @@ +package sudoku + +import ( + crypto_rand "crypto/rand" + "crypto/sha256" + "encoding/binary" + "fmt" + "math/bits" + "math/rand" + "sort" + "strings" + "sync" + "time" +) + +type table struct { + encode [256][][4]byte + decode map[uint32]byte + layout *byteLayout +} + +type tableCacheKey struct { + password string + ascii string + customTable string +} + +var ( + tableCache sync.Map + tableSetCache sync.Map + + basePatternsOnce sync.Once + basePatterns [][][4]byte + basePatternsErr error +) + +type byteLayout struct { + hintMask byte + hintValue byte + padMarker byte + paddingPool []byte + encodeHint func(group byte) byte + encodeGroup func(group byte) byte + decodeGroup func(b byte) (byte, bool) +} + +func (l *byteLayout) isHint(b byte) bool { + if (b & l.hintMask) == l.hintValue { + return true + } + // ASCII layout maps 0x7f to '\n' to avoid DEL on the wire. + return l.hintMask == 0x40 && b == '\n' +} + +func getTable(config *Config) (*table, error) { + tables, err := getTables(config) + if err != nil { + return nil, err + } + if len(tables) == 0 { + return nil, fmt.Errorf("empty sudoku table set") + } + return tables[0], nil +} + +func getTables(config *Config) ([]*table, error) { + if config == nil { + return nil, fmt.Errorf("nil sudoku config") + } + + mode, err := normalizeASCII(config.GetAscii()) + if err != nil { + return nil, err + } + + patterns, err := normalizedCustomPatterns(config, mode) + if err != nil { + return nil, err + } + + cacheKey := tableCacheKey{ + password: config.GetPassword(), + ascii: mode, + customTable: strings.Join(patterns, "\x00"), + } + if cached, ok := tableSetCache.Load(cacheKey); ok { + return cached.([]*table), nil + } + + tables := make([]*table, 0, len(patterns)) + for _, pattern := range patterns { + layout, err := resolveLayout(mode, pattern) + if err != nil { + return nil, err + } + t, err := buildTable(config.GetPassword(), layout) + if err != nil { + return nil, err + } + tables = append(tables, t) + } + + actual, _ := tableSetCache.LoadOrStore(cacheKey, tables) + return actual.([]*table), nil +} + +func normalizedCustomPatterns(config *Config, mode string) ([]string, error) { + if config == nil { + return []string{""}, nil + } + if mode == "prefer_ascii" { + return []string{""}, nil + } + + rawPatterns := config.GetCustomTables() + if len(rawPatterns) == 0 { + rawPatterns = []string{config.GetCustomTable()} + } + + patterns := make([]string, 0, len(rawPatterns)) + seen := make(map[string]struct{}, len(rawPatterns)) + for _, raw := range rawPatterns { + pattern := strings.TrimSpace(raw) + if pattern != "" { + var err error + pattern, err = normalizeCustomTable(pattern) + if err != nil { + return nil, err + } + } + if _, ok := seen[pattern]; ok { + continue + } + seen[pattern] = struct{}{} + patterns = append(patterns, pattern) + } + + if len(patterns) == 0 { + return []string{""}, nil + } + + return patterns, nil +} + +func normalizedPadding(config *Config) (int, int) { + if config == nil { + return 0, 0 + } + + pMin := int(config.GetPaddingMin()) + pMax := int(config.GetPaddingMax()) + + if pMin > 100 { + pMin = 100 + } + if pMax > 100 { + pMax = 100 + } + if pMax < pMin { + pMax = pMin + } + return pMin, pMax +} + +func normalizeASCII(mode string) (string, error) { + switch strings.ToLower(strings.TrimSpace(mode)) { + case "", "entropy", "prefer_entropy": + return "prefer_entropy", nil + case "ascii", "prefer_ascii": + return "prefer_ascii", nil + default: + return "", fmt.Errorf("invalid sudoku ascii mode: %s", mode) + } +} + +func normalizeCustomTable(pattern string) (string, error) { + cleaned := strings.ToLower(strings.TrimSpace(pattern)) + cleaned = strings.ReplaceAll(cleaned, " ", "") + if len(cleaned) != 8 { + return "", fmt.Errorf("customTable must be 8 chars, got %d", len(cleaned)) + } + + var xCount, pCount, vCount int + for _, ch := range cleaned { + switch ch { + case 'x': + xCount++ + case 'p': + pCount++ + case 'v': + vCount++ + default: + return "", fmt.Errorf("customTable has invalid char %q", ch) + } + } + if xCount != 2 || pCount != 2 || vCount != 4 { + return "", fmt.Errorf("customTable must contain exactly 2 x, 2 p and 4 v") + } + return cleaned, nil +} + +func resolveLayout(mode, customTable string) (*byteLayout, error) { + if mode == "prefer_ascii" { + return asciiLayout(), nil + } + + if customTable != "" { + return customLayout(customTable) + } + return entropyLayout(), nil +} + +func asciiLayout() *byteLayout { + padding := make([]byte, 0, 32) + for i := 0; i < 32; i++ { + padding = append(padding, byte(0x20+i)) + } + + encodeGroup := func(group byte) byte { + b := byte(0x40 | (group & 0x3f)) + if b == 0x7f { + return '\n' + } + return b + } + + return &byteLayout{ + hintMask: 0x40, + hintValue: 0x40, + padMarker: 0x3f, + paddingPool: padding, + encodeHint: encodeGroup, + encodeGroup: encodeGroup, + decodeGroup: func(b byte) (byte, bool) { + if b == '\n' { + return 0x3f, true + } + if (b & 0x40) == 0 { + return 0, false + } + return b & 0x3f, true + }, + } +} + +func entropyLayout() *byteLayout { + padding := make([]byte, 0, 16) + for i := 0; i < 8; i++ { + padding = append(padding, byte(0x80+i), byte(0x10+i)) + } + + encodeGroup := func(group byte) byte { + v := group & 0x3f + return ((v & 0x30) << 1) | (v & 0x0f) + } + + return &byteLayout{ + hintMask: 0x90, + hintValue: 0x00, + padMarker: 0x80, + paddingPool: padding, + encodeHint: encodeGroup, + encodeGroup: encodeGroup, + decodeGroup: func(b byte) (byte, bool) { + if (b & 0x90) != 0 { + return 0, false + } + return ((b >> 1) & 0x30) | (b & 0x0f), true + }, + } +} + +func customLayout(pattern string) (*byteLayout, error) { + pattern, err := normalizeCustomTable(pattern) + if err != nil { + return nil, err + } + + var xBits, pBits, vBits []uint8 + for i, c := range pattern { + bit := uint8(7 - i) + switch c { + case 'x': + xBits = append(xBits, bit) + case 'p': + pBits = append(pBits, bit) + case 'v': + vBits = append(vBits, bit) + } + } + + xMask := byte(0) + for _, bit := range xBits { + xMask |= 1 << bit + } + + encodeGroupWithDropX := func(group byte, dropX int) byte { + out := xMask + if dropX >= 0 { + out &^= 1 << xBits[dropX] + } + + val := (group >> 4) & 0x03 + pos := group & 0x0f + + if (val & 0x02) != 0 { + out |= 1 << pBits[0] + } + if (val & 0x01) != 0 { + out |= 1 << pBits[1] + } + for i, bit := range vBits { + if (pos>>(3-uint8(i)))&0x01 == 1 { + out |= 1 << bit + } + } + + return out + } + + paddingSet := make(map[byte]struct{}, 64) + padding := make([]byte, 0, 64) + for drop := range xBits { + for val := byte(0); val < 4; val++ { + for pos := byte(0); pos < 16; pos++ { + group := (val << 4) | pos + b := encodeGroupWithDropX(group, drop) + if bits.OnesCount8(b) >= 5 { + if _, exists := paddingSet[b]; !exists { + paddingSet[b] = struct{}{} + padding = append(padding, b) + } + } + } + } + } + sort.Slice(padding, func(i, j int) bool { return padding[i] < padding[j] }) + if len(padding) == 0 { + return nil, fmt.Errorf("customTable produced empty padding pool") + } + + decodeGroup := func(b byte) (byte, bool) { + if (b & xMask) != xMask { + return 0, false + } + + var val, pos byte + if b&(1< in[1] { + in[0], in[1] = in[1], in[0] + } + if in[2] > in[3] { + in[2], in[3] = in[3], in[2] + } + if in[0] > in[2] { + in[0], in[2] = in[2], in[0] + } + if in[1] > in[3] { + in[1], in[3] = in[3], in[1] + } + if in[1] > in[2] { + in[1], in[2] = in[2], in[1] + } + return in +} + +func newSeededRand() *rand.Rand { + seed := time.Now().UnixNano() + var seedBytes [8]byte + if _, err := crypto_rand.Read(seedBytes[:]); err == nil { + seed = int64(binary.BigEndian.Uint64(seedBytes[:])) + } + return rand.New(rand.NewSource(seed)) +} diff --git a/transport/internet/finalmask/udp_test.go b/transport/internet/finalmask/udp_test.go index 49cdd923..f2f18f2a 100644 --- a/transport/internet/finalmask/udp_test.go +++ b/transport/internet/finalmask/udp_test.go @@ -2,10 +2,13 @@ package finalmask_test import ( "bytes" + "io" "net" + "sync/atomic" "testing" "time" + "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/transport/internet/finalmask" "github.com/xtls/xray-core/transport/internet/finalmask/header/custom" "github.com/xtls/xray-core/transport/internet/finalmask/header/dns" @@ -16,6 +19,7 @@ import ( "github.com/xtls/xray-core/transport/internet/finalmask/mkcp/aes128gcm" "github.com/xtls/xray-core/transport/internet/finalmask/mkcp/original" "github.com/xtls/xray-core/transport/internet/finalmask/salamander" + "github.com/xtls/xray-core/transport/internet/finalmask/sudoku" ) func mustSendRecv( @@ -49,39 +53,93 @@ func mustSendRecv( } type layerMask struct { - name string - mask finalmask.Udpmask + name string + mask finalmask.Udpmask + layers int +} + +type countingConn struct { + net.Conn + written atomic.Int64 +} + +func (c *countingConn) Write(p []byte) (int, error) { + n, err := c.Conn.Write(p) + c.written.Add(int64(n)) + return n, err +} + +func (c *countingConn) Written() int64 { + return c.written.Load() } func TestPacketConnReadWrite(t *testing.T) { cases := []layerMask{ { - name: "aes128gcm", - mask: &aes128gcm.Config{Password: "123"}, + name: "aes128gcm", + mask: &aes128gcm.Config{Password: "123"}, + layers: 2, }, { - name: "original", - mask: &original.Config{}, + name: "original", + mask: &original.Config{}, + layers: 2, }, { - name: "dns", - mask: &dns.Config{Domain: "www.baidu.com"}, + name: "dns", + mask: &dns.Config{Domain: "www.baidu.com"}, + layers: 2, }, { - name: "srtp", - mask: &srtp.Config{}, + name: "srtp", + mask: &srtp.Config{}, + layers: 2, }, { - name: "utp", - mask: &utp.Config{}, + name: "utp", + mask: &utp.Config{}, + layers: 2, }, { - name: "wechat", - mask: &wechat.Config{}, + name: "wechat", + mask: &wechat.Config{}, + layers: 2, }, { - name: "wireguard", - mask: &wireguard.Config{}, + name: "wireguard", + mask: &wireguard.Config{}, + layers: 2, + }, + { + name: "salamander", + mask: &salamander.Config{Password: "1234"}, + layers: 2, + }, + { + name: "sudoku-prefer-ascii", + mask: &sudoku.Config{ + Password: "sudoku-mask", + Ascii: "prefer_ascii", + }, + layers: 1, + }, + { + name: "sudoku-custom-table", + mask: &sudoku.Config{ + Password: "sudoku-mask", + Ascii: "prefer_entropy", + CustomTable: "xpxvvpvv", + }, + layers: 1, + }, + { + name: "sudoku-custom-tables", + mask: &sudoku.Config{ + Password: "sudoku-mask", + Ascii: "prefer_entropy", + CustomTables: []string{"xpxvvpvv", "vxpvxvvp"}, + }, + layers: 1, }, { name: "custom", @@ -103,18 +161,27 @@ func TestPacketConnReadWrite(t *testing.T) { }, }, }, + layers: 1, }, { - name: "salamander", - mask: &salamander.Config{Password: "1234"}, + name: "salamander-single", + mask: &salamander.Config{Password: "1234"}, + layers: 1, }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { mask := c.mask - - maskManager := finalmask.NewUdpmaskManager([]finalmask.Udpmask{mask, mask}) + layers := c.layers + if layers <= 0 { + layers = 1 + } + masks := make([]finalmask.Udpmask, 0, layers) + for i := 0; i < layers; i++ { + masks = append(masks, mask) + } + maskManager := finalmask.NewUdpmaskManager(masks) client, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { @@ -147,3 +214,419 @@ func TestPacketConnReadWrite(t *testing.T) { }) } } + +func TestSudokuBDD(t *testing.T) { + t.Run("GivenSudokuTCPMask_WhenRoundTripWithAsciiPreference_ThenPayloadMatches", func(t *testing.T) { + cfg := &sudoku.Config{ + Password: "sudoku-tcp", + Ascii: "prefer_ascii", + } + + clientRaw, serverRaw := net.Pipe() + defer clientRaw.Close() + defer serverRaw.Close() + + clientConn, err := cfg.WrapConnClient(clientRaw) + if err != nil { + t.Fatal(err) + } + serverConn, err := cfg.WrapConnServer(serverRaw) + if err != nil { + t.Fatal(err) + } + + send := bytes.Repeat([]byte("client->server"), 1024) + recv := make([]byte, len(send)) + + writeErr := make(chan error, 1) + go func() { + _, wErr := clientConn.Write(send) + writeErr <- wErr + }() + + if _, err := io.ReadFull(serverConn, recv); err != nil { + t.Fatal(err) + } + if err := <-writeErr; err != nil { + t.Fatal(err) + } + if !bytes.Equal(send, recv) { + t.Fatal("tcp sudoku payload mismatch") + } + }) + + t.Run("GivenSudokuTCPMask_WhenRoundTrip_ThenBothDirectionsMatch", func(t *testing.T) { + cfg := &sudoku.Config{ + Password: "sudoku-packed", + Ascii: "prefer_ascii", + PaddingMin: 0, + PaddingMax: 0, + } + + clientRaw, serverRaw := net.Pipe() + defer clientRaw.Close() + defer serverRaw.Close() + + clientConn, err := cfg.WrapConnClient(clientRaw) + if err != nil { + t.Fatal(err) + } + serverConn, err := cfg.WrapConnServer(serverRaw) + if err != nil { + t.Fatal(err) + } + + clientToServer := bytes.Repeat([]byte("client-packed->server"), 257) + serverToClient := bytes.Repeat([]byte("server-packed->client"), 263) + + c2sRecv := make([]byte, len(clientToServer)) + c2sErr := make(chan error, 1) + go func() { + _, err := clientConn.Write(clientToServer) + c2sErr <- err + }() + if _, err := io.ReadFull(serverConn, c2sRecv); err != nil { + t.Fatal(err) + } + if err := <-c2sErr; err != nil { + t.Fatal(err) + } + if !bytes.Equal(clientToServer, c2sRecv) { + t.Fatal("tcp client->server payload mismatch") + } + + s2cRecv := make([]byte, len(serverToClient)) + s2cErr := make(chan error, 1) + go func() { + _, err := serverConn.Write(serverToClient) + s2cErr <- err + }() + if _, err := io.ReadFull(clientConn, s2cRecv); err != nil { + t.Fatal(err) + } + if err := <-s2cErr; err != nil { + t.Fatal(err) + } + if !bytes.Equal(serverToClient, s2cRecv) { + t.Fatal("tcp server->client payload mismatch") + } + }) + + t.Run("GivenSudokuTCPMask_WhenServerWritesDownlink_ThenWireBytesAreReduced", func(t *testing.T) { + payload := bytes.Repeat([]byte("0123456789abcdef"), 192) // 3072 bytes, divisible by 3. + + countWireBytes := func(wrapServer func(net.Conn, *sudoku.Config) (net.Conn, error), cfg *sudoku.Config) int64 { + t.Helper() + + clientRaw, serverRaw := net.Pipe() + watchedServerRaw := &countingConn{Conn: serverRaw} + + clientConn, err := cfg.WrapConnClient(clientRaw) + if err != nil { + t.Fatal(err) + } + serverConn, err := wrapServer(watchedServerRaw, cfg) + if err != nil { + t.Fatal(err) + } + + readErr := make(chan error, 1) + go func() { + _, err := io.CopyN(io.Discard, clientConn, int64(len(payload))) + readErr <- err + }() + + if _, err := serverConn.Write(payload); err != nil { + t.Fatal(err) + } + if err := <-readErr; err != nil { + t.Fatal(err) + } + + _ = clientConn.Close() + _ = serverConn.Close() + return watchedServerRaw.Written() + } + + pureUplinkPackedDownlink := &sudoku.Config{ + Password: "sudoku-bandwidth", + Ascii: "prefer_entropy", + PaddingMin: 0, + PaddingMax: 0, + } + packedDownlinkBytes := countWireBytes(func(raw net.Conn, cfg *sudoku.Config) (net.Conn, error) { + return cfg.WrapConnServer(raw) + }, pureUplinkPackedDownlink) + legacyPureBytes := countWireBytes(func(raw net.Conn, cfg *sudoku.Config) (net.Conn, error) { + return sudoku.NewTCPConn(raw, cfg) + }, pureUplinkPackedDownlink) + + if packedDownlinkBytes >= legacyPureBytes { + t.Fatalf("expected default packed downlink bytes < legacy pure bytes, got packed=%d pure=%d", packedDownlinkBytes, legacyPureBytes) + } + }) + + t.Run("GivenSudokuMultiTableTCPMask_WhenRoundTrip_ThenPayloadMatches", func(t *testing.T) { + cfg := &sudoku.Config{ + Password: "sudoku-multi-tcp", + Ascii: "prefer_entropy", + CustomTables: []string{"xpxvvpvv", "vxpvxvvp"}, + } + + clientRaw, serverRaw := net.Pipe() + defer clientRaw.Close() + defer serverRaw.Close() + + clientConn, err := cfg.WrapConnClient(clientRaw) + if err != nil { + t.Fatal(err) + } + serverConn, err := cfg.WrapConnServer(serverRaw) + if err != nil { + t.Fatal(err) + } + + send := bytes.Repeat([]byte("rotate-table"), 513) + recv := make([]byte, len(send)) + + writeErr := make(chan error, 1) + go func() { + _, wErr := clientConn.Write(send) + writeErr <- wErr + }() + + if _, err := io.ReadFull(serverConn, recv); err != nil { + t.Fatal(err) + } + if err := <-writeErr; err != nil { + t.Fatal(err) + } + if !bytes.Equal(send, recv) { + t.Fatal("multi-table tcp sudoku payload mismatch") + } + }) + + t.Run("GivenSudokuMultiTableTCPMask_WhenPackedDownlink_ThenPayloadMatches", func(t *testing.T) { + cfg := &sudoku.Config{ + Password: "sudoku-multi-packed", + Ascii: "prefer_entropy", + CustomTables: []string{"xpxvvpvv", "vxpvxvvp"}, + PaddingMin: 0, + PaddingMax: 0, + } + + clientRaw, serverRaw := net.Pipe() + defer clientRaw.Close() + defer serverRaw.Close() + + clientConn, err := cfg.WrapConnClient(clientRaw) + if err != nil { + t.Fatal(err) + } + serverConn, err := cfg.WrapConnServer(serverRaw) + if err != nil { + t.Fatal(err) + } + + send := bytes.Repeat([]byte("packed-rotate"), 257) + recv := make([]byte, len(send)) + + writeErr := make(chan error, 1) + go func() { + _, wErr := clientConn.Write(send) + writeErr <- wErr + }() + + if _, err := io.ReadFull(serverConn, recv); err != nil { + t.Fatal(err) + } + if err := <-writeErr; err != nil { + t.Fatal(err) + } + if !bytes.Equal(send, recv) { + t.Fatal("multi-table tcp sudoku payload mismatch") + } + }) + + t.Run("GivenSudokuUDPMask_WhenNotInnermost_ThenWrapFails", func(t *testing.T) { + cfg := &sudoku.Config{Password: "sudoku-udp"} + raw, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer raw.Close() + + if _, err := cfg.WrapPacketConnClient(raw, 0, 1); err == nil { + t.Fatal("expected innermost check failure") + } + }) + + t.Run("GivenSudokuMultiTableUDPMask_WhenClientSendsMultipleDatagrams_ThenPayloadMatches", func(t *testing.T) { + cfg := &sudoku.Config{ + Password: "sudoku-udp-multi", + Ascii: "prefer_entropy", + CustomTables: []string{"xpxvvpvv", "vxpvxvvp"}, + PaddingMin: 0, + PaddingMax: 0, + } + maskManager := finalmask.NewUdpmaskManager([]finalmask.Udpmask{cfg}) + + clientRaw, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer clientRaw.Close() + + serverRaw, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer serverRaw.Close() + + client, err := maskManager.WrapPacketConnClient(clientRaw) + if err != nil { + t.Fatal(err) + } + server, err := maskManager.WrapPacketConnServer(serverRaw) + if err != nil { + t.Fatal(err) + } + + _ = client.SetDeadline(time.Now().Add(2 * time.Second)) + _ = server.SetDeadline(time.Now().Add(2 * time.Second)) + + mustSendRecv(t, client, server, []byte("first-datagram")) + mustSendRecv(t, client, server, []byte("second-datagram")) + mustSendRecv(t, client, server, []byte("third-datagram")) + }) + + t.Run("GivenSudokuTCPMask_WhenCloseWriteIsCalled_ThenEOFPropagates", func(t *testing.T) { + cfg := &sudoku.Config{ + Password: "sudoku-closewrite", + Ascii: "prefer_ascii", + PaddingMin: 0, + PaddingMax: 0, + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + acceptCh := make(chan net.Conn, 1) + errCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + errCh <- err + return + } + acceptCh <- conn + }() + + clientRaw, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer clientRaw.Close() + + var serverRaw net.Conn + select { + case serverRaw = <-acceptCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(2 * time.Second): + t.Fatal("accept timeout") + } + defer serverRaw.Close() + + clientConn, err := cfg.WrapConnClient(clientRaw) + if err != nil { + t.Fatal(err) + } + serverConn, err := cfg.WrapConnServer(serverRaw) + if err != nil { + t.Fatal(err) + } + + closeWriter, ok := clientConn.(interface{ CloseWrite() error }) + if !ok { + t.Fatalf("wrapped conn does not expose CloseWrite: %T", clientConn) + } + + writeErr := make(chan error, 1) + go func() { + if _, err := clientConn.Write([]byte("closewrite")); err != nil { + writeErr <- err + return + } + writeErr <- closeWriter.CloseWrite() + }() + + buf := make([]byte, len("closewrite")) + if _, err := io.ReadFull(serverConn, buf); err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, []byte("closewrite")) { + t.Fatal("unexpected payload before closewrite") + } + if err := <-writeErr; err != nil { + t.Fatal(err) + } + + one := make([]byte, 1) + n, err := serverConn.Read(one) + if n != 0 || err != io.EOF { + t.Fatalf("expected EOF after CloseWrite, got n=%d err=%v", n, err) + } + }) + + t.Run("GivenSudokuTCPMask_WhenProxyUnwrapRawConn_ThenMaskConnIsRetained", func(t *testing.T) { + cfg := &sudoku.Config{ + Password: "sudoku-unwrap", + Ascii: "prefer_entropy", + } + + clientRaw, serverRaw := net.Pipe() + defer clientRaw.Close() + defer serverRaw.Close() + + clientConn, err := cfg.WrapConnClient(clientRaw) + if err != nil { + t.Fatal(err) + } + + unwrapped, readCounter, writeCounter := proxy.UnwrapRawConn(clientConn) + if readCounter != nil || writeCounter != nil { + t.Fatal("unexpected stat counters while unwrapping sudoku conn") + } + if unwrapped != clientConn { + t.Fatalf("expected sudoku conn to stay wrapped, got %T", unwrapped) + } + }) + + t.Run("GivenSudokuTCPMask_WhenProxyUnwrapRawConn_AfterDownlinkOptimization_ThenMaskConnIsRetained", func(t *testing.T) { + cfg := &sudoku.Config{ + Password: "sudoku-packed-unwrap", + Ascii: "prefer_entropy", + } + + clientRaw, serverRaw := net.Pipe() + defer clientRaw.Close() + defer serverRaw.Close() + + clientConn, err := cfg.WrapConnClient(clientRaw) + if err != nil { + t.Fatal(err) + } + + unwrapped, readCounter, writeCounter := proxy.UnwrapRawConn(clientConn) + if readCounter != nil || writeCounter != nil { + t.Fatal("unexpected stat counters while unwrapping sudoku conn") + } + if unwrapped != clientConn { + t.Fatalf("expected sudoku conn to stay wrapped, got %T", unwrapped) + } + }) +}