From 05e259c8e4cc25da48059694cc7cf5b2f556fee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=98=D0=B2=D0=B0=D0=BD?= <82300276+fatyzzz@users.noreply.github.com> Date: Wed, 15 Apr 2026 23:21:23 +0700 Subject: [PATCH] header-custom finalmask: Add UDP standalone handshake mode (#5945) https://github.com/XTLS/Xray-core/commit/175502d8079aa5a151242ed911d01a1b90b98b28 --- common/serial/typed_message_test.go | 29 ++ infra/conf/transport_internet.go | 8 + infra/conf/transport_test.go | 4 +- transport/internet/finalmask/finalmask.go | 34 +- .../finalmask/header/custom/config.go | 10 + .../finalmask/header/custom/config.pb.go | 13 +- .../finalmask/header/custom/config.proto | 1 + .../finalmask/header/custom/evaluator.go | 10 + .../finalmask/header/custom/metadata_test.go | 94 +++++ .../internet/finalmask/header/custom/udp.go | 233 +++++++++++ transport/internet/finalmask/tcp_test.go | 68 ++++ transport/internet/finalmask/udp_test.go | 383 ++++++++++++++++++ transport/internet/udp/dialer.go | 5 +- 13 files changed, 886 insertions(+), 6 deletions(-) diff --git a/common/serial/typed_message_test.go b/common/serial/typed_message_test.go index 726a7733..75d529fd 100644 --- a/common/serial/typed_message_test.go +++ b/common/serial/typed_message_test.go @@ -4,6 +4,7 @@ import ( "testing" . "github.com/xtls/xray-core/common/serial" + "github.com/xtls/xray-core/transport/internet/finalmask/header/custom" ) func TestGetInstance(t *testing.T) { @@ -22,3 +23,31 @@ func TestConvertingNilMessage(t *testing.T) { t.Error("expect nil, but actually not") } } + +func TestTypedMessageRoundTripPreservesFinalmaskCustomUDPMode(t *testing.T) { + msg := &custom.UDPConfig{ + Mode: "standalone", + Client: []*custom.UDPItem{ + {Rand: 12, Save: "txid"}, + }, + } + + tm := ToTypedMessage(msg) + if tm == nil { + t.Fatal("expected typed message") + } + + roundTrip, err := tm.GetInstance() + if err != nil { + t.Fatalf("GetInstance() failed: %v", err) + } + + udp, ok := roundTrip.(*custom.UDPConfig) + if !ok { + t.Fatalf("unexpected round-trip type: %T", roundTrip) + } + + if udp.GetMode() != "standalone" { + t.Fatalf("mode lost during typed message round-trip: got %q", udp.GetMode()) + } +} diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 23df27a8..15d5cfbd 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -1657,11 +1657,18 @@ func buildCustomTransformArg(arg CustomTransformArg) (*custom.ExprArg, error) { } type HeaderCustomUDP struct { + Mode string `json:"mode"` Client []UDPItem `json:"client"` Server []UDPItem `json:"server"` } func (c *HeaderCustomUDP) Build() (proto.Message, error) { + switch c.Mode { + case "", "prefix", "standalone": + default: + return nil, errors.New("unknown udp mode") + } + for _, item := range c.Client { if err := validateCustomItemSpec(item.Capture, item.Packet, item.Rand, item.Reuse, item.Transform); err != nil { return nil, err @@ -1730,6 +1737,7 @@ func (c *HeaderCustomUDP) Build() (proto.Message, error) { return &custom.UDPConfig{ Client: client, Server: server, + Mode: c.Mode, }, nil } diff --git a/infra/conf/transport_test.go b/infra/conf/transport_test.go index 1912126a..15c9aba3 100644 --- a/infra/conf/transport_test.go +++ b/infra/conf/transport_test.go @@ -191,10 +191,12 @@ func TestHeaderCustomUDPBuild(t *testing.T) { { "reuse": "txid" } - ] + ], + "mode": "standalone" }`, Parser: parser, Output: &finalmaskcustom.UDPConfig{ + Mode: "standalone", Client: []*finalmaskcustom.UDPItem{ { RandMax: 255, diff --git a/transport/internet/finalmask/finalmask.go b/transport/internet/finalmask/finalmask.go index f9c92f66..e7bf0080 100644 --- a/transport/internet/finalmask/finalmask.go +++ b/transport/internet/finalmask/finalmask.go @@ -31,6 +31,19 @@ func (m *UdpmaskManager) WrapPacketConnClient(raw net.PacketConn) (net.PacketCon var conns []net.PacketConn for i, mask := range m.udpmasks { if _, ok := mask.(headerConn); ok { + if mode, ok := mask.(headerConnMode); ok && !mode.UseHeaderConn() { + if len(conns) > 0 { + raw = &headerManagerConn{sizes: sizes, conns: conns, PacketConn: raw} + sizes = nil + conns = nil + } + var err error + raw, err = mask.WrapPacketConnClient(raw, i, len(m.udpmasks)-1) + if err != nil { + return nil, err + } + continue + } conn, err := mask.WrapPacketConnClient(nil, i, len(m.udpmasks)-1) if err != nil { return nil, err @@ -64,6 +77,19 @@ func (m *UdpmaskManager) WrapPacketConnServer(raw net.PacketConn) (net.PacketCon var conns []net.PacketConn for i, mask := range m.udpmasks { if _, ok := mask.(headerConn); ok { + if mode, ok := mask.(headerConnMode); ok && !mode.UseHeaderConn() { + if len(conns) > 0 { + raw = &headerManagerConn{sizes: sizes, conns: conns, PacketConn: raw} + sizes = nil + conns = nil + } + var err error + raw, err = mask.WrapPacketConnServer(raw, i, len(m.udpmasks)-1) + if err != nil { + return nil, err + } + continue + } conn, err := mask.WrapPacketConnServer(nil, i, len(m.udpmasks)-1) if err != nil { return nil, err @@ -100,6 +126,10 @@ type headerConn interface { HeaderConn() } +type headerConnMode interface { + UseHeaderConn() bool +} + type headerSize interface { Size() int } @@ -262,8 +292,8 @@ func (l *tcpListener) Accept() (net.Conn, error) { newConn, err := l.m.WrapConnServer(conn) if err != nil { errors.LogDebugInner(context.Background(), err, "mask err") - // conn.Close() - return conn, nil + _ = conn.Close() + return nil, err } return newConn, nil diff --git a/transport/internet/finalmask/header/custom/config.go b/transport/internet/finalmask/header/custom/config.go index be094f21..0427051b 100644 --- a/transport/internet/finalmask/header/custom/config.go +++ b/transport/internet/finalmask/header/custom/config.go @@ -19,12 +19,22 @@ func (c *UDPConfig) UDP() { } func (c *UDPConfig) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + if c.Mode == "standalone" { + return NewConnClientUDPStandalone(c, raw) + } return NewConnClientUDP(c, raw) } func (c *UDPConfig) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + if c.Mode == "standalone" { + return NewConnServerUDPStandalone(c, raw) + } return NewConnServerUDP(c, raw) } func (c *UDPConfig) HeaderConn() { } + +func (c *UDPConfig) UseHeaderConn() bool { + return c.Mode != "standalone" +} diff --git a/transport/internet/finalmask/header/custom/config.pb.go b/transport/internet/finalmask/header/custom/config.pb.go index 2340a544..43423860 100644 --- a/transport/internet/finalmask/header/custom/config.pb.go +++ b/transport/internet/finalmask/header/custom/config.pb.go @@ -511,6 +511,7 @@ type UDPConfig struct { state protoimpl.MessageState `protogen:"open.v1"` Client []*UDPItem `protobuf:"bytes,1,rep,name=client,proto3" json:"client,omitempty"` Server []*UDPItem `protobuf:"bytes,2,rep,name=server,proto3" json:"server,omitempty"` + Mode string `protobuf:"bytes,3,opt,name=mode,proto3" json:"mode,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -559,6 +560,13 @@ func (x *UDPConfig) GetServer() []*UDPItem { return nil } +func (x *UDPConfig) GetMode() string { + if x != nil { + return x.Mode + } + return "" +} + var File_transport_internet_finalmask_header_custom_config_proto protoreflect.FileDescriptor const file_transport_internet_finalmask_header_custom_config_proto_rawDesc = "" + @@ -597,10 +605,11 @@ const file_transport_internet_finalmask_header_custom_config_proto_rawDesc = "" "\x06packet\x18\x04 \x01(\fR\x06packet\x12\x12\n" + "\x04save\x18\x05 \x01(\tR\x04save\x12\x10\n" + "\x03var\x18\x06 \x01(\tR\x03var\x12I\n" + - "\x04expr\x18\a \x01(\v25.xray.transport.internet.finalmask.header.custom.ExprR\x04expr\"\xaf\x01\n" + + "\x04expr\x18\a \x01(\v25.xray.transport.internet.finalmask.header.custom.ExprR\x04expr\"\xc3\x01\n" + "\tUDPConfig\x12P\n" + "\x06client\x18\x01 \x03(\v28.xray.transport.internet.finalmask.header.custom.UDPItemR\x06client\x12P\n" + - "\x06server\x18\x02 \x03(\v28.xray.transport.internet.finalmask.header.custom.UDPItemR\x06serverB\xaf\x01\n" + + "\x06server\x18\x02 \x03(\v28.xray.transport.internet.finalmask.header.custom.UDPItemR\x06server\x12\x12\n" + + "\x04mode\x18\x03 \x01(\tR\x04modeB\xaf\x01\n" + "3com.xray.transport.internet.finalmask.header.customP\x01ZDgithub.com/xtls/xray-core/transport/internet/finalmask/header/custom\xaa\x02/Xray.Transport.Internet.Finalmask.Header.Customb\x06proto3" var ( diff --git a/transport/internet/finalmask/header/custom/config.proto b/transport/internet/finalmask/header/custom/config.proto index 3602e522..d350e76f 100644 --- a/transport/internet/finalmask/header/custom/config.proto +++ b/transport/internet/finalmask/header/custom/config.proto @@ -56,4 +56,5 @@ message UDPItem { message UDPConfig { repeated UDPItem client = 1; repeated UDPItem server = 2; + string mode = 3; } diff --git a/transport/internet/finalmask/header/custom/evaluator.go b/transport/internet/finalmask/header/custom/evaluator.go index 0fba850a..46655700 100644 --- a/transport/internet/finalmask/header/custom/evaluator.go +++ b/transport/internet/finalmask/header/custom/evaluator.go @@ -398,9 +398,19 @@ func loadMetadata(dst map[string]evalValue, prefix string, addr net.Addr) { func loadIPPortMetadata(dst map[string]evalValue, prefix string, ip net.IP, port int) { portValue := uint64(port) dst[prefix+"_port"] = evalValue{u64: &portValue} + if prefix == "remote" { + dst["src_port_u16"] = evalValue{u64: &portValue} + } else if prefix == "local" { + dst["dst_port_u16"] = evalValue{u64: &portValue} + } if ip4 := ip.To4(); ip4 != nil { ipValue := uint64(binary.BigEndian.Uint32(ip4)) dst[prefix+"_ip4_u32"] = evalValue{u64: &ipValue} + if prefix == "remote" { + dst["src_ip4_u32"] = evalValue{u64: &ipValue} + } else if prefix == "local" { + dst["dst_ip4_u32"] = evalValue{u64: &ipValue} + } } } diff --git a/transport/internet/finalmask/header/custom/metadata_test.go b/transport/internet/finalmask/header/custom/metadata_test.go index ee300bab..78633e61 100644 --- a/transport/internet/finalmask/header/custom/metadata_test.go +++ b/transport/internet/finalmask/header/custom/metadata_test.go @@ -30,6 +30,100 @@ func TestMetadataEvaluatorRejectsUnknownName(t *testing.T) { } } +func TestMetadataAliasesExposeSrcAndDstNames(t *testing.T) { + ctx := newEvalContextWithAddrs( + &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 3478}, + &net.UDPAddr{IP: net.IPv4(203, 0, 113, 9), Port: 54321}, + ) + + items := []*UDPItem{ + { + Expr: &Expr{ + Op: "concat", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "be16", + Args: []*ExprArg{ + {Value: &ExprArg_Metadata{Metadata: "src_port_u16"}}, + }, + }, + }, + }, + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "be32", + Args: []*ExprArg{ + {Value: &ExprArg_Metadata{Metadata: "src_ip4_u32"}}, + }, + }, + }, + }, + }, + }, + }, + } + + got, err := evaluateUDPItemsWithContext(items, ctx) + if err != nil { + t.Fatal(err) + } + + want := []byte{0xD4, 0x31, 203, 0, 113, 9} + if !bytes.Equal(got, want) { + t.Fatalf("unexpected alias output: got=%x want=%x", got, want) + } +} + +func TestMetadataAliasesExposeDstNames(t *testing.T) { + ctx := newEvalContextWithAddrs( + &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 3478}, + &net.UDPAddr{IP: net.IPv4(203, 0, 113, 9), Port: 54321}, + ) + + items := []*UDPItem{ + { + Expr: &Expr{ + Op: "concat", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "be16", + Args: []*ExprArg{ + {Value: &ExprArg_Metadata{Metadata: "dst_port_u16"}}, + }, + }, + }, + }, + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "be32", + Args: []*ExprArg{ + {Value: &ExprArg_Metadata{Metadata: "dst_ip4_u32"}}, + }, + }, + }, + }, + }, + }, + }, + } + + got, err := evaluateUDPItemsWithContext(items, ctx) + if err != nil { + t.Fatal(err) + } + + want := []byte{0x0D, 0x96, 10, 0, 0, 1} + if !bytes.Equal(got, want) { + t.Fatalf("unexpected alias output: got=%x want=%x", got, want) + } +} + func TestMetadataUDPWriteUsesRemotePort(t *testing.T) { cfg := &UDPConfig{ Client: []*UDPItem{ diff --git a/transport/internet/finalmask/header/custom/udp.go b/transport/internet/finalmask/header/custom/udp.go index b3d8d122..b16c7c9a 100644 --- a/transport/internet/finalmask/header/custom/udp.go +++ b/transport/internet/finalmask/header/custom/udp.go @@ -3,11 +3,14 @@ package custom import ( "bytes" "net" + "sync" "time" "github.com/xtls/xray-core/common/errors" ) +const udpStandaloneBufferSize = 4096 + type udpCustomClient struct { client []*UDPItem server []*UDPItem @@ -267,3 +270,233 @@ func udpStateKey(addr net.Addr) string { } return addr.String() } + +type udpCustomStandaloneClientConn struct { + net.PacketConn + client []*UDPItem + server []*UDPItem + state *stateStore + read int + mu sync.Mutex + once sync.Once + queue chan udpStandalonePacket + wait map[string]*udpStandaloneWaiter +} + +type udpStandalonePacket struct { + data []byte + addr net.Addr + err error +} + +type udpStandaloneWaiter struct { + vars map[string][]byte + done chan error +} + +func NewConnClientUDPStandalone(c *UDPConfig, raw net.PacketConn) (net.PacketConn, error) { + clientSavedSizes := collectSavedUDPSizes(c.Client) + read, err := measureUDPItemsWithFallback(c.Server, clientSavedSizes) + if err != nil { + return nil, err + } + + return &udpCustomStandaloneClientConn{ + PacketConn: raw, + client: c.Client, + server: c.Server, + state: newStateStore(5 * time.Second), + read: read, + queue: make(chan udpStandalonePacket, 16), + wait: make(map[string]*udpStandaloneWaiter), + }, nil +} + +func (c *udpCustomStandaloneClientConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + c.ensureReader() + packet, ok := <-c.queue + if !ok { + return 0, nil, net.ErrClosed + } + if packet.err != nil { + return 0, packet.addr, packet.err + } + if len(packet.data) > len(p) { + copy(p, packet.data[:len(p)]) + return len(p), packet.addr, nil + } + copy(p, packet.data) + return len(packet.data), packet.addr, nil +} + +func (c *udpCustomStandaloneClientConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + c.ensureReader() + key := udpStateKey(addr) + if _, ok := c.state.get(key); !ok { + var localAddr net.Addr + if c.PacketConn != nil { + localAddr = c.PacketConn.LocalAddr() + } + + ctx := newEvalContextWithAddrs(localAddr, addr) + request, err := evaluateUDPItemsWithContext(c.client, ctx) + if err != nil { + return 0, err + } + waiter := c.registerWaiter(key, ctx.vars) + if _, err := c.PacketConn.WriteTo(request, addr); err != nil { + c.unregisterWaiter(key, waiter) + return 0, err + } + if err := <-waiter.done; err != nil { + return 0, err + } + } + + return c.PacketConn.WriteTo(p, addr) +} + +func (c *udpCustomStandaloneClientConn) ensureReader() { + c.once.Do(func() { + go c.readerLoop(c.queue) + }) +} + +func (c *udpCustomStandaloneClientConn) registerWaiter(key string, vars map[string][]byte) *udpStandaloneWaiter { + waiter := &udpStandaloneWaiter{ + vars: cloneVars(vars), + done: make(chan error, 1), + } + c.mu.Lock() + c.wait[key] = waiter + c.mu.Unlock() + return waiter +} + +func (c *udpCustomStandaloneClientConn) unregisterWaiter(key string, waiter *udpStandaloneWaiter) { + c.mu.Lock() + if c.wait[key] == waiter { + delete(c.wait, key) + } + c.mu.Unlock() +} + +func (c *udpCustomStandaloneClientConn) readerLoop(queue chan udpStandalonePacket) { + buf := make([]byte, udpStandaloneBufferSize) + for { + n, addr, err := c.PacketConn.ReadFrom(buf) + if err != nil { + c.failWaiters(err) + queue <- udpStandalonePacket{addr: addr, err: err} + close(queue) + return + } + data := append([]byte(nil), buf[:n]...) + if c.tryCompleteHandshake(addr, data) { + continue + } + queue <- udpStandalonePacket{data: data, addr: addr} + } +} + +func (c *udpCustomStandaloneClientConn) tryCompleteHandshake(addr net.Addr, data []byte) bool { + key := udpStateKey(addr) + c.mu.Lock() + waiter, ok := c.wait[key] + c.mu.Unlock() + if !ok || len(data) != c.read { + return false + } + + vars, matched := matchUDPItems(c.server, data, c.read, waiter.vars) + if !matched { + return false + } + + c.state.set(key, vars) + c.mu.Lock() + if c.wait[key] == waiter { + delete(c.wait, key) + } + c.mu.Unlock() + waiter.done <- nil + return true +} + +func (c *udpCustomStandaloneClientConn) failWaiters(err error) { + c.mu.Lock() + waiters := c.wait + c.wait = make(map[string]*udpStandaloneWaiter) + c.mu.Unlock() + for _, waiter := range waiters { + waiter.done <- err + } +} + +type udpCustomStandaloneServerConn struct { + net.PacketConn + client []*UDPItem + server []*UDPItem + state *stateStore + read int +} + +func NewConnServerUDPStandalone(c *UDPConfig, raw net.PacketConn) (net.PacketConn, error) { + read, err := measureUDPItems(c.Client) + if err != nil { + return nil, err + } + + return &udpCustomStandaloneServerConn{ + PacketConn: raw, + client: c.Client, + server: c.Server, + state: newStateStore(5 * time.Second), + read: read, + }, nil +} + +func (c *udpCustomStandaloneServerConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + buf := p + copyBack := false + if len(buf) < udpStandaloneBufferSize { + buf = make([]byte, udpStandaloneBufferSize) + copyBack = true + } + + for { + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil { + return 0, addr, err + } + if n == c.read { + vars, ok := matchUDPItems(c.client, buf[:n], c.read, nil) + if ok { + var localAddr net.Addr + if c.PacketConn != nil { + localAddr = c.PacketConn.LocalAddr() + } + ctx := newEvalContextWithAddrs(localAddr, addr) + ctx.vars = cloneVars(vars) + response, err := evaluateUDPItemsWithContext(c.server, ctx) + if err != nil { + return 0, addr, err + } + if _, err := c.PacketConn.WriteTo(response, addr); err != nil { + return 0, addr, err + } + c.state.set(udpStateKey(addr), ctx.vars) + continue + } + } + + if copyBack { + copy(p, buf[:n]) + } + return n, addr, nil + } +} + +func (c *udpCustomStandaloneServerConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return c.PacketConn.WriteTo(p, addr) +} diff --git a/transport/internet/finalmask/tcp_test.go b/transport/internet/finalmask/tcp_test.go index c6125aa9..4c07d983 100644 --- a/transport/internet/finalmask/tcp_test.go +++ b/transport/internet/finalmask/tcp_test.go @@ -47,6 +47,14 @@ type layerMaskTcp struct { mask finalmask.Tcpmask } +type failingWrapMask struct{} + +func (failingWrapMask) TCP() {} +func (f failingWrapMask) WrapConnClient(raw net.Conn) (net.Conn, error) { return raw, nil } +func (f failingWrapMask) WrapConnServer(raw net.Conn) (net.Conn, error) { + return nil, io.ErrClosedPipe +} + func TestConnReadWrite(t *testing.T) { cases := []layerMaskTcp{ { @@ -247,3 +255,63 @@ func TestTCPcustomClientRejectsMismatchedServerSequence(t *testing.T) { t.Fatalf("expected server timeout after client auth failure, got %v", readErr) } } + +func TestTCPWrapListenerRejectsImmediateWrapErrors(t *testing.T) { + clientManager := finalmask.NewTcpmaskManager([]finalmask.Tcpmask{failingWrapMask{}}) + serverManager := finalmask.NewTcpmaskManager([]finalmask.Tcpmask{failingWrapMask{}}) + + rawLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer rawLn.Close() + + ln, err := serverManager.WrapListener(rawLn) + if err != nil { + t.Fatal(err) + } + + accepted := make(chan struct { + conn net.Conn + err error + }, 1) + go func() { + conn, err := ln.Accept() + accepted <- struct { + conn net.Conn + err error + }{conn: conn, err: err} + }() + + clientRaw, err := net.Dial("tcp", rawLn.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer clientRaw.Close() + + client, err := clientManager.WrapConnClient(clientRaw) + if err != nil { + t.Fatal(err) + } + + _ = client.SetDeadline(time.Now().Add(time.Second)) + + writeErr := make(chan error, 1) + go func() { + _, err := client.Write([]byte("payload")) + writeErr <- err + }() + + result := <-accepted + if result.err == nil { + if result.conn != nil { + result.conn.Close() + } + t.Fatal("expected wrapped listener accept to fail") + } + if result.conn != nil { + result.conn.Close() + t.Fatalf("expected no raw conn on wrapped listener failure, got %T", result.conn) + } + <-writeErr +} diff --git a/transport/internet/finalmask/udp_test.go b/transport/internet/finalmask/udp_test.go index f7a40f2d..7b7b7566 100644 --- a/transport/internet/finalmask/udp_test.go +++ b/transport/internet/finalmask/udp_test.go @@ -2,12 +2,16 @@ package finalmask_test import ( "bytes" + "context" + "encoding/binary" "io" "net" "sync/atomic" "testing" "time" + singM "github.com/sagernet/sing/common/metadata" + singN "github.com/sagernet/sing/common/network" "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/transport/internet/finalmask" "github.com/xtls/xray-core/transport/internet/finalmask/header/custom" @@ -73,6 +77,194 @@ func (c *countingConn) Written() int64 { return c.written.Load() } +type recordedPacketWrite struct { + payload []byte + addr net.Addr +} + +type scriptedPacketConn struct { + local *net.UDPAddr + writes chan recordedPacketWrite + reads chan recordedPacketWrite + closed atomic.Bool + deadline atomic.Int64 +} + +func newScriptedPacketConn() *scriptedPacketConn { + return &scriptedPacketConn{ + local: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000}, + writes: make(chan recordedPacketWrite, 8), + reads: make(chan recordedPacketWrite, 8), + } +} + +func (c *scriptedPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + item, ok := <-c.reads + if !ok { + return 0, nil, io.EOF + } + copy(p, item.payload) + return len(item.payload), item.addr, nil +} + +func (c *scriptedPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + c.writes <- recordedPacketWrite{ + payload: append([]byte(nil), p...), + addr: addr, + } + return len(p), nil +} + +func (c *scriptedPacketConn) Close() error { + if c.closed.CompareAndSwap(false, true) { + close(c.reads) + } + return nil +} + +func (c *scriptedPacketConn) LocalAddr() net.Addr { return c.local } +func (c *scriptedPacketConn) SetDeadline(t time.Time) error { + c.deadline.Store(t.UnixNano()) + return nil +} +func (c *scriptedPacketConn) SetReadDeadline(t time.Time) error { + c.deadline.Store(t.UnixNano()) + return nil +} +func (c *scriptedPacketConn) SetWriteDeadline(t time.Time) error { + c.deadline.Store(t.UnixNano()) + return nil +} + +type captureUDPHandler struct { + gotMetadata chan singM.Metadata +} + +func (h *captureUDPHandler) NewConnection(_ context.Context, _ net.Conn, _ singM.Metadata) error { + return nil +} + +func (h *captureUDPHandler) NewPacketConnection(_ context.Context, _ singN.PacketConn, metadata singM.Metadata) error { + select { + case h.gotMetadata <- metadata: + default: + } + return nil +} + +func (h *captureUDPHandler) NewError(_ context.Context, _ error) {} + +func newStandaloneEchoUDPConfig() *custom.UDPConfig { + return &custom.UDPConfig{ + Mode: "standalone", + Client: []*custom.UDPItem{ + {Packet: []byte{0xAA}}, + {Rand: 2, Save: "txid"}, + }, + Server: []*custom.UDPItem{ + {Packet: []byte{0xBB}}, + {Var: "txid"}, + }, + } +} + +func newStandaloneStunLikeUDPConfig() *custom.UDPConfig { + return &custom.UDPConfig{ + Mode: "standalone", + Client: []*custom.UDPItem{ + {Packet: []byte{0x00, 0x01, 0x00, 0x00, 0x21, 0x12, 0xA4, 0x42}}, + {Rand: 12, RandMin: 0x2A, RandMax: 0x2A, Save: "txid"}, + }, + Server: []*custom.UDPItem{ + {Packet: []byte{0x01, 0x01, 0x00, 0x0C, 0x21, 0x12, 0xA4, 0x42}}, + {Var: "txid"}, + {Packet: []byte{0x00, 0x20, 0x00, 0x08, 0x00, 0x01}}, + {Rand: 2, Save: "mapped_port"}, + {Rand: 4, Save: "mapped_ip"}, + }, + } +} + +func newStandaloneStunLikeUDPServerConfig() *custom.UDPConfig { + return &custom.UDPConfig{ + Mode: "standalone", + Client: []*custom.UDPItem{ + {Packet: []byte{0x00, 0x01, 0x00, 0x00, 0x21, 0x12, 0xA4, 0x42}}, + {Rand: 12, RandMin: 0x2A, RandMax: 0x2A, Save: "txid"}, + }, + Server: []*custom.UDPItem{ + {Packet: []byte{0x01, 0x01, 0x00, 0x0C, 0x21, 0x12, 0xA4, 0x42}}, + {Var: "txid"}, + {Packet: []byte{0x00, 0x20, 0x00, 0x08, 0x00, 0x01}}, + { + Expr: &custom.Expr{ + Op: "be16", + Args: []*custom.ExprArg{ + { + Value: &custom.ExprArg_Expr{ + Expr: &custom.Expr{ + Op: "xor16", + Args: []*custom.ExprArg{ + {Value: &custom.ExprArg_Metadata{Metadata: "src_port_u16"}}, + {Value: &custom.ExprArg_U64{U64: 0x2112}}, + }, + }, + }, + }, + }, + }, + }, + { + Expr: &custom.Expr{ + Op: "be32", + Args: []*custom.ExprArg{ + { + Value: &custom.ExprArg_Expr{ + Expr: &custom.Expr{ + Op: "xor32", + Args: []*custom.ExprArg{ + {Value: &custom.ExprArg_Metadata{Metadata: "src_ip4_u32"}}, + {Value: &custom.ExprArg_U64{U64: 0x2112A442}}, + }, + }, + }, + }, + }, + }, + }, + }, + } +} + +func newUDPClientServerPair(t *testing.T, cfg *custom.UDPConfig) (net.PacketConn, net.PacketConn, net.PacketConn, net.PacketConn) { + t.Helper() + + clientRaw, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = clientRaw.Close() }) + + serverRaw, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = serverRaw.Close() }) + + maskManager := finalmask.NewUdpmaskManager([]finalmask.Udpmask{cfg}) + + client, err := maskManager.WrapPacketConnClient(clientRaw) + if err != nil { + t.Fatal(err) + } + server, err := maskManager.WrapPacketConnServer(serverRaw) + if err != nil { + t.Fatal(err) + } + + return clientRaw, serverRaw, client, server +} + func TestPacketConnReadWrite(t *testing.T) { cases := []layerMask{ { @@ -317,6 +509,197 @@ func TestUDPcustomServerRejectsMismatchedStaticHeader(t *testing.T) { } } +func TestUDPcustomStandaloneClientSendsDetachedHandshakeBeforePayload(t *testing.T) { + _, serverRaw, client, _ := newUDPClientServerPair(t, newStandaloneEchoUDPConfig()) + + payload := []byte("standalone-payload") + writeErr := make(chan error, 1) + go func() { + _, err := client.WriteTo(payload, serverRaw.LocalAddr()) + writeErr <- err + }() + + wire := make([]byte, 128) + _ = serverRaw.SetDeadline(time.Now().Add(time.Second)) + n, addr, err := serverRaw.ReadFrom(wire) + if err != nil { + t.Fatal(err) + } + if n != 3 { + t.Fatalf("unexpected handshake size: got=%d want=3", n) + } + if !bytes.Equal(wire[:1], []byte{0xAA}) { + t.Fatalf("unexpected handshake prefix: %x", wire[:1]) + } + txid := append([]byte(nil), wire[1:n]...) + + if _, err := serverRaw.WriteTo(append([]byte{0xBB}, txid...), addr); err != nil { + t.Fatal(err) + } + + n, _, err = serverRaw.ReadFrom(wire) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(wire[:n], payload) { + t.Fatalf("unexpected payload after handshake: %q", wire[:n]) + } + + if err := <-writeErr; err != nil { + t.Fatal(err) + } +} + +func TestUDPcustomStandaloneServerConsumesHandshakeAndAutoResponds(t *testing.T) { + clientRaw, _, _, server := newUDPClientServerPair(t, newStandaloneEchoUDPConfig()) + + _ = clientRaw.SetDeadline(time.Now().Add(time.Second)) + _ = server.SetDeadline(time.Now().Add(time.Second)) + + readPayload := make(chan []byte, 1) + readErr := make(chan error, 1) + go func() { + buf := make([]byte, 128) + n, _, err := server.ReadFrom(buf) + if err != nil { + readErr <- err + return + } + readPayload <- append([]byte(nil), buf[:n]...) + }() + + txid := []byte{0x10, 0x20} + if _, err := clientRaw.WriteTo(append([]byte{0xAA}, txid...), server.LocalAddr()); err != nil { + t.Fatal(err) + } + + buf := make([]byte, 128) + n, _, err := clientRaw.ReadFrom(buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf[:n], append([]byte{0xBB}, txid...)) { + t.Fatalf("unexpected auto-response: %x", buf[:n]) + } + + payload := []byte("server-side-standalone") + if _, err := clientRaw.WriteTo(payload, server.LocalAddr()); err != nil { + t.Fatal(err) + } + + select { + case got := <-readPayload: + if !bytes.Equal(got, payload) { + t.Fatalf("unexpected payload: %q", got) + } + case err := <-readErr: + t.Fatal(err) + case <-time.After(2 * time.Second): + t.Fatal("payload read timeout") + } +} + +func TestUDPcustomStandaloneStunLikeExchangeUsesSavedTxidAndSrcMetadata(t *testing.T) { + clientRaw, _, _, server := newUDPClientServerPair(t, newStandaloneStunLikeUDPServerConfig()) + + _ = clientRaw.SetDeadline(time.Now().Add(time.Second)) + _ = server.SetDeadline(time.Now().Add(time.Second)) + + readPayload := make(chan []byte, 1) + readErr := make(chan error, 1) + go func() { + buf := make([]byte, 64) + n, _, err := server.ReadFrom(buf) + if err != nil { + readErr <- err + return + } + readPayload <- append([]byte(nil), buf[:n]...) + }() + + txid := bytes.Repeat([]byte{0x2A}, 12) + request := append([]byte{0x00, 0x01, 0x00, 0x00, 0x21, 0x12, 0xA4, 0x42}, txid...) + if _, err := clientRaw.WriteTo(request, server.LocalAddr()); err != nil { + t.Fatal(err) + } + + buf := make([]byte, 64) + n, _, err := clientRaw.ReadFrom(buf) + if err != nil { + t.Fatal(err) + } + + want := make([]byte, 0, 32) + want = append(want, []byte{0x01, 0x01, 0x00, 0x0C, 0x21, 0x12, 0xA4, 0x42}...) + want = append(want, txid...) + want = append(want, []byte{0x00, 0x20, 0x00, 0x08, 0x00, 0x01}...) + + clientAddr := clientRaw.LocalAddr().(*net.UDPAddr) + xPort := uint16(clientAddr.Port) ^ 0x2112 + xIP := binary.BigEndian.Uint32(clientAddr.IP.To4()) ^ 0x2112A442 + want = append(want, byte(xPort>>8), byte(xPort)) + want = append(want, byte(xIP>>24), byte(xIP>>16), byte(xIP>>8), byte(xIP)) + + if !bytes.Equal(buf[:n], want) { + t.Fatalf("unexpected stun-like response: got=%x want=%x", buf[:n], want) + } + + payload := []byte("after-standalone-stun") + if _, err := clientRaw.WriteTo(payload, server.LocalAddr()); err != nil { + t.Fatal(err) + } + + select { + case got := <-readPayload: + if !bytes.Equal(got, payload) { + t.Fatalf("unexpected payload after stun exchange: %q", got) + } + case err := <-readErr: + t.Fatal(err) + case <-time.After(2 * time.Second): + t.Fatal("payload read timeout") + } +} + +func TestUDPcustomStandaloneClientHandshakeSurvivesConcurrentReader(t *testing.T) { + _, serverRaw, clientMask, serverMask := newUDPClientServerPair(t, newStandaloneStunLikeUDPConfig()) + + go func() { + buf := make([]byte, 2048) + _ = clientMask.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + _, _, _ = clientMask.ReadFrom(buf) + }() + + go func() { + buf := make([]byte, 2048) + for { + n, addr, err := serverMask.ReadFrom(buf) + if err != nil { + return + } + if n == len([]byte("dns-payload")) && string(buf[:n]) == "dns-payload" { + return + } + _ = addr + } + }() + + writeDone := make(chan error, 1) + go func() { + _, err := clientMask.WriteTo([]byte("dns-payload"), serverRaw.LocalAddr()) + writeDone <- err + }() + + select { + case err := <-writeDone: + if err != nil { + t.Fatal(err) + } + case <-time.After(2 * time.Second): + t.Fatal("expected handshake to complete even with concurrent reader") + } +} + func TestSudokuBDD(t *testing.T) { t.Run("GivenSudokuTCPMask_WhenRoundTripWithAsciiPreference_ThenPayloadMatches", func(t *testing.T) { cfg := &sudoku.Config{ diff --git a/transport/internet/udp/dialer.go b/transport/internet/udp/dialer.go index c930c355..06f4f474 100644 --- a/transport/internet/udp/dialer.go +++ b/transport/internet/udp/dialer.go @@ -2,7 +2,7 @@ package udp import ( "context" - reflect "reflect" + "reflect" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" @@ -33,6 +33,7 @@ func init() { return nil, errors.New("mask err").Base(err) } c.PacketConn = pktConn + errors.LogInfo(ctx, "finalmask udp dialer: wrapped existing PacketConnWrapper with ", reflect.TypeOf(pktConn)) case *net.UDPConn: pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c) if err != nil { @@ -43,6 +44,7 @@ func init() { PacketConn: pktConn, Dest: c.RemoteAddr().(*net.UDPAddr), } + errors.LogInfo(ctx, "finalmask udp dialer: wrapped UDPConn with ", reflect.TypeOf(pktConn)) case *cnc.Connection: fakeConn := &internet.FakePacketConn{Conn: c} pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(fakeConn) @@ -57,6 +59,7 @@ func init() { Port: 0, }, } + errors.LogInfo(ctx, "finalmask udp dialer: wrapped cnc.Connection with ", reflect.TypeOf(pktConn)) default: conn.Close() return nil, errors.New("unknown conn ", reflect.TypeOf(c))