mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-05-08 14:13:22 +00:00
Finalmask: Add Sudoku (TCP & UDP) (#5685)
https://github.com/SUDOKU-ASCII/sudoku/issues/23#issuecomment-3859972396
This commit is contained in:
163
transport/internet/finalmask/sudoku/codec.go
Normal file
163
transport/internet/finalmask/sudoku/codec.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
var perm4 = [24][4]byte{
|
||||
{0, 1, 2, 3},
|
||||
{0, 1, 3, 2},
|
||||
{0, 2, 1, 3},
|
||||
{0, 2, 3, 1},
|
||||
{0, 3, 1, 2},
|
||||
{0, 3, 2, 1},
|
||||
{1, 0, 2, 3},
|
||||
{1, 0, 3, 2},
|
||||
{1, 2, 0, 3},
|
||||
{1, 2, 3, 0},
|
||||
{1, 3, 0, 2},
|
||||
{1, 3, 2, 0},
|
||||
{2, 0, 1, 3},
|
||||
{2, 0, 3, 1},
|
||||
{2, 1, 0, 3},
|
||||
{2, 1, 3, 0},
|
||||
{2, 3, 0, 1},
|
||||
{2, 3, 1, 0},
|
||||
{3, 0, 1, 2},
|
||||
{3, 0, 2, 1},
|
||||
{3, 1, 0, 2},
|
||||
{3, 1, 2, 0},
|
||||
{3, 2, 0, 1},
|
||||
{3, 2, 1, 0},
|
||||
}
|
||||
|
||||
type codec struct {
|
||||
tables []*table
|
||||
rng *rand.Rand
|
||||
paddingChance int
|
||||
tableIndex int
|
||||
}
|
||||
|
||||
func newCodec(tables []*table, pMin, pMax int) *codec {
|
||||
if len(tables) == 0 {
|
||||
tables = nil
|
||||
}
|
||||
rng := newSeededRand()
|
||||
return &codec{
|
||||
tables: tables,
|
||||
rng: rng,
|
||||
paddingChance: pickPaddingChance(rng, pMin, pMax),
|
||||
}
|
||||
}
|
||||
|
||||
func pickPaddingChance(rng *rand.Rand, pMin, pMax int) int {
|
||||
if pMin < 0 {
|
||||
pMin = 0
|
||||
}
|
||||
if pMax < pMin {
|
||||
pMax = pMin
|
||||
}
|
||||
if pMin > 100 {
|
||||
pMin = 100
|
||||
}
|
||||
if pMax > 100 {
|
||||
pMax = 100
|
||||
}
|
||||
if pMax == pMin {
|
||||
return pMin
|
||||
}
|
||||
return pMin + rng.Intn(pMax-pMin+1)
|
||||
}
|
||||
|
||||
func (c *codec) shouldPad() bool {
|
||||
if c.paddingChance <= 0 {
|
||||
return false
|
||||
}
|
||||
if c.paddingChance >= 100 {
|
||||
return true
|
||||
}
|
||||
return c.rng.Intn(100) < c.paddingChance
|
||||
}
|
||||
|
||||
func (c *codec) currentTable() *table {
|
||||
if len(c.tables) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.tables[c.tableIndex%len(c.tables)]
|
||||
}
|
||||
|
||||
func (c *codec) randomPadding(t *table) byte {
|
||||
pool := t.layout.paddingPool
|
||||
return pool[c.rng.Intn(len(pool))]
|
||||
}
|
||||
|
||||
func (c *codec) encode(in []byte) ([]byte, error) {
|
||||
if len(in) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
out := make([]byte, 0, len(in)*6+8)
|
||||
for _, b := range in {
|
||||
t := c.currentTable()
|
||||
if t == nil {
|
||||
return nil, fmt.Errorf("sudoku table set missing")
|
||||
}
|
||||
if c.shouldPad() {
|
||||
out = append(out, c.randomPadding(t))
|
||||
}
|
||||
|
||||
enc := t.encode[b]
|
||||
if len(enc) == 0 {
|
||||
return nil, fmt.Errorf("sudoku encode table missing for byte %d", b)
|
||||
}
|
||||
|
||||
hints := enc[c.rng.Intn(len(enc))]
|
||||
perm := perm4[c.rng.Intn(len(perm4))]
|
||||
for _, idx := range perm {
|
||||
if c.shouldPad() {
|
||||
out = append(out, c.randomPadding(t))
|
||||
}
|
||||
out = append(out, hints[idx])
|
||||
}
|
||||
c.tableIndex++
|
||||
}
|
||||
|
||||
if c.shouldPad() {
|
||||
if t := c.currentTable(); t != nil {
|
||||
out = append(out, c.randomPadding(t))
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func decodeBytes(tables []*table, tableIndex *int, in []byte, hintBuf []byte, out []byte) ([]byte, []byte, error) {
|
||||
if len(tables) == 0 {
|
||||
return hintBuf, out, fmt.Errorf("sudoku table set missing")
|
||||
}
|
||||
for _, b := range in {
|
||||
t := tables[*tableIndex%len(tables)]
|
||||
if !t.layout.isHint(b) {
|
||||
continue
|
||||
}
|
||||
|
||||
hintBuf = append(hintBuf, b)
|
||||
if len(hintBuf) < 4 {
|
||||
continue
|
||||
}
|
||||
|
||||
keyBytes := sort4([4]byte{hintBuf[0], hintBuf[1], hintBuf[2], hintBuf[3]})
|
||||
key := packKey(keyBytes)
|
||||
decoded, ok := t.decode[key]
|
||||
if !ok {
|
||||
return hintBuf[:0], out, fmt.Errorf("invalid sudoku hint tuple")
|
||||
}
|
||||
|
||||
out = append(out, decoded)
|
||||
hintBuf = hintBuf[:0]
|
||||
*tableIndex++
|
||||
}
|
||||
|
||||
return hintBuf, out, nil
|
||||
}
|
||||
57
transport/internet/finalmask/sudoku/config.go
Normal file
57
transport/internet/finalmask/sudoku/config.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
)
|
||||
|
||||
func (c *Config) TCP() {
|
||||
}
|
||||
|
||||
func (c *Config) UDP() {
|
||||
}
|
||||
|
||||
// Sudoku in finalmask mode is a pure appearance transform with no standalone handshake.
|
||||
// TCP always keeps classic sudoku on uplink and uses packed downlink optimization on server writes.
|
||||
func (c *Config) WrapConnClient(raw net.Conn) (net.Conn, error) {
|
||||
return newPackedDirectionalConn(raw, c, true)
|
||||
}
|
||||
|
||||
func (c *Config) WrapConnServer(raw net.Conn) (net.Conn, error) {
|
||||
return newPackedDirectionalConn(raw, c, false)
|
||||
}
|
||||
|
||||
func newPackedDirectionalConn(raw net.Conn, config *Config, readPacked bool) (net.Conn, error) {
|
||||
pureReader, pureWriter, err := newPureReaderWriter(raw, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packedReader, packedWriter, err := newPackedReaderWriter(raw, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reader, writer := pureReader, pureWriter
|
||||
if readPacked {
|
||||
reader = packedReader
|
||||
} else {
|
||||
writer = packedWriter
|
||||
}
|
||||
|
||||
return newWrappedConn(raw, reader, writer), nil
|
||||
}
|
||||
|
||||
func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
|
||||
if level != levelCount {
|
||||
return nil, errors.New("sudoku udp mask must be the innermost mask in chain")
|
||||
}
|
||||
return NewUDPConn(raw, c)
|
||||
}
|
||||
|
||||
func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
|
||||
if level != levelCount {
|
||||
return nil, errors.New("sudoku udp mask must be the innermost mask in chain")
|
||||
}
|
||||
return NewUDPConn(raw, c)
|
||||
}
|
||||
170
transport/internet/finalmask/sudoku/config.pb.go
Normal file
170
transport/internet/finalmask/sudoku/config.pb.go
Normal file
@@ -0,0 +1,170 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.11
|
||||
// protoc v6.33.5
|
||||
// source: transport/internet/finalmask/sudoku/config.proto
|
||||
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"`
|
||||
Ascii string `protobuf:"bytes,2,opt,name=ascii,proto3" json:"ascii,omitempty"`
|
||||
CustomTable string `protobuf:"bytes,3,opt,name=custom_table,json=customTable,proto3" json:"custom_table,omitempty"`
|
||||
PaddingMin uint32 `protobuf:"varint,4,opt,name=padding_min,json=paddingMin,proto3" json:"padding_min,omitempty"`
|
||||
PaddingMax uint32 `protobuf:"varint,5,opt,name=padding_max,json=paddingMax,proto3" json:"padding_max,omitempty"`
|
||||
CustomTables []string `protobuf:"bytes,7,rep,name=custom_tables,json=customTables,proto3" json:"custom_tables,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *Config) Reset() {
|
||||
*x = Config{}
|
||||
mi := &file_transport_internet_finalmask_sudoku_config_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *Config) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Config) ProtoMessage() {}
|
||||
|
||||
func (x *Config) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_transport_internet_finalmask_sudoku_config_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
|
||||
func (*Config) Descriptor() ([]byte, []int) {
|
||||
return file_transport_internet_finalmask_sudoku_config_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *Config) GetPassword() string {
|
||||
if x != nil {
|
||||
return x.Password
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Config) GetAscii() string {
|
||||
if x != nil {
|
||||
return x.Ascii
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Config) GetCustomTable() string {
|
||||
if x != nil {
|
||||
return x.CustomTable
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Config) GetPaddingMin() uint32 {
|
||||
if x != nil {
|
||||
return x.PaddingMin
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *Config) GetPaddingMax() uint32 {
|
||||
if x != nil {
|
||||
return x.PaddingMax
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *Config) GetCustomTables() []string {
|
||||
if x != nil {
|
||||
return x.CustomTables
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_transport_internet_finalmask_sudoku_config_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_transport_internet_finalmask_sudoku_config_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"0transport/internet/finalmask/sudoku/config.proto\x12(xray.transport.internet.finalmask.sudoku\"\xc4\x01\n" +
|
||||
"\x06Config\x12\x1a\n" +
|
||||
"\bpassword\x18\x01 \x01(\tR\bpassword\x12\x14\n" +
|
||||
"\x05ascii\x18\x02 \x01(\tR\x05ascii\x12!\n" +
|
||||
"\fcustom_table\x18\x03 \x01(\tR\vcustomTable\x12\x1f\n" +
|
||||
"\vpadding_min\x18\x04 \x01(\rR\n" +
|
||||
"paddingMin\x12\x1f\n" +
|
||||
"\vpadding_max\x18\x05 \x01(\rR\n" +
|
||||
"paddingMax\x12#\n" +
|
||||
"\rcustom_tables\x18\a \x03(\tR\fcustomTablesB\x9a\x01\n" +
|
||||
",com.xray.transport.internet.finalmask.sudokuP\x01Z=github.com/xtls/xray-core/transport/internet/finalmask/sudoku\xaa\x02(Xray.Transport.Internet.Finalmask.Sudokub\x06proto3"
|
||||
|
||||
var (
|
||||
file_transport_internet_finalmask_sudoku_config_proto_rawDescOnce sync.Once
|
||||
file_transport_internet_finalmask_sudoku_config_proto_rawDescData []byte
|
||||
)
|
||||
|
||||
func file_transport_internet_finalmask_sudoku_config_proto_rawDescGZIP() []byte {
|
||||
file_transport_internet_finalmask_sudoku_config_proto_rawDescOnce.Do(func() {
|
||||
file_transport_internet_finalmask_sudoku_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_sudoku_config_proto_rawDesc), len(file_transport_internet_finalmask_sudoku_config_proto_rawDesc)))
|
||||
})
|
||||
return file_transport_internet_finalmask_sudoku_config_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_transport_internet_finalmask_sudoku_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_transport_internet_finalmask_sudoku_config_proto_goTypes = []any{
|
||||
(*Config)(nil), // 0: xray.transport.internet.finalmask.sudoku.Config
|
||||
}
|
||||
var file_transport_internet_finalmask_sudoku_config_proto_depIdxs = []int32{
|
||||
0, // [0:0] is the sub-list for method output_type
|
||||
0, // [0:0] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_transport_internet_finalmask_sudoku_config_proto_init() }
|
||||
func file_transport_internet_finalmask_sudoku_config_proto_init() {
|
||||
if File_transport_internet_finalmask_sudoku_config_proto != nil {
|
||||
return
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_sudoku_config_proto_rawDesc), len(file_transport_internet_finalmask_sudoku_config_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 1,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_transport_internet_finalmask_sudoku_config_proto_goTypes,
|
||||
DependencyIndexes: file_transport_internet_finalmask_sudoku_config_proto_depIdxs,
|
||||
MessageInfos: file_transport_internet_finalmask_sudoku_config_proto_msgTypes,
|
||||
}.Build()
|
||||
File_transport_internet_finalmask_sudoku_config_proto = out.File
|
||||
file_transport_internet_finalmask_sudoku_config_proto_goTypes = nil
|
||||
file_transport_internet_finalmask_sudoku_config_proto_depIdxs = nil
|
||||
}
|
||||
16
transport/internet/finalmask/sudoku/config.proto
Normal file
16
transport/internet/finalmask/sudoku/config.proto
Normal file
@@ -0,0 +1,16 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package xray.transport.internet.finalmask.sudoku;
|
||||
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Sudoku";
|
||||
option go_package = "github.com/xtls/xray-core/transport/internet/finalmask/sudoku";
|
||||
option java_package = "com.xray.transport.internet.finalmask.sudoku";
|
||||
option java_multiple_files = true;
|
||||
|
||||
message Config {
|
||||
string password = 1;
|
||||
string ascii = 2;
|
||||
string custom_table = 3;
|
||||
uint32 padding_min = 4;
|
||||
uint32 padding_max = 5;
|
||||
repeated string custom_tables = 7;
|
||||
}
|
||||
212
transport/internet/finalmask/sudoku/conn_tcp.go
Normal file
212
transport/internet/finalmask/sudoku/conn_tcp.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/xtls/xray-core/transport/internet/finalmask"
|
||||
)
|
||||
|
||||
const ioBufferSize = 32 * 1024
|
||||
|
||||
var _ finalmask.TcpMaskConn = (*wrappedConn)(nil)
|
||||
|
||||
type streamDecoder interface {
|
||||
decodeChunk(in []byte, pending []byte) ([]byte, error)
|
||||
reset()
|
||||
}
|
||||
|
||||
type streamReader struct {
|
||||
reader *bufio.Reader
|
||||
rawBuf []byte
|
||||
pending []byte
|
||||
decode streamDecoder
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newStreamReader(raw net.Conn, decode streamDecoder) io.Reader {
|
||||
return &streamReader{
|
||||
reader: bufio.NewReaderSize(raw, ioBufferSize),
|
||||
rawBuf: make([]byte, ioBufferSize),
|
||||
pending: make([]byte, 0, 4096),
|
||||
decode: decode,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *streamReader) Read(p []byte) (int, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if n, ok := drainPending(p, &r.pending); ok {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
for len(r.pending) == 0 {
|
||||
nr, rErr := r.reader.Read(r.rawBuf)
|
||||
if nr > 0 {
|
||||
var dErr error
|
||||
r.pending, dErr = r.decode.decodeChunk(r.rawBuf[:nr], r.pending)
|
||||
if dErr != nil {
|
||||
return 0, dErr
|
||||
}
|
||||
}
|
||||
|
||||
if rErr != nil {
|
||||
if rErr == io.EOF {
|
||||
r.decode.reset()
|
||||
if len(r.pending) > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return 0, rErr
|
||||
}
|
||||
}
|
||||
|
||||
n, _ := drainPending(p, &r.pending)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
type streamWriter struct {
|
||||
conn net.Conn
|
||||
encode func([]byte) ([]byte, error)
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newStreamWriter(raw net.Conn, encode func([]byte) ([]byte, error)) io.Writer {
|
||||
return &streamWriter{
|
||||
conn: raw,
|
||||
encode: encode,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *streamWriter) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
encoded, err := w.encode(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := writeAll(w.conn, encoded); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
type wrappedConn struct {
|
||||
net.Conn
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
type closeWriteConn interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
func newWrappedConn(raw net.Conn, reader io.Reader, writer io.Writer) net.Conn {
|
||||
return &wrappedConn{
|
||||
Conn: raw,
|
||||
reader: reader,
|
||||
writer: writer,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *wrappedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *wrappedConn) Write(p []byte) (int, error) {
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
|
||||
func (c *wrappedConn) TcpMaskConn() {}
|
||||
|
||||
func (c *wrappedConn) RawConn() net.Conn {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *wrappedConn) Splice() bool {
|
||||
// Sudoku transforms the entire stream; bypassing it would disable masking.
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *wrappedConn) CloseWrite() error {
|
||||
if raw, ok := c.Conn.(closeWriteConn); ok {
|
||||
return raw.CloseWrite()
|
||||
}
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
func NewTCPConn(raw net.Conn, config *Config) (net.Conn, error) {
|
||||
reader, writer, err := newPureReaderWriter(raw, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newWrappedConn(raw, reader, writer), nil
|
||||
}
|
||||
|
||||
func newPureReaderWriter(raw net.Conn, config *Config) (io.Reader, io.Writer, error) {
|
||||
tables, err := getTables(config)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pMin, pMax := normalizedPadding(config)
|
||||
c := newCodec(tables, pMin, pMax)
|
||||
return newStreamReader(raw, newHintStreamDecoder(tables)), newStreamWriter(raw, c.encode), nil
|
||||
}
|
||||
|
||||
type hintStreamDecoder struct {
|
||||
tables []*table
|
||||
tableIndex int
|
||||
hintBuf []byte
|
||||
}
|
||||
|
||||
func newHintStreamDecoder(tables []*table) *hintStreamDecoder {
|
||||
return &hintStreamDecoder{
|
||||
tables: tables,
|
||||
hintBuf: make([]byte, 0, 4),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *hintStreamDecoder) decodeChunk(in []byte, pending []byte) ([]byte, error) {
|
||||
var err error
|
||||
d.hintBuf, pending, err = decodeBytes(d.tables, &d.tableIndex, in, d.hintBuf, pending)
|
||||
return pending, err
|
||||
}
|
||||
|
||||
func (d *hintStreamDecoder) reset() {}
|
||||
|
||||
func drainPending(p []byte, pending *[]byte) (int, bool) {
|
||||
if len(*pending) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
n := copy(p, *pending)
|
||||
if n >= len(*pending) {
|
||||
*pending = (*pending)[:0]
|
||||
return n, true
|
||||
}
|
||||
|
||||
remaining := len(*pending) - n
|
||||
copy(*pending, (*pending)[n:])
|
||||
*pending = (*pending)[:remaining]
|
||||
return n, true
|
||||
}
|
||||
|
||||
func writeAll(conn net.Conn, b []byte) error {
|
||||
for len(b) > 0 {
|
||||
n, err := conn.Write(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
182
transport/internet/finalmask/sudoku/conn_tcp_packed.go
Normal file
182
transport/internet/finalmask/sudoku/conn_tcp_packed.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type packedEncoder struct {
|
||||
layouts []*byteLayout
|
||||
codec *codec
|
||||
groupIndex int
|
||||
}
|
||||
|
||||
func newPackedEncoder(tables []*table, pMin, pMax int) *packedEncoder {
|
||||
layouts := make([]*byteLayout, 0, len(tables))
|
||||
for _, t := range tables {
|
||||
layouts = append(layouts, t.layout)
|
||||
}
|
||||
if len(layouts) == 0 {
|
||||
layouts = append(layouts, entropyLayout())
|
||||
}
|
||||
return &packedEncoder{
|
||||
layouts: layouts,
|
||||
codec: newCodec(nil, pMin, pMax),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *packedEncoder) encode(p []byte) ([]byte, error) {
|
||||
out := make([]byte, 0, len(p)*2+8)
|
||||
var bitBuf uint64
|
||||
var bitCount uint8
|
||||
|
||||
for _, b := range p {
|
||||
bitBuf = (bitBuf << 8) | uint64(b)
|
||||
bitCount += 8
|
||||
|
||||
for bitCount >= 6 {
|
||||
bitCount -= 6
|
||||
layout := e.layouts[e.groupIndex%len(e.layouts)]
|
||||
group := byte(bitBuf >> bitCount)
|
||||
out = e.maybePad(out, layout)
|
||||
out = append(out, layout.encodeGroup(group&0x3f))
|
||||
e.groupIndex++
|
||||
if bitCount > 0 {
|
||||
bitBuf &= (uint64(1) << bitCount) - 1
|
||||
} else {
|
||||
bitBuf = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bitCount > 0 {
|
||||
layout := e.layouts[e.groupIndex%len(e.layouts)]
|
||||
group := byte(bitBuf << (6 - bitCount))
|
||||
out = e.maybePad(out, layout)
|
||||
out = append(out, layout.encodeGroup(group&0x3f))
|
||||
e.groupIndex++
|
||||
nextLayout := e.layouts[e.groupIndex%len(e.layouts)]
|
||||
out = append(out, nextLayout.padMarker)
|
||||
}
|
||||
|
||||
out = e.maybePad(out, e.layouts[e.groupIndex%len(e.layouts)])
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (e *packedEncoder) maybePad(out []byte, layout *byteLayout) []byte {
|
||||
if !e.codec.shouldPad() {
|
||||
return out
|
||||
}
|
||||
if len(layout.paddingPool) == 1 {
|
||||
return append(out, layout.paddingPool[0])
|
||||
}
|
||||
for {
|
||||
b := layout.paddingPool[e.codec.rng.Intn(len(layout.paddingPool))]
|
||||
if b != layout.padMarker {
|
||||
return append(out, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type packedStreamDecoder struct {
|
||||
layouts []*byteLayout
|
||||
groupIndex int
|
||||
bitBuf uint64
|
||||
bitCount int
|
||||
}
|
||||
|
||||
func (d *packedStreamDecoder) decodeChunk(in []byte, pending []byte) ([]byte, error) {
|
||||
var err error
|
||||
d.bitBuf, d.bitCount, d.groupIndex, pending, err = decodePackedBytes(
|
||||
d.layouts,
|
||||
in,
|
||||
d.bitBuf,
|
||||
d.bitCount,
|
||||
d.groupIndex,
|
||||
pending,
|
||||
)
|
||||
return pending, err
|
||||
}
|
||||
|
||||
func (d *packedStreamDecoder) reset() {
|
||||
d.bitBuf = 0
|
||||
d.bitCount = 0
|
||||
}
|
||||
|
||||
func NewPackedTCPConn(raw net.Conn, config *Config) (net.Conn, error) {
|
||||
reader, writer, err := newPackedReaderWriter(raw, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newWrappedConn(raw, reader, writer), nil
|
||||
}
|
||||
|
||||
func newPackedReaderWriter(raw net.Conn, config *Config) (io.Reader, io.Writer, error) {
|
||||
tables, err := getTables(config)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pMin, pMax := normalizedPadding(config)
|
||||
encoder := newPackedEncoder(tables, pMin, pMax)
|
||||
decoder := &packedStreamDecoder{
|
||||
layouts: tablesToLayouts(tables),
|
||||
}
|
||||
return newStreamReader(raw, decoder), newStreamWriter(raw, encoder.encode), nil
|
||||
}
|
||||
|
||||
func tablesToLayouts(tables []*table) []*byteLayout {
|
||||
layouts := make([]*byteLayout, 0, len(tables))
|
||||
for _, t := range tables {
|
||||
layouts = append(layouts, t.layout)
|
||||
}
|
||||
if len(layouts) == 0 {
|
||||
layouts = append(layouts, entropyLayout())
|
||||
}
|
||||
return layouts
|
||||
}
|
||||
|
||||
func decodePackedBytes(
|
||||
layouts []*byteLayout,
|
||||
in []byte,
|
||||
bitBuf uint64,
|
||||
bitCount int,
|
||||
groupIndex int,
|
||||
out []byte,
|
||||
) (uint64, int, int, []byte, error) {
|
||||
if len(layouts) == 0 {
|
||||
return bitBuf, bitCount, groupIndex, out, fmt.Errorf("sudoku layout set missing")
|
||||
}
|
||||
for _, b := range in {
|
||||
layout := layouts[groupIndex%len(layouts)]
|
||||
if !layout.isHint(b) {
|
||||
if b == layout.padMarker {
|
||||
bitBuf = 0
|
||||
bitCount = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
group, ok := layout.decodeGroup(b)
|
||||
if !ok {
|
||||
return bitBuf, bitCount, groupIndex, out, fmt.Errorf("invalid packed sudoku byte: %d", b)
|
||||
}
|
||||
groupIndex++
|
||||
|
||||
bitBuf = (bitBuf << 6) | uint64(group)
|
||||
bitCount += 6
|
||||
|
||||
for bitCount >= 8 {
|
||||
bitCount -= 8
|
||||
out = append(out, byte(bitBuf>>bitCount))
|
||||
if bitCount > 0 {
|
||||
bitBuf &= (uint64(1) << bitCount) - 1
|
||||
} else {
|
||||
bitBuf = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bitBuf, bitCount, groupIndex, out, nil
|
||||
}
|
||||
106
transport/internet/finalmask/sudoku/conn_udp.go
Normal file
106
transport/internet/finalmask/sudoku/conn_udp.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type udpConn struct {
|
||||
conn net.PacketConn
|
||||
tables []*table
|
||||
pMin int
|
||||
pMax int
|
||||
|
||||
readBuf []byte
|
||||
|
||||
readMu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewUDPConn(raw net.PacketConn, config *Config) (net.PacketConn, error) {
|
||||
tables, err := getTables(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pMin, pMax := normalizedPadding(config)
|
||||
return &udpConn{
|
||||
conn: raw,
|
||||
tables: tables,
|
||||
pMin: pMin,
|
||||
pMax: pMax,
|
||||
readBuf: make([]byte, 65535),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *udpConn) Size() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
c.readMu.Lock()
|
||||
defer c.readMu.Unlock()
|
||||
|
||||
n, addr, err = c.conn.ReadFrom(c.readBuf)
|
||||
if err != nil {
|
||||
return n, addr, err
|
||||
}
|
||||
|
||||
decoded := make([]byte, 0, n/4+1)
|
||||
hints := make([]byte, 0, 4)
|
||||
tableIndex := 0
|
||||
hints, decoded, err = decodeBytes(c.tables, &tableIndex, c.readBuf[:n], hints, decoded)
|
||||
if err != nil {
|
||||
return 0, addr, err
|
||||
}
|
||||
if len(hints) != 0 {
|
||||
return 0, addr, io.ErrUnexpectedEOF
|
||||
}
|
||||
if len(p) < len(decoded) {
|
||||
return 0, addr, io.ErrShortBuffer
|
||||
}
|
||||
copy(p, decoded)
|
||||
return len(decoded), addr, nil
|
||||
}
|
||||
|
||||
func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
// UDP decoding restarts at table 0 for every datagram, so encoding must do the same.
|
||||
encoded, err := newCodec(c.tables, c.pMin, c.pMax).encode(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
nn, err := c.conn.WriteTo(encoded, addr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if nn != len(encoded) {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *udpConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *udpConn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *udpConn) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *udpConn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *udpConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
1396
transport/internet/finalmask/sudoku/sudoku_test.go
Normal file
1396
transport/internet/finalmask/sudoku/sudoku_test.go
Normal file
File diff suppressed because it is too large
Load Diff
580
transport/internet/finalmask/sudoku/table.go
Normal file
580
transport/internet/finalmask/sudoku/table.go
Normal file
@@ -0,0 +1,580 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
crypto_rand "crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/bits"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type table struct {
|
||||
encode [256][][4]byte
|
||||
decode map[uint32]byte
|
||||
layout *byteLayout
|
||||
}
|
||||
|
||||
type tableCacheKey struct {
|
||||
password string
|
||||
ascii string
|
||||
customTable string
|
||||
}
|
||||
|
||||
var (
|
||||
tableCache sync.Map
|
||||
tableSetCache sync.Map
|
||||
|
||||
basePatternsOnce sync.Once
|
||||
basePatterns [][][4]byte
|
||||
basePatternsErr error
|
||||
)
|
||||
|
||||
type byteLayout struct {
|
||||
hintMask byte
|
||||
hintValue byte
|
||||
padMarker byte
|
||||
paddingPool []byte
|
||||
encodeHint func(group byte) byte
|
||||
encodeGroup func(group byte) byte
|
||||
decodeGroup func(b byte) (byte, bool)
|
||||
}
|
||||
|
||||
func (l *byteLayout) isHint(b byte) bool {
|
||||
if (b & l.hintMask) == l.hintValue {
|
||||
return true
|
||||
}
|
||||
// ASCII layout maps 0x7f to '\n' to avoid DEL on the wire.
|
||||
return l.hintMask == 0x40 && b == '\n'
|
||||
}
|
||||
|
||||
func getTable(config *Config) (*table, error) {
|
||||
tables, err := getTables(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(tables) == 0 {
|
||||
return nil, fmt.Errorf("empty sudoku table set")
|
||||
}
|
||||
return tables[0], nil
|
||||
}
|
||||
|
||||
func getTables(config *Config) ([]*table, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("nil sudoku config")
|
||||
}
|
||||
|
||||
mode, err := normalizeASCII(config.GetAscii())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
patterns, err := normalizedCustomPatterns(config, mode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cacheKey := tableCacheKey{
|
||||
password: config.GetPassword(),
|
||||
ascii: mode,
|
||||
customTable: strings.Join(patterns, "\x00"),
|
||||
}
|
||||
if cached, ok := tableSetCache.Load(cacheKey); ok {
|
||||
return cached.([]*table), nil
|
||||
}
|
||||
|
||||
tables := make([]*table, 0, len(patterns))
|
||||
for _, pattern := range patterns {
|
||||
layout, err := resolveLayout(mode, pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t, err := buildTable(config.GetPassword(), layout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tables = append(tables, t)
|
||||
}
|
||||
|
||||
actual, _ := tableSetCache.LoadOrStore(cacheKey, tables)
|
||||
return actual.([]*table), nil
|
||||
}
|
||||
|
||||
func normalizedCustomPatterns(config *Config, mode string) ([]string, error) {
|
||||
if config == nil {
|
||||
return []string{""}, nil
|
||||
}
|
||||
if mode == "prefer_ascii" {
|
||||
return []string{""}, nil
|
||||
}
|
||||
|
||||
rawPatterns := config.GetCustomTables()
|
||||
if len(rawPatterns) == 0 {
|
||||
rawPatterns = []string{config.GetCustomTable()}
|
||||
}
|
||||
|
||||
patterns := make([]string, 0, len(rawPatterns))
|
||||
seen := make(map[string]struct{}, len(rawPatterns))
|
||||
for _, raw := range rawPatterns {
|
||||
pattern := strings.TrimSpace(raw)
|
||||
if pattern != "" {
|
||||
var err error
|
||||
pattern, err = normalizeCustomTable(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if _, ok := seen[pattern]; ok {
|
||||
continue
|
||||
}
|
||||
seen[pattern] = struct{}{}
|
||||
patterns = append(patterns, pattern)
|
||||
}
|
||||
|
||||
if len(patterns) == 0 {
|
||||
return []string{""}, nil
|
||||
}
|
||||
|
||||
return patterns, nil
|
||||
}
|
||||
|
||||
func normalizedPadding(config *Config) (int, int) {
|
||||
if config == nil {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
pMin := int(config.GetPaddingMin())
|
||||
pMax := int(config.GetPaddingMax())
|
||||
|
||||
if pMin > 100 {
|
||||
pMin = 100
|
||||
}
|
||||
if pMax > 100 {
|
||||
pMax = 100
|
||||
}
|
||||
if pMax < pMin {
|
||||
pMax = pMin
|
||||
}
|
||||
return pMin, pMax
|
||||
}
|
||||
|
||||
func normalizeASCII(mode string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case "", "entropy", "prefer_entropy":
|
||||
return "prefer_entropy", nil
|
||||
case "ascii", "prefer_ascii":
|
||||
return "prefer_ascii", nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid sudoku ascii mode: %s", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeCustomTable(pattern string) (string, error) {
|
||||
cleaned := strings.ToLower(strings.TrimSpace(pattern))
|
||||
cleaned = strings.ReplaceAll(cleaned, " ", "")
|
||||
if len(cleaned) != 8 {
|
||||
return "", fmt.Errorf("customTable must be 8 chars, got %d", len(cleaned))
|
||||
}
|
||||
|
||||
var xCount, pCount, vCount int
|
||||
for _, ch := range cleaned {
|
||||
switch ch {
|
||||
case 'x':
|
||||
xCount++
|
||||
case 'p':
|
||||
pCount++
|
||||
case 'v':
|
||||
vCount++
|
||||
default:
|
||||
return "", fmt.Errorf("customTable has invalid char %q", ch)
|
||||
}
|
||||
}
|
||||
if xCount != 2 || pCount != 2 || vCount != 4 {
|
||||
return "", fmt.Errorf("customTable must contain exactly 2 x, 2 p and 4 v")
|
||||
}
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
func resolveLayout(mode, customTable string) (*byteLayout, error) {
|
||||
if mode == "prefer_ascii" {
|
||||
return asciiLayout(), nil
|
||||
}
|
||||
|
||||
if customTable != "" {
|
||||
return customLayout(customTable)
|
||||
}
|
||||
return entropyLayout(), nil
|
||||
}
|
||||
|
||||
func asciiLayout() *byteLayout {
|
||||
padding := make([]byte, 0, 32)
|
||||
for i := 0; i < 32; i++ {
|
||||
padding = append(padding, byte(0x20+i))
|
||||
}
|
||||
|
||||
encodeGroup := func(group byte) byte {
|
||||
b := byte(0x40 | (group & 0x3f))
|
||||
if b == 0x7f {
|
||||
return '\n'
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
return &byteLayout{
|
||||
hintMask: 0x40,
|
||||
hintValue: 0x40,
|
||||
padMarker: 0x3f,
|
||||
paddingPool: padding,
|
||||
encodeHint: encodeGroup,
|
||||
encodeGroup: encodeGroup,
|
||||
decodeGroup: func(b byte) (byte, bool) {
|
||||
if b == '\n' {
|
||||
return 0x3f, true
|
||||
}
|
||||
if (b & 0x40) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
return b & 0x3f, true
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func entropyLayout() *byteLayout {
|
||||
padding := make([]byte, 0, 16)
|
||||
for i := 0; i < 8; i++ {
|
||||
padding = append(padding, byte(0x80+i), byte(0x10+i))
|
||||
}
|
||||
|
||||
encodeGroup := func(group byte) byte {
|
||||
v := group & 0x3f
|
||||
return ((v & 0x30) << 1) | (v & 0x0f)
|
||||
}
|
||||
|
||||
return &byteLayout{
|
||||
hintMask: 0x90,
|
||||
hintValue: 0x00,
|
||||
padMarker: 0x80,
|
||||
paddingPool: padding,
|
||||
encodeHint: encodeGroup,
|
||||
encodeGroup: encodeGroup,
|
||||
decodeGroup: func(b byte) (byte, bool) {
|
||||
if (b & 0x90) != 0 {
|
||||
return 0, false
|
||||
}
|
||||
return ((b >> 1) & 0x30) | (b & 0x0f), true
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func customLayout(pattern string) (*byteLayout, error) {
|
||||
pattern, err := normalizeCustomTable(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var xBits, pBits, vBits []uint8
|
||||
for i, c := range pattern {
|
||||
bit := uint8(7 - i)
|
||||
switch c {
|
||||
case 'x':
|
||||
xBits = append(xBits, bit)
|
||||
case 'p':
|
||||
pBits = append(pBits, bit)
|
||||
case 'v':
|
||||
vBits = append(vBits, bit)
|
||||
}
|
||||
}
|
||||
|
||||
xMask := byte(0)
|
||||
for _, bit := range xBits {
|
||||
xMask |= 1 << bit
|
||||
}
|
||||
|
||||
encodeGroupWithDropX := func(group byte, dropX int) byte {
|
||||
out := xMask
|
||||
if dropX >= 0 {
|
||||
out &^= 1 << xBits[dropX]
|
||||
}
|
||||
|
||||
val := (group >> 4) & 0x03
|
||||
pos := group & 0x0f
|
||||
|
||||
if (val & 0x02) != 0 {
|
||||
out |= 1 << pBits[0]
|
||||
}
|
||||
if (val & 0x01) != 0 {
|
||||
out |= 1 << pBits[1]
|
||||
}
|
||||
for i, bit := range vBits {
|
||||
if (pos>>(3-uint8(i)))&0x01 == 1 {
|
||||
out |= 1 << bit
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
paddingSet := make(map[byte]struct{}, 64)
|
||||
padding := make([]byte, 0, 64)
|
||||
for drop := range xBits {
|
||||
for val := byte(0); val < 4; val++ {
|
||||
for pos := byte(0); pos < 16; pos++ {
|
||||
group := (val << 4) | pos
|
||||
b := encodeGroupWithDropX(group, drop)
|
||||
if bits.OnesCount8(b) >= 5 {
|
||||
if _, exists := paddingSet[b]; !exists {
|
||||
paddingSet[b] = struct{}{}
|
||||
padding = append(padding, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
sort.Slice(padding, func(i, j int) bool { return padding[i] < padding[j] })
|
||||
if len(padding) == 0 {
|
||||
return nil, fmt.Errorf("customTable produced empty padding pool")
|
||||
}
|
||||
|
||||
decodeGroup := func(b byte) (byte, bool) {
|
||||
if (b & xMask) != xMask {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var val, pos byte
|
||||
if b&(1<<pBits[0]) != 0 {
|
||||
val |= 0x02
|
||||
}
|
||||
if b&(1<<pBits[1]) != 0 {
|
||||
val |= 0x01
|
||||
}
|
||||
for i, bit := range vBits {
|
||||
if b&(1<<bit) != 0 {
|
||||
pos |= 1 << (3 - uint8(i))
|
||||
}
|
||||
}
|
||||
|
||||
return ((val & 0x03) << 4) | (pos & 0x0f), true
|
||||
}
|
||||
encodeGroup := func(group byte) byte {
|
||||
return encodeGroupWithDropX(group, -1)
|
||||
}
|
||||
|
||||
return &byteLayout{
|
||||
hintMask: xMask,
|
||||
hintValue: xMask,
|
||||
padMarker: padding[0],
|
||||
paddingPool: padding,
|
||||
encodeHint: encodeGroup,
|
||||
encodeGroup: encodeGroup,
|
||||
decodeGroup: decodeGroup,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildTable(password string, layout *byteLayout) (*table, error) {
|
||||
patterns, err := getBasePatterns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(patterns) < 256 {
|
||||
return nil, fmt.Errorf("not enough sudoku grids: %d", len(patterns))
|
||||
}
|
||||
|
||||
order := make([]int, len(patterns))
|
||||
for i := range order {
|
||||
order[i] = i
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(password))
|
||||
seed := int64(binary.BigEndian.Uint64(hash[:8]))
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
rng.Shuffle(len(order), func(i, j int) {
|
||||
order[i], order[j] = order[j], order[i]
|
||||
})
|
||||
|
||||
t := &table{
|
||||
decode: make(map[uint32]byte, 1<<16),
|
||||
layout: layout,
|
||||
}
|
||||
for b := 0; b < 256; b++ {
|
||||
patList := patterns[order[b]]
|
||||
if len(patList) == 0 {
|
||||
return nil, fmt.Errorf("grid %d has no valid clue set", order[b])
|
||||
}
|
||||
|
||||
enc := make([][4]byte, 0, len(patList))
|
||||
for _, groups := range patList {
|
||||
hints := [4]byte{
|
||||
layout.encodeHint(groups[0]),
|
||||
layout.encodeHint(groups[1]),
|
||||
layout.encodeHint(groups[2]),
|
||||
layout.encodeHint(groups[3]),
|
||||
}
|
||||
sortedHints := sort4(hints)
|
||||
key := packKey(sortedHints)
|
||||
if old, exists := t.decode[key]; exists && old != byte(b) {
|
||||
return nil, fmt.Errorf("decode key collision for byte %d and %d", old, b)
|
||||
}
|
||||
t.decode[key] = byte(b)
|
||||
enc = append(enc, hints)
|
||||
}
|
||||
|
||||
t.encode[b] = enc
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func getBasePatterns() ([][][4]byte, error) {
|
||||
basePatternsOnce.Do(func() {
|
||||
basePatterns, basePatternsErr = buildBasePatterns()
|
||||
})
|
||||
return basePatterns, basePatternsErr
|
||||
}
|
||||
|
||||
type grid [16]byte
|
||||
|
||||
func buildBasePatterns() ([][][4]byte, error) {
|
||||
grids := generateAllGrids()
|
||||
positions := hintPositions()
|
||||
|
||||
patterns := make([][][4]byte, len(grids))
|
||||
for _, ps := range positions {
|
||||
counts := make(map[uint32]uint16, len(grids))
|
||||
keys := make([]uint32, len(grids))
|
||||
groupsByGrid := make([][4]byte, len(grids))
|
||||
|
||||
for gi, g := range grids {
|
||||
groups := [4]byte{
|
||||
clueGroup(g, ps[0]),
|
||||
clueGroup(g, ps[1]),
|
||||
clueGroup(g, ps[2]),
|
||||
clueGroup(g, ps[3]),
|
||||
}
|
||||
groups = sort4(groups)
|
||||
key := packKey(groups)
|
||||
keys[gi] = key
|
||||
groupsByGrid[gi] = groups
|
||||
counts[key]++
|
||||
}
|
||||
|
||||
for gi, key := range keys {
|
||||
if counts[key] == 1 {
|
||||
patterns[gi] = append(patterns[gi], groupsByGrid[gi])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for gi, list := range patterns {
|
||||
if len(list) == 0 {
|
||||
return nil, fmt.Errorf("grid %d has no uniquely decodable clue set", gi)
|
||||
}
|
||||
}
|
||||
|
||||
return patterns, nil
|
||||
}
|
||||
|
||||
func clueGroup(g grid, pos byte) byte {
|
||||
// 2 bits of value + 4 bits of position.
|
||||
return ((g[pos] - 1) << 4) | (pos & 0x0f)
|
||||
}
|
||||
|
||||
func generateAllGrids() []grid {
|
||||
grids := make([]grid, 0, 288)
|
||||
var g grid
|
||||
|
||||
var dfs func(idx int)
|
||||
dfs = func(idx int) {
|
||||
if idx == 16 {
|
||||
grids = append(grids, g)
|
||||
return
|
||||
}
|
||||
|
||||
row := idx / 4
|
||||
col := idx % 4
|
||||
boxRow := (row / 2) * 2
|
||||
boxCol := (col / 2) * 2
|
||||
|
||||
for num := byte(1); num <= 4; num++ {
|
||||
valid := true
|
||||
for i := 0; i < 4; i++ {
|
||||
if g[row*4+i] == num || g[i*4+col] == num {
|
||||
valid = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
continue
|
||||
}
|
||||
|
||||
for r := 0; r < 2 && valid; r++ {
|
||||
for c := 0; c < 2; c++ {
|
||||
if g[(boxRow+r)*4+(boxCol+c)] == num {
|
||||
valid = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
continue
|
||||
}
|
||||
|
||||
g[idx] = num
|
||||
dfs(idx + 1)
|
||||
g[idx] = 0
|
||||
}
|
||||
}
|
||||
|
||||
dfs(0)
|
||||
return grids
|
||||
}
|
||||
|
||||
func hintPositions() [][4]byte {
|
||||
// C(16, 4) = 1820.
|
||||
positions := make([][4]byte, 0, 1820)
|
||||
for a := 0; a < 13; a++ {
|
||||
for b := a + 1; b < 14; b++ {
|
||||
for c := b + 1; c < 15; c++ {
|
||||
for d := c + 1; d < 16; d++ {
|
||||
positions = append(positions, [4]byte{byte(a), byte(b), byte(c), byte(d)})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return positions
|
||||
}
|
||||
|
||||
func packKey(in [4]byte) uint32 {
|
||||
return uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
|
||||
}
|
||||
|
||||
func sort4(in [4]byte) [4]byte {
|
||||
if in[0] > in[1] {
|
||||
in[0], in[1] = in[1], in[0]
|
||||
}
|
||||
if in[2] > in[3] {
|
||||
in[2], in[3] = in[3], in[2]
|
||||
}
|
||||
if in[0] > in[2] {
|
||||
in[0], in[2] = in[2], in[0]
|
||||
}
|
||||
if in[1] > in[3] {
|
||||
in[1], in[3] = in[3], in[1]
|
||||
}
|
||||
if in[1] > in[2] {
|
||||
in[1], in[2] = in[2], in[1]
|
||||
}
|
||||
return in
|
||||
}
|
||||
|
||||
func newSeededRand() *rand.Rand {
|
||||
seed := time.Now().UnixNano()
|
||||
var seedBytes [8]byte
|
||||
if _, err := crypto_rand.Read(seedBytes[:]); err == nil {
|
||||
seed = int64(binary.BigEndian.Uint64(seedBytes[:]))
|
||||
}
|
||||
return rand.New(rand.NewSource(seed))
|
||||
}
|
||||
@@ -2,10 +2,13 @@ package finalmask_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/proxy"
|
||||
"github.com/xtls/xray-core/transport/internet/finalmask"
|
||||
"github.com/xtls/xray-core/transport/internet/finalmask/header/custom"
|
||||
"github.com/xtls/xray-core/transport/internet/finalmask/header/dns"
|
||||
@@ -16,6 +19,7 @@ import (
|
||||
"github.com/xtls/xray-core/transport/internet/finalmask/mkcp/aes128gcm"
|
||||
"github.com/xtls/xray-core/transport/internet/finalmask/mkcp/original"
|
||||
"github.com/xtls/xray-core/transport/internet/finalmask/salamander"
|
||||
"github.com/xtls/xray-core/transport/internet/finalmask/sudoku"
|
||||
)
|
||||
|
||||
func mustSendRecv(
|
||||
@@ -49,39 +53,93 @@ func mustSendRecv(
|
||||
}
|
||||
|
||||
type layerMask struct {
|
||||
name string
|
||||
mask finalmask.Udpmask
|
||||
name string
|
||||
mask finalmask.Udpmask
|
||||
layers int
|
||||
}
|
||||
|
||||
type countingConn struct {
|
||||
net.Conn
|
||||
written atomic.Int64
|
||||
}
|
||||
|
||||
func (c *countingConn) Write(p []byte) (int, error) {
|
||||
n, err := c.Conn.Write(p)
|
||||
c.written.Add(int64(n))
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *countingConn) Written() int64 {
|
||||
return c.written.Load()
|
||||
}
|
||||
|
||||
func TestPacketConnReadWrite(t *testing.T) {
|
||||
cases := []layerMask{
|
||||
{
|
||||
name: "aes128gcm",
|
||||
mask: &aes128gcm.Config{Password: "123"},
|
||||
name: "aes128gcm",
|
||||
mask: &aes128gcm.Config{Password: "123"},
|
||||
layers: 2,
|
||||
},
|
||||
{
|
||||
name: "original",
|
||||
mask: &original.Config{},
|
||||
name: "original",
|
||||
mask: &original.Config{},
|
||||
layers: 2,
|
||||
},
|
||||
{
|
||||
name: "dns",
|
||||
mask: &dns.Config{Domain: "www.baidu.com"},
|
||||
name: "dns",
|
||||
mask: &dns.Config{Domain: "www.baidu.com"},
|
||||
layers: 2,
|
||||
},
|
||||
{
|
||||
name: "srtp",
|
||||
mask: &srtp.Config{},
|
||||
name: "srtp",
|
||||
mask: &srtp.Config{},
|
||||
layers: 2,
|
||||
},
|
||||
{
|
||||
name: "utp",
|
||||
mask: &utp.Config{},
|
||||
name: "utp",
|
||||
mask: &utp.Config{},
|
||||
layers: 2,
|
||||
},
|
||||
{
|
||||
name: "wechat",
|
||||
mask: &wechat.Config{},
|
||||
name: "wechat",
|
||||
mask: &wechat.Config{},
|
||||
layers: 2,
|
||||
},
|
||||
{
|
||||
name: "wireguard",
|
||||
mask: &wireguard.Config{},
|
||||
name: "wireguard",
|
||||
mask: &wireguard.Config{},
|
||||
layers: 2,
|
||||
},
|
||||
{
|
||||
name: "salamander",
|
||||
mask: &salamander.Config{Password: "1234"},
|
||||
layers: 2,
|
||||
},
|
||||
{
|
||||
name: "sudoku-prefer-ascii",
|
||||
mask: &sudoku.Config{
|
||||
Password: "sudoku-mask",
|
||||
Ascii: "prefer_ascii",
|
||||
},
|
||||
layers: 1,
|
||||
},
|
||||
{
|
||||
name: "sudoku-custom-table",
|
||||
mask: &sudoku.Config{
|
||||
Password: "sudoku-mask",
|
||||
Ascii: "prefer_entropy",
|
||||
CustomTable: "xpxvvpvv",
|
||||
},
|
||||
layers: 1,
|
||||
},
|
||||
{
|
||||
name: "sudoku-custom-tables",
|
||||
mask: &sudoku.Config{
|
||||
Password: "sudoku-mask",
|
||||
Ascii: "prefer_entropy",
|
||||
CustomTables: []string{"xpxvvpvv", "vxpvxvvp"},
|
||||
},
|
||||
layers: 1,
|
||||
},
|
||||
{
|
||||
name: "custom",
|
||||
@@ -103,18 +161,27 @@ func TestPacketConnReadWrite(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
layers: 1,
|
||||
},
|
||||
{
|
||||
name: "salamander",
|
||||
mask: &salamander.Config{Password: "1234"},
|
||||
name: "salamander-single",
|
||||
mask: &salamander.Config{Password: "1234"},
|
||||
layers: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
mask := c.mask
|
||||
|
||||
maskManager := finalmask.NewUdpmaskManager([]finalmask.Udpmask{mask, mask})
|
||||
layers := c.layers
|
||||
if layers <= 0 {
|
||||
layers = 1
|
||||
}
|
||||
masks := make([]finalmask.Udpmask, 0, layers)
|
||||
for i := 0; i < layers; i++ {
|
||||
masks = append(masks, mask)
|
||||
}
|
||||
maskManager := finalmask.NewUdpmaskManager(masks)
|
||||
|
||||
client, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
@@ -147,3 +214,419 @@ func TestPacketConnReadWrite(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSudokuBDD(t *testing.T) {
|
||||
t.Run("GivenSudokuTCPMask_WhenRoundTripWithAsciiPreference_ThenPayloadMatches", func(t *testing.T) {
|
||||
cfg := &sudoku.Config{
|
||||
Password: "sudoku-tcp",
|
||||
Ascii: "prefer_ascii",
|
||||
}
|
||||
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer clientRaw.Close()
|
||||
defer serverRaw.Close()
|
||||
|
||||
clientConn, err := cfg.WrapConnClient(clientRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serverConn, err := cfg.WrapConnServer(serverRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
send := bytes.Repeat([]byte("client->server"), 1024)
|
||||
recv := make([]byte, len(send))
|
||||
|
||||
writeErr := make(chan error, 1)
|
||||
go func() {
|
||||
_, wErr := clientConn.Write(send)
|
||||
writeErr <- wErr
|
||||
}()
|
||||
|
||||
if _, err := io.ReadFull(serverConn, recv); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := <-writeErr; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(send, recv) {
|
||||
t.Fatal("tcp sudoku payload mismatch")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GivenSudokuTCPMask_WhenRoundTrip_ThenBothDirectionsMatch", func(t *testing.T) {
|
||||
cfg := &sudoku.Config{
|
||||
Password: "sudoku-packed",
|
||||
Ascii: "prefer_ascii",
|
||||
PaddingMin: 0,
|
||||
PaddingMax: 0,
|
||||
}
|
||||
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer clientRaw.Close()
|
||||
defer serverRaw.Close()
|
||||
|
||||
clientConn, err := cfg.WrapConnClient(clientRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serverConn, err := cfg.WrapConnServer(serverRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientToServer := bytes.Repeat([]byte("client-packed->server"), 257)
|
||||
serverToClient := bytes.Repeat([]byte("server-packed->client"), 263)
|
||||
|
||||
c2sRecv := make([]byte, len(clientToServer))
|
||||
c2sErr := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := clientConn.Write(clientToServer)
|
||||
c2sErr <- err
|
||||
}()
|
||||
if _, err := io.ReadFull(serverConn, c2sRecv); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := <-c2sErr; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(clientToServer, c2sRecv) {
|
||||
t.Fatal("tcp client->server payload mismatch")
|
||||
}
|
||||
|
||||
s2cRecv := make([]byte, len(serverToClient))
|
||||
s2cErr := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := serverConn.Write(serverToClient)
|
||||
s2cErr <- err
|
||||
}()
|
||||
if _, err := io.ReadFull(clientConn, s2cRecv); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := <-s2cErr; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(serverToClient, s2cRecv) {
|
||||
t.Fatal("tcp server->client payload mismatch")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GivenSudokuTCPMask_WhenServerWritesDownlink_ThenWireBytesAreReduced", func(t *testing.T) {
|
||||
payload := bytes.Repeat([]byte("0123456789abcdef"), 192) // 3072 bytes, divisible by 3.
|
||||
|
||||
countWireBytes := func(wrapServer func(net.Conn, *sudoku.Config) (net.Conn, error), cfg *sudoku.Config) int64 {
|
||||
t.Helper()
|
||||
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
watchedServerRaw := &countingConn{Conn: serverRaw}
|
||||
|
||||
clientConn, err := cfg.WrapConnClient(clientRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serverConn, err := wrapServer(watchedServerRaw, cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
readErr := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := io.CopyN(io.Discard, clientConn, int64(len(payload)))
|
||||
readErr <- err
|
||||
}()
|
||||
|
||||
if _, err := serverConn.Write(payload); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := <-readErr; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
return watchedServerRaw.Written()
|
||||
}
|
||||
|
||||
pureUplinkPackedDownlink := &sudoku.Config{
|
||||
Password: "sudoku-bandwidth",
|
||||
Ascii: "prefer_entropy",
|
||||
PaddingMin: 0,
|
||||
PaddingMax: 0,
|
||||
}
|
||||
packedDownlinkBytes := countWireBytes(func(raw net.Conn, cfg *sudoku.Config) (net.Conn, error) {
|
||||
return cfg.WrapConnServer(raw)
|
||||
}, pureUplinkPackedDownlink)
|
||||
legacyPureBytes := countWireBytes(func(raw net.Conn, cfg *sudoku.Config) (net.Conn, error) {
|
||||
return sudoku.NewTCPConn(raw, cfg)
|
||||
}, pureUplinkPackedDownlink)
|
||||
|
||||
if packedDownlinkBytes >= legacyPureBytes {
|
||||
t.Fatalf("expected default packed downlink bytes < legacy pure bytes, got packed=%d pure=%d", packedDownlinkBytes, legacyPureBytes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GivenSudokuMultiTableTCPMask_WhenRoundTrip_ThenPayloadMatches", func(t *testing.T) {
|
||||
cfg := &sudoku.Config{
|
||||
Password: "sudoku-multi-tcp",
|
||||
Ascii: "prefer_entropy",
|
||||
CustomTables: []string{"xpxvvpvv", "vxpvxvvp"},
|
||||
}
|
||||
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer clientRaw.Close()
|
||||
defer serverRaw.Close()
|
||||
|
||||
clientConn, err := cfg.WrapConnClient(clientRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serverConn, err := cfg.WrapConnServer(serverRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
send := bytes.Repeat([]byte("rotate-table"), 513)
|
||||
recv := make([]byte, len(send))
|
||||
|
||||
writeErr := make(chan error, 1)
|
||||
go func() {
|
||||
_, wErr := clientConn.Write(send)
|
||||
writeErr <- wErr
|
||||
}()
|
||||
|
||||
if _, err := io.ReadFull(serverConn, recv); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := <-writeErr; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(send, recv) {
|
||||
t.Fatal("multi-table tcp sudoku payload mismatch")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GivenSudokuMultiTableTCPMask_WhenPackedDownlink_ThenPayloadMatches", func(t *testing.T) {
|
||||
cfg := &sudoku.Config{
|
||||
Password: "sudoku-multi-packed",
|
||||
Ascii: "prefer_entropy",
|
||||
CustomTables: []string{"xpxvvpvv", "vxpvxvvp"},
|
||||
PaddingMin: 0,
|
||||
PaddingMax: 0,
|
||||
}
|
||||
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer clientRaw.Close()
|
||||
defer serverRaw.Close()
|
||||
|
||||
clientConn, err := cfg.WrapConnClient(clientRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serverConn, err := cfg.WrapConnServer(serverRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
send := bytes.Repeat([]byte("packed-rotate"), 257)
|
||||
recv := make([]byte, len(send))
|
||||
|
||||
writeErr := make(chan error, 1)
|
||||
go func() {
|
||||
_, wErr := clientConn.Write(send)
|
||||
writeErr <- wErr
|
||||
}()
|
||||
|
||||
if _, err := io.ReadFull(serverConn, recv); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := <-writeErr; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(send, recv) {
|
||||
t.Fatal("multi-table tcp sudoku payload mismatch")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GivenSudokuUDPMask_WhenNotInnermost_ThenWrapFails", func(t *testing.T) {
|
||||
cfg := &sudoku.Config{Password: "sudoku-udp"}
|
||||
raw, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer raw.Close()
|
||||
|
||||
if _, err := cfg.WrapPacketConnClient(raw, 0, 1); err == nil {
|
||||
t.Fatal("expected innermost check failure")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GivenSudokuMultiTableUDPMask_WhenClientSendsMultipleDatagrams_ThenPayloadMatches", func(t *testing.T) {
|
||||
cfg := &sudoku.Config{
|
||||
Password: "sudoku-udp-multi",
|
||||
Ascii: "prefer_entropy",
|
||||
CustomTables: []string{"xpxvvpvv", "vxpvxvvp"},
|
||||
PaddingMin: 0,
|
||||
PaddingMax: 0,
|
||||
}
|
||||
maskManager := finalmask.NewUdpmaskManager([]finalmask.Udpmask{cfg})
|
||||
|
||||
clientRaw, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer clientRaw.Close()
|
||||
|
||||
serverRaw, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer serverRaw.Close()
|
||||
|
||||
client, err := maskManager.WrapPacketConnClient(clientRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server, err := maskManager.WrapPacketConnServer(serverRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
_ = server.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
|
||||
mustSendRecv(t, client, server, []byte("first-datagram"))
|
||||
mustSendRecv(t, client, server, []byte("second-datagram"))
|
||||
mustSendRecv(t, client, server, []byte("third-datagram"))
|
||||
})
|
||||
|
||||
t.Run("GivenSudokuTCPMask_WhenCloseWriteIsCalled_ThenEOFPropagates", func(t *testing.T) {
|
||||
cfg := &sudoku.Config{
|
||||
Password: "sudoku-closewrite",
|
||||
Ascii: "prefer_ascii",
|
||||
PaddingMin: 0,
|
||||
PaddingMax: 0,
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
acceptCh := make(chan net.Conn, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
acceptCh <- conn
|
||||
}()
|
||||
|
||||
clientRaw, err := net.Dial("tcp", listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer clientRaw.Close()
|
||||
|
||||
var serverRaw net.Conn
|
||||
select {
|
||||
case serverRaw = <-acceptCh:
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("accept timeout")
|
||||
}
|
||||
defer serverRaw.Close()
|
||||
|
||||
clientConn, err := cfg.WrapConnClient(clientRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serverConn, err := cfg.WrapConnServer(serverRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
closeWriter, ok := clientConn.(interface{ CloseWrite() error })
|
||||
if !ok {
|
||||
t.Fatalf("wrapped conn does not expose CloseWrite: %T", clientConn)
|
||||
}
|
||||
|
||||
writeErr := make(chan error, 1)
|
||||
go func() {
|
||||
if _, err := clientConn.Write([]byte("closewrite")); err != nil {
|
||||
writeErr <- err
|
||||
return
|
||||
}
|
||||
writeErr <- closeWriter.CloseWrite()
|
||||
}()
|
||||
|
||||
buf := make([]byte, len("closewrite"))
|
||||
if _, err := io.ReadFull(serverConn, buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(buf, []byte("closewrite")) {
|
||||
t.Fatal("unexpected payload before closewrite")
|
||||
}
|
||||
if err := <-writeErr; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
one := make([]byte, 1)
|
||||
n, err := serverConn.Read(one)
|
||||
if n != 0 || err != io.EOF {
|
||||
t.Fatalf("expected EOF after CloseWrite, got n=%d err=%v", n, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GivenSudokuTCPMask_WhenProxyUnwrapRawConn_ThenMaskConnIsRetained", func(t *testing.T) {
|
||||
cfg := &sudoku.Config{
|
||||
Password: "sudoku-unwrap",
|
||||
Ascii: "prefer_entropy",
|
||||
}
|
||||
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer clientRaw.Close()
|
||||
defer serverRaw.Close()
|
||||
|
||||
clientConn, err := cfg.WrapConnClient(clientRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
unwrapped, readCounter, writeCounter := proxy.UnwrapRawConn(clientConn)
|
||||
if readCounter != nil || writeCounter != nil {
|
||||
t.Fatal("unexpected stat counters while unwrapping sudoku conn")
|
||||
}
|
||||
if unwrapped != clientConn {
|
||||
t.Fatalf("expected sudoku conn to stay wrapped, got %T", unwrapped)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GivenSudokuTCPMask_WhenProxyUnwrapRawConn_AfterDownlinkOptimization_ThenMaskConnIsRetained", func(t *testing.T) {
|
||||
cfg := &sudoku.Config{
|
||||
Password: "sudoku-packed-unwrap",
|
||||
Ascii: "prefer_entropy",
|
||||
}
|
||||
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer clientRaw.Close()
|
||||
defer serverRaw.Close()
|
||||
|
||||
clientConn, err := cfg.WrapConnClient(clientRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
unwrapped, readCounter, writeCounter := proxy.UnwrapRawConn(clientConn)
|
||||
if readCounter != nil || writeCounter != nil {
|
||||
t.Fatal("unexpected stat counters while unwrapping sudoku conn")
|
||||
}
|
||||
if unwrapped != clientConn {
|
||||
t.Fatalf("expected sudoku conn to stay wrapped, got %T", unwrapped)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user