XDNS finalmask: Support resolvers (client) and domains (server) instead of domain (#5872)

https://github.com/XTLS/Xray-core/pull/5872#issuecomment-4192730898

Example: https://github.com/XTLS/Xray-core/pull/5872#issuecomment-4196172391

---------

Co-authored-by: LjhAUMEM <llnu14702@gmail.com>
This commit is contained in:
Nikita Nemirovsky
2026-04-12 03:37:32 +08:00
committed by GitHub
parent a91a88c7b2
commit 1642fdfbdd
10 changed files with 266 additions and 112 deletions

View File

@@ -1662,16 +1662,30 @@ func (c *Sudoku) Build() (proto.Message, error) {
}
type Xdns struct {
Domain string `json:"domain"`
Domain json.RawMessage `json:"domain"`
Domains []string `json:"domains"`
Resolvers []string `json:"resolvers"`
}
func (c *Xdns) Build() (proto.Message, error) {
if c.Domain == "" {
return nil, errors.New("empty domain")
if c.Domain != nil {
return nil, errors.PrintRemovedFeatureError("domain", "domains(server) & resolvers(client)")
}
if len(c.Domains) == 0 && len(c.Resolvers) == 0 {
return nil, errors.New("empty domains & empty resolvers")
}
for _, r := range c.Resolvers {
if !strings.Contains(r, "+udp://") {
return nil, errors.New("invalid resolver ", r)
}
}
return &xdns.Config{
Domain: c.Domain,
Domains: c.Domains,
Resolvers: c.Resolvers,
}, nil
}

View File

@@ -9,7 +9,10 @@ import (
go_errors "errors"
"io"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/xtls/xray-core/common"
@@ -34,10 +37,14 @@ type packet struct {
}
type xdnsConnClient struct {
net.PacketConn
conn net.PacketConn
resolverConns []net.PacketConn
resolverAddrs []*net.UDPAddr
resolverIdx uint32
resolverSend []atomic.Uint32
clientID []byte
domain Name
domains []Name
pollChan chan struct{}
readQueue chan *packet
@@ -48,16 +55,66 @@ type xdnsConnClient struct {
}
func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
domain, err := ParseName(c.Domain)
if err != nil {
return nil, err
if len(c.Resolvers) == 0 {
return nil, errors.New("empty resolvers")
}
var domains []Name
var servers []string
for _, rs := range c.Resolvers {
parts := strings.Split(rs, "+udp://")
if len(parts) != 2 {
return nil, errors.New("invalid resolvers")
}
domain, err := ParseName(parts[0])
if err != nil {
return nil, err
}
domains = append(domains, domain)
servers = append(servers, parts[1])
}
var resolverConns []net.PacketConn
var resolverAddrs []*net.UDPAddr
var resolverSend []atomic.Uint32
for _, rs := range servers {
h, p, err := net.SplitHostPort(rs)
if err != nil {
return nil, err
}
ip := net.ParseIP(h)
if ip == nil {
return nil, errors.New("invalid ip address")
}
port, _ := strconv.Atoi(p)
if port == 0 {
return nil, errors.New("invalid port")
}
var uc net.PacketConn
if ip.To4() != nil {
uc, err = net.ListenPacket("udp4", ":0")
} else {
uc, err = net.ListenPacket("udp6", ":0")
}
if err != nil {
for _, rc := range resolverConns {
rc.Close()
}
return nil, errors.New("failed to create resolver socket: ", err)
}
resolverConns = append(resolverConns, uc)
resolverAddrs = append(resolverAddrs, &net.UDPAddr{IP: ip, Port: port})
}
resolverSend = make([]atomic.Uint32, len(resolverConns))
conn := &xdnsConnClient{
PacketConn: raw,
conn: raw,
resolverConns: resolverConns,
resolverAddrs: resolverAddrs,
resolverSend: resolverSend,
clientID: make([]byte, 8),
domain: domain,
domains: domains,
pollChan: make(chan struct{}, pollLimit),
readQueue: make(chan *packet, 256),
@@ -73,58 +130,70 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
}
func (c *xdnsConnClient) recvLoop() {
var buf [finalmask.UDPSize]byte
var wg sync.WaitGroup
for {
if c.closed {
break
}
for i, rc := range c.resolverConns {
wg.Add(1)
go func() {
defer wg.Done()
n, addr, err := c.PacketConn.ReadFrom(buf[:])
if err != nil || n == 0 {
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.EOF) {
break
var buf [finalmask.UDPSize]byte
for {
if c.closed {
break
}
n, addr, err := rc.ReadFrom(buf[:])
if err != nil {
if go_errors.Is(err, net.ErrClosed) {
break
}
continue
}
resp, err := MessageFromWireFormat(buf[:n])
if err != nil {
errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err)
continue
}
payload := dnsResponsePayload(&resp, c.domains)
r := bytes.NewReader(payload)
anyPacket := false
for {
p, err := nextPacket(r)
if err != nil {
break
}
anyPacket = true
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.readQueue <- &packet{
p: buf,
addr: addr,
}:
default:
errors.LogDebug(context.Background(), addr, " mask read err queue full")
}
}
if anyPacket {
c.resolverSend[i].Store(0)
select {
case c.pollChan <- struct{}{}:
default:
}
}
}
continue
}
resp, err := MessageFromWireFormat(buf[:n])
if err != nil {
errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err)
continue
}
payload := dnsResponsePayload(&resp, c.domain)
r := bytes.NewReader(payload)
anyPacket := false
for {
p, err := nextPacket(r)
if err != nil {
break
}
anyPacket = true
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.readQueue <- &packet{
p: buf,
addr: addr,
}:
default:
errors.LogDebug(context.Background(), addr, " mask read err queue full")
}
}
if anyPacket {
select {
case c.pollChan <- struct{}{}:
default:
}
}
}()
}
wg.Wait()
errors.LogDebug(context.Background(), "xdns closed")
close(c.pollChan)
@@ -138,8 +207,6 @@ func (c *xdnsConnClient) recvLoop() {
}
func (c *xdnsConnClient) sendLoop() {
var addr net.Addr
pollDelay := initPollDelay
pollTimer := time.NewTimer(pollDelay)
for {
@@ -158,17 +225,14 @@ func (c *xdnsConnClient) sendLoop() {
}
if p != nil {
addr = p.addr
select {
case <-c.pollChan:
default:
}
} else if addr != nil {
encoded, _ := encode(nil, c.clientID, c.domain)
} else {
encoded, _ := encode(nil, c.clientID, c.domains[c.resolverIdx])
p = &packet{
p: encoded,
addr: addr,
p: encoded,
}
}
@@ -189,10 +253,16 @@ func (c *xdnsConnClient) sendLoop() {
return
}
if p != nil {
_, err := c.PacketConn.WriteTo(p.p, p.addr)
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.ErrClosedPipe) {
c.closed = true
cur := c.resolverIdx
curSend := c.resolverSend[c.resolverIdx].Add(1)
_, _ = c.resolverConns[c.resolverIdx].WriteTo(p.p, c.resolverAddrs[c.resolverIdx])
for {
c.resolverIdx += 1
c.resolverIdx %= uint32(len(c.resolverConns))
if c.resolverIdx == cur {
break
}
if c.resolverSend[c.resolverIdx].Load() < curSend {
break
}
}
@@ -220,7 +290,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, io.ErrClosedPipe
}
encoded, err := encode(p, c.clientID, c.domain)
encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverConns))])
if err != nil {
errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p))
return 0, nil
@@ -240,7 +310,35 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
func (c *xdnsConnClient) Close() error {
c.closed = true
return c.PacketConn.Close()
for _, rc := range c.resolverConns {
rc.Close()
}
return c.conn.Close()
}
func (c *xdnsConnClient) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *xdnsConnClient) SetDeadline(t time.Time) error {
for _, rc := range c.resolverConns {
rc.SetDeadline(t)
}
return c.conn.SetDeadline(t)
}
func (c *xdnsConnClient) SetReadDeadline(t time.Time) error {
for _, rc := range c.resolverConns {
rc.SetReadDeadline(t)
}
return c.conn.SetReadDeadline(t)
}
func (c *xdnsConnClient) SetWriteDeadline(t time.Time) error {
for _, rc := range c.resolverConns {
rc.SetWriteDeadline(t)
}
return c.conn.SetWriteDeadline(t)
}
func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {
@@ -332,7 +430,7 @@ func nextPacket(r *bytes.Reader) ([]byte, error) {
return p, err
}
func dnsResponsePayload(resp *Message, domain Name) []byte {
func dnsResponsePayload(resp *Message, domains []Name) []byte {
if resp.Flags&0x8000 != 0x8000 {
return nil
}
@@ -345,7 +443,13 @@ func dnsResponsePayload(resp *Message, domain Name) []byte {
}
answer := resp.Answer[0]
_, ok := answer.Name.TrimSuffix(domain)
var ok bool
for _, domain := range domains {
_, ok = answer.Name.TrimSuffix(domain)
if ok {
break
}
}
if !ok {
return nil
}

View File

@@ -2,15 +2,27 @@ package xdns
import (
"net"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/hysteria/udphop"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
_, ok1 := raw.(*internet.FakePacketConn)
_, ok2 := raw.(*udphop.UdpHopPacketConn)
if level != 0 || ok1 || ok2 {
return nil, errors.New("xdns requires being at the outermost level")
}
return NewConnClient(c, raw)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
if level != 0 {
return nil, errors.New("xdns requires being at the outermost level")
}
return NewConnServer(c, raw)
}

View File

@@ -23,7 +23,8 @@ const (
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"`
Domains []string `protobuf:"bytes,1,rep,name=domains,proto3" json:"domains,omitempty"`
Resolvers []string `protobuf:"bytes,2,rep,name=resolvers,proto3" json:"resolvers,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -58,20 +59,28 @@ func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_xdns_config_proto_rawDescGZIP(), []int{0}
}
func (x *Config) GetDomain() string {
func (x *Config) GetDomains() []string {
if x != nil {
return x.Domain
return x.Domains
}
return ""
return nil
}
func (x *Config) GetResolvers() []string {
if x != nil {
return x.Resolvers
}
return nil
}
var File_transport_internet_finalmask_xdns_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_xdns_config_proto_rawDesc = "" +
"\n" +
".transport/internet/finalmask/xdns/config.proto\x12&xray.transport.internet.finalmask.xdns\" \n" +
"\x06Config\x12\x16\n" +
"\x06domain\x18\x01 \x01(\tR\x06domainB\x94\x01\n" +
".transport/internet/finalmask/xdns/config.proto\x12&xray.transport.internet.finalmask.xdns\"@\n" +
"\x06Config\x12\x18\n" +
"\adomains\x18\x01 \x03(\tR\adomains\x12\x1c\n" +
"\tresolvers\x18\x02 \x03(\tR\tresolversB\x94\x01\n" +
"*com.xray.transport.internet.finalmask.xdnsP\x01Z;github.com/xtls/xray-core/transport/internet/finalmask/xdns\xaa\x02&Xray.Transport.Internet.Finalmask.Xdnsb\x06proto3"
var (

View File

@@ -7,6 +7,6 @@ option java_package = "com.xray.transport.internet.finalmask.xdns";
option java_multiple_files = true;
message Config {
string domain = 1;
}
repeated string domains = 1;
repeated string resolvers = 2;
}

View File

@@ -559,6 +559,7 @@ func TestEncodeRDataTXT(t *testing.T) {
}
fmt.Println(EncodeRDataTXT(nil))
fmt.Println(computeMaxEncodedPayload(maxUDPPayload))
}
func TestRDataTXTRoundTrip(t *testing.T) {

View File

@@ -52,7 +52,7 @@ type queue struct {
type xdnsConnServer struct {
net.PacketConn
domain Name
domains []Name
ch chan *record
readQueue chan *packet
@@ -63,15 +63,22 @@ type xdnsConnServer struct {
}
func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) {
domain, err := ParseName(c.Domain)
if err != nil {
return nil, err
if len(c.Domains) == 0 {
return nil, errors.New("empty domains")
}
domains := make([]Name, 0, len(c.Domains))
for _, domain := range c.Domains {
domain, err := ParseName(domain)
if err != nil {
return nil, err
}
domains = append(domains, domain)
}
conn := &xdnsConnServer{
PacketConn: raw,
domain: domain,
domains: domains,
ch: make(chan *record, 500),
readQueue: make(chan *packet, 512),
@@ -156,8 +163,8 @@ func (c *xdnsConnServer) recvLoop() {
}
n, addr, err := c.PacketConn.ReadFrom(buf[:])
if err != nil || n == 0 {
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.EOF) {
if err != nil {
if go_errors.Is(err, net.ErrClosed) {
break
}
continue
@@ -169,7 +176,7 @@ func (c *xdnsConnServer) recvLoop() {
continue
}
resp, payload := responseFor(&query, c.domain)
resp, payload := responseFor(&query, c.domains)
var clientID [8]byte
n = copy(clientID[:], payload)
@@ -321,7 +328,7 @@ func (c *xdnsConnServer) sendLoop() {
}
_, err = c.PacketConn.WriteTo(buf, rec.Addr)
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.ErrClosedPipe) {
if go_errors.Is(err, net.ErrClosed) {
c.closed = true
break
}
@@ -399,7 +406,7 @@ func nextPacketServer(r *bytes.Reader) ([]byte, error) {
}
}
func responseFor(query *Message, domain Name) (*Message, []byte) {
func responseFor(query *Message, domains []Name) (*Message, []byte) {
resp := &Message{
ID: query.ID,
Flags: 0x8000,
@@ -447,7 +454,14 @@ func responseFor(query *Message, domain Name) (*Message, []byte) {
}
question := query.Question[0]
prefix, ok := question.Name.TrimSuffix(domain)
var prefix Name
var ok bool
for _, domain := range domains {
prefix, ok = question.Name.TrimSuffix(domain)
if ok {
break
}
}
if !ok {
resp.Flags |= RcodeNameError
return resp, nil
@@ -525,7 +539,7 @@ func computeMaxEncodedPayload(limit int) int {
},
},
}
resp, _ := responseFor(query, [][]byte{})
resp, _ := responseFor(query, []Name{[][]byte{}})
resp.Answer = []RR{
{

View File

@@ -10,9 +10,7 @@ import (
"github.com/xtls/xray-core/common/crypto"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/finalmask"
"github.com/xtls/xray-core/transport/internet/hysteria/udphop"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
@@ -54,13 +52,7 @@ type xicmpConnClient struct {
mutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, level int) (net.PacketConn, error) {
_, ok1 := raw.(*internet.FakePacketConn)
_, ok2 := raw.(*udphop.UdpHopPacketConn)
if level != 0 || ok1 || ok2 {
return nil, errors.New("xicmp requires being at the outermost level")
}
func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
network := "ip4:icmp"
typ := icmp.Type(ipv4.ICMPTypeEcho)
proto := 1

View File

@@ -2,15 +2,27 @@ package xicmp
import (
"net"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/hysteria/udphop"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
return NewConnClient(c, raw, level)
_, ok1 := raw.(*internet.FakePacketConn)
_, ok2 := raw.(*udphop.UdpHopPacketConn)
if level != 0 || ok1 || ok2 {
return nil, errors.New("xicmp requires being at the outermost level")
}
return NewConnClient(c, raw)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
return NewConnServer(c, raw, level)
if level != 0 {
return nil, errors.New("xicmp requires being at the outermost level")
}
return NewConnServer(c, raw)
}

View File

@@ -50,11 +50,7 @@ type xicmpConnServer struct {
mutex sync.Mutex
}
func NewConnServer(c *Config, raw net.PacketConn, level int) (net.PacketConn, error) {
if level != 0 {
return nil, errors.New("xicmp requires being at the outermost level")
}
func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) {
network := "ip4:icmp"
typ := icmp.Type(ipv4.ICMPTypeEchoReply)
proto := 1