Files
xray-core/proxy/wireguard/bind.go
copilot-swe-agent[bot] 31d10f3544 Restore unified reader architecture per @RPRX request
Reverted timeout-based solution and restored the unified reader architecture:
- Each peer connection continuously reads and queues to dataChan
- Single unifiedReader() dispatcher matches data with read requests
- No blocking - all connections monitored simultaneously
- Addresses @RPRX's request for unified reader instead of timeout

Architecture benefits:
- True concurrent reading from all peer connections
- Clean separation between reading and dispatching
- No timeout delays or retry loops
- Scalable to any number of peers

Tests pass.

Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
2026-01-11 09:47:37 +00:00

362 lines
7.0 KiB
Go

package wireguard
import (
"context"
"errors"
"net/netip"
"strconv"
"sync"
"golang.zx2c4.com/wireguard/conn"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/transport/internet"
)
type netReadInfo struct {
// status
waiter sync.WaitGroup
// param
buff []byte
// result
bytes int
endpoint conn.Endpoint
err error
}
// reduce duplicated code
type netBind struct {
dns dns.Client
dnsOption dns.IPOption
workers int
readQueue chan *netReadInfo
}
// SetMark implements conn.Bind
func (bind *netBind) SetMark(mark uint32) error {
return nil
}
// ParseEndpoint implements conn.Bind
func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
ipStr, port, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
portNum, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
addr := net.ParseAddress(ipStr)
if addr.Family() == net.AddressFamilyDomain {
ips, _, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, dns.ErrEmptyResponse
}
addr = net.IPAddress(ips[0])
}
dst := net.Destination{
Address: addr,
Port: net.Port(portNum),
Network: net.Network_UDP,
}
return &netEndpoint{
dst: dst,
}, nil
}
// BatchSize implements conn.Bind
func (bind *netBind) BatchSize() int {
return 1
}
// Open implements conn.Bind
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
bind.readQueue = make(chan *netReadInfo)
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
defer func() {
if r := recover(); r != nil {
n = 0
err = errors.New("channel closed")
}
}()
r := &netReadInfo{
buff: bufs[0],
}
r.waiter.Add(1)
bind.readQueue <- r
r.waiter.Wait() // wait read goroutine done, or we will miss the result
sizes[0], eps[0] = r.bytes, r.endpoint
return 1, r.err
}
workers := bind.workers
if workers <= 0 {
workers = 1
}
arr := make([]conn.ReceiveFunc, workers)
for i := 0; i < workers; i++ {
arr[i] = fun
}
return arr, uint16(uport), nil
}
// Close implements conn.Bind
func (bind *netBind) Close() error {
if bind.readQueue != nil {
close(bind.readQueue)
}
return nil
}
type netBindClient struct {
netBind
ctx context.Context
dialer internet.Dialer
reserved []byte
// Track all peer connections for unified reading
connMutex sync.RWMutex
conns map[*netEndpoint]net.Conn
dataChan chan *receivedData
closeChan chan struct{}
closeOnce sync.Once
}
const (
// Buffer size for dataChan - allows some buffering of received packets
// while dispatcher matches them with read requests
dataChannelBufferSize = 100
)
type receivedData struct {
data []byte
n int
endpoint *netEndpoint
err error
}
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
c, err := bind.dialer.Dial(bind.ctx, endpoint.dst)
if err != nil {
return err
}
endpoint.conn = c
// Initialize channels on first connection
bind.connMutex.Lock()
if bind.conns == nil {
bind.conns = make(map[*netEndpoint]net.Conn)
bind.dataChan = make(chan *receivedData, dataChannelBufferSize)
bind.closeChan = make(chan struct{})
// Start unified reader dispatcher
go bind.unifiedReader()
}
bind.conns[endpoint] = c
bind.connMutex.Unlock()
// Start a reader goroutine for this specific connection
go func(conn net.Conn, endpoint *netEndpoint) {
const maxPacketSize = 1500
for {
select {
case <-bind.closeChan:
return
default:
}
buf := make([]byte, maxPacketSize)
n, err := conn.Read(buf)
// Send only the valid data portion to dispatcher
dataToSend := buf
if n > 0 && n < len(buf) {
dataToSend = buf[:n]
}
// Send received data to dispatcher
select {
case bind.dataChan <- &receivedData{
data: dataToSend,
n: n,
endpoint: endpoint,
err: err,
}:
case <-bind.closeChan:
return
}
if err != nil {
bind.connMutex.Lock()
delete(bind.conns, endpoint)
endpoint.conn = nil
bind.connMutex.Unlock()
return
}
}
}(c, endpoint)
return nil
}
// unifiedReader dispatches received data to waiting read requests
func (bind *netBindClient) unifiedReader() {
for {
select {
case data := <-bind.dataChan:
// Bounds check to prevent panic
if data.n > len(data.data) {
data.n = len(data.data)
}
// Wait for a read request with timeout to prevent blocking forever
select {
case v := <-bind.readQueue:
// Copy data to request buffer
n := copy(v.buff, data.data[:data.n])
// Clear reserved bytes if needed
if n > 3 {
v.buff[1] = 0
v.buff[2] = 0
v.buff[3] = 0
}
v.bytes = n
v.endpoint = data.endpoint
v.err = data.err
v.waiter.Done()
case <-bind.closeChan:
return
}
case <-bind.closeChan:
return
}
}
}
// Close implements conn.Bind.Close for netBindClient
func (bind *netBindClient) Close() error {
// Use sync.Once to prevent double-close panic
bind.closeOnce.Do(func() {
bind.connMutex.Lock()
if bind.closeChan != nil {
close(bind.closeChan)
}
bind.connMutex.Unlock()
})
// Call parent Close
return bind.netBind.Close()
}
func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
var err error
nend, ok := endpoint.(*netEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
if nend.conn == nil {
err = bind.connectTo(nend)
if err != nil {
return err
}
}
for _, buff := range buff {
if len(buff) > 3 && len(bind.reserved) == 3 {
copy(buff[1:], bind.reserved)
}
if _, err = nend.conn.Write(buff); err != nil {
return err
}
}
return nil
}
type netBindServer struct {
netBind
}
func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
var err error
nend, ok := endpoint.(*netEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
if nend.conn == nil {
return errors.New("connection not open yet")
}
for _, buff := range buff {
if _, err = nend.conn.Write(buff); err != nil {
return err
}
}
return err
}
type netEndpoint struct {
dst net.Destination
conn net.Conn
}
func (netEndpoint) ClearSrc() {}
func (e netEndpoint) DstIP() netip.Addr {
return netip.Addr{}
}
func (e netEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
func (e netEndpoint) DstToBytes() []byte {
var dat []byte
if e.dst.Address.Family().IsIPv4() {
dat = e.dst.Address.IP().To4()[:]
} else {
dat = e.dst.Address.IP().To16()[:]
}
dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
return dat
}
func (e netEndpoint) DstToString() string {
return e.dst.NetAddr()
}
func (e netEndpoint) SrcToString() string {
return ""
}
func toNetIpAddr(addr net.Address) netip.Addr {
if addr.Family().IsIPv4() {
ip := addr.IP()
return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
} else {
ip := addr.IP()
arr := [16]byte{}
for i := 0; i < 16; i++ {
arr[i] = ip[i]
}
return netip.AddrFrom16(arr)
}
}