mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-05-08 14:13:22 +00:00
Reverted timeout-based solution and restored the unified reader architecture: - Each peer connection continuously reads and queues to dataChan - Single unifiedReader() dispatcher matches data with read requests - No blocking - all connections monitored simultaneously - Addresses @RPRX's request for unified reader instead of timeout Architecture benefits: - True concurrent reading from all peer connections - Clean separation between reading and dispatching - No timeout delays or retry loops - Scalable to any number of peers Tests pass. Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
362 lines
7.0 KiB
Go
362 lines
7.0 KiB
Go
package wireguard
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/netip"
|
|
"strconv"
|
|
"sync"
|
|
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
|
|
"github.com/xtls/xray-core/common/net"
|
|
"github.com/xtls/xray-core/features/dns"
|
|
"github.com/xtls/xray-core/transport/internet"
|
|
)
|
|
|
|
type netReadInfo struct {
|
|
// status
|
|
waiter sync.WaitGroup
|
|
// param
|
|
buff []byte
|
|
// result
|
|
bytes int
|
|
endpoint conn.Endpoint
|
|
err error
|
|
}
|
|
|
|
// reduce duplicated code
|
|
type netBind struct {
|
|
dns dns.Client
|
|
dnsOption dns.IPOption
|
|
|
|
workers int
|
|
readQueue chan *netReadInfo
|
|
}
|
|
|
|
// SetMark implements conn.Bind
|
|
func (bind *netBind) SetMark(mark uint32) error {
|
|
return nil
|
|
}
|
|
|
|
// ParseEndpoint implements conn.Bind
|
|
func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
|
ipStr, port, err := net.SplitHostPort(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
portNum, err := strconv.Atoi(port)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
addr := net.ParseAddress(ipStr)
|
|
if addr.Family() == net.AddressFamilyDomain {
|
|
ips, _, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if len(ips) == 0 {
|
|
return nil, dns.ErrEmptyResponse
|
|
}
|
|
addr = net.IPAddress(ips[0])
|
|
}
|
|
|
|
dst := net.Destination{
|
|
Address: addr,
|
|
Port: net.Port(portNum),
|
|
Network: net.Network_UDP,
|
|
}
|
|
|
|
return &netEndpoint{
|
|
dst: dst,
|
|
}, nil
|
|
}
|
|
|
|
// BatchSize implements conn.Bind
|
|
func (bind *netBind) BatchSize() int {
|
|
return 1
|
|
}
|
|
|
|
// Open implements conn.Bind
|
|
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
bind.readQueue = make(chan *netReadInfo)
|
|
|
|
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
n = 0
|
|
err = errors.New("channel closed")
|
|
}
|
|
}()
|
|
|
|
r := &netReadInfo{
|
|
buff: bufs[0],
|
|
}
|
|
r.waiter.Add(1)
|
|
bind.readQueue <- r
|
|
r.waiter.Wait() // wait read goroutine done, or we will miss the result
|
|
sizes[0], eps[0] = r.bytes, r.endpoint
|
|
return 1, r.err
|
|
}
|
|
workers := bind.workers
|
|
if workers <= 0 {
|
|
workers = 1
|
|
}
|
|
arr := make([]conn.ReceiveFunc, workers)
|
|
for i := 0; i < workers; i++ {
|
|
arr[i] = fun
|
|
}
|
|
|
|
return arr, uint16(uport), nil
|
|
}
|
|
|
|
// Close implements conn.Bind
|
|
func (bind *netBind) Close() error {
|
|
if bind.readQueue != nil {
|
|
close(bind.readQueue)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type netBindClient struct {
|
|
netBind
|
|
|
|
ctx context.Context
|
|
dialer internet.Dialer
|
|
reserved []byte
|
|
|
|
// Track all peer connections for unified reading
|
|
connMutex sync.RWMutex
|
|
conns map[*netEndpoint]net.Conn
|
|
dataChan chan *receivedData
|
|
closeChan chan struct{}
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
const (
|
|
// Buffer size for dataChan - allows some buffering of received packets
|
|
// while dispatcher matches them with read requests
|
|
dataChannelBufferSize = 100
|
|
)
|
|
|
|
type receivedData struct {
|
|
data []byte
|
|
n int
|
|
endpoint *netEndpoint
|
|
err error
|
|
}
|
|
|
|
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
|
c, err := bind.dialer.Dial(bind.ctx, endpoint.dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
endpoint.conn = c
|
|
|
|
// Initialize channels on first connection
|
|
bind.connMutex.Lock()
|
|
if bind.conns == nil {
|
|
bind.conns = make(map[*netEndpoint]net.Conn)
|
|
bind.dataChan = make(chan *receivedData, dataChannelBufferSize)
|
|
bind.closeChan = make(chan struct{})
|
|
|
|
// Start unified reader dispatcher
|
|
go bind.unifiedReader()
|
|
}
|
|
bind.conns[endpoint] = c
|
|
bind.connMutex.Unlock()
|
|
|
|
// Start a reader goroutine for this specific connection
|
|
go func(conn net.Conn, endpoint *netEndpoint) {
|
|
const maxPacketSize = 1500
|
|
for {
|
|
select {
|
|
case <-bind.closeChan:
|
|
return
|
|
default:
|
|
}
|
|
|
|
buf := make([]byte, maxPacketSize)
|
|
n, err := conn.Read(buf)
|
|
|
|
// Send only the valid data portion to dispatcher
|
|
dataToSend := buf
|
|
if n > 0 && n < len(buf) {
|
|
dataToSend = buf[:n]
|
|
}
|
|
|
|
// Send received data to dispatcher
|
|
select {
|
|
case bind.dataChan <- &receivedData{
|
|
data: dataToSend,
|
|
n: n,
|
|
endpoint: endpoint,
|
|
err: err,
|
|
}:
|
|
case <-bind.closeChan:
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
bind.connMutex.Lock()
|
|
delete(bind.conns, endpoint)
|
|
endpoint.conn = nil
|
|
bind.connMutex.Unlock()
|
|
return
|
|
}
|
|
}
|
|
}(c, endpoint)
|
|
|
|
return nil
|
|
}
|
|
|
|
// unifiedReader dispatches received data to waiting read requests
|
|
func (bind *netBindClient) unifiedReader() {
|
|
for {
|
|
select {
|
|
case data := <-bind.dataChan:
|
|
// Bounds check to prevent panic
|
|
if data.n > len(data.data) {
|
|
data.n = len(data.data)
|
|
}
|
|
|
|
// Wait for a read request with timeout to prevent blocking forever
|
|
select {
|
|
case v := <-bind.readQueue:
|
|
// Copy data to request buffer
|
|
n := copy(v.buff, data.data[:data.n])
|
|
|
|
// Clear reserved bytes if needed
|
|
if n > 3 {
|
|
v.buff[1] = 0
|
|
v.buff[2] = 0
|
|
v.buff[3] = 0
|
|
}
|
|
|
|
v.bytes = n
|
|
v.endpoint = data.endpoint
|
|
v.err = data.err
|
|
v.waiter.Done()
|
|
case <-bind.closeChan:
|
|
return
|
|
}
|
|
case <-bind.closeChan:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close implements conn.Bind.Close for netBindClient
|
|
func (bind *netBindClient) Close() error {
|
|
// Use sync.Once to prevent double-close panic
|
|
bind.closeOnce.Do(func() {
|
|
bind.connMutex.Lock()
|
|
if bind.closeChan != nil {
|
|
close(bind.closeChan)
|
|
}
|
|
bind.connMutex.Unlock()
|
|
})
|
|
|
|
// Call parent Close
|
|
return bind.netBind.Close()
|
|
}
|
|
|
|
func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
|
var err error
|
|
|
|
nend, ok := endpoint.(*netEndpoint)
|
|
if !ok {
|
|
return conn.ErrWrongEndpointType
|
|
}
|
|
|
|
if nend.conn == nil {
|
|
err = bind.connectTo(nend)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
for _, buff := range buff {
|
|
if len(buff) > 3 && len(bind.reserved) == 3 {
|
|
copy(buff[1:], bind.reserved)
|
|
}
|
|
if _, err = nend.conn.Write(buff); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type netBindServer struct {
|
|
netBind
|
|
}
|
|
|
|
func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
|
var err error
|
|
|
|
nend, ok := endpoint.(*netEndpoint)
|
|
if !ok {
|
|
return conn.ErrWrongEndpointType
|
|
}
|
|
|
|
if nend.conn == nil {
|
|
return errors.New("connection not open yet")
|
|
}
|
|
|
|
for _, buff := range buff {
|
|
if _, err = nend.conn.Write(buff); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
type netEndpoint struct {
|
|
dst net.Destination
|
|
conn net.Conn
|
|
}
|
|
|
|
func (netEndpoint) ClearSrc() {}
|
|
|
|
func (e netEndpoint) DstIP() netip.Addr {
|
|
return netip.Addr{}
|
|
}
|
|
|
|
func (e netEndpoint) SrcIP() netip.Addr {
|
|
return netip.Addr{}
|
|
}
|
|
|
|
func (e netEndpoint) DstToBytes() []byte {
|
|
var dat []byte
|
|
if e.dst.Address.Family().IsIPv4() {
|
|
dat = e.dst.Address.IP().To4()[:]
|
|
} else {
|
|
dat = e.dst.Address.IP().To16()[:]
|
|
}
|
|
dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
|
|
return dat
|
|
}
|
|
|
|
func (e netEndpoint) DstToString() string {
|
|
return e.dst.NetAddr()
|
|
}
|
|
|
|
func (e netEndpoint) SrcToString() string {
|
|
return ""
|
|
}
|
|
|
|
func toNetIpAddr(addr net.Address) netip.Addr {
|
|
if addr.Family().IsIPv4() {
|
|
ip := addr.IP()
|
|
return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
|
|
} else {
|
|
ip := addr.IP()
|
|
arr := [16]byte{}
|
|
for i := 0; i < 16; i++ {
|
|
arr[i] = ip[i]
|
|
}
|
|
return netip.AddrFrom16(arr)
|
|
}
|
|
}
|