mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-05-08 14:13:22 +00:00
WireGuard inbound: Fix multi-peer; Fix potential routing issue (#5843)
Fixes https://github.com/XTLS/Xray-core/pull/5554 Fixes https://github.com/XTLS/Xray-core/issues/4760
This commit is contained in:
@@ -36,7 +36,7 @@ type serverityLogger struct {
|
||||
func NewLogger(logWriterCreator WriterCreator) Handler {
|
||||
return &generalLogger{
|
||||
creator: logWriterCreator,
|
||||
buffer: make(chan Message, 16),
|
||||
buffer: make(chan Message, 128),
|
||||
access: semaphore.New(1),
|
||||
done: done.New(),
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func ReplaceWithSeverityLogger(serverity Severity) {
|
||||
w := CreateStdoutLogWriter()
|
||||
g := &generalLogger{
|
||||
creator: w,
|
||||
buffer: make(chan Message, 16),
|
||||
buffer: make(chan Message, 128),
|
||||
access: semaphore.New(1),
|
||||
done: done.New(),
|
||||
}
|
||||
|
||||
@@ -2,27 +2,23 @@ package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
gonet "net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
)
|
||||
|
||||
type netReadInfo struct {
|
||||
// status
|
||||
waiter sync.WaitGroup
|
||||
// param
|
||||
buff []byte
|
||||
// result
|
||||
bytes int
|
||||
buff []byte
|
||||
endpoint conn.Endpoint
|
||||
err error
|
||||
}
|
||||
|
||||
// reduce duplicated code
|
||||
@@ -32,6 +28,7 @@ type netBind struct {
|
||||
|
||||
workers int
|
||||
readQueue chan *netReadInfo
|
||||
closedCh chan struct{}
|
||||
}
|
||||
|
||||
// SetMark implements conn.Bind
|
||||
@@ -79,27 +76,23 @@ func (bind *netBind) BatchSize() int {
|
||||
|
||||
// Open implements conn.Bind
|
||||
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
bind.readQueue = make(chan *netReadInfo)
|
||||
bind.closedCh = make(chan struct{})
|
||||
errors.LogDebug(context.Background(), "bind opened")
|
||||
|
||||
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
n = 0
|
||||
err = errors.New("channel closed")
|
||||
}
|
||||
}()
|
||||
|
||||
r, ok := <-bind.readQueue
|
||||
if !ok {
|
||||
return 0, errors.New("channel closed")
|
||||
select {
|
||||
case r := <-bind.readQueue:
|
||||
sizes[0], eps[0] = copy(bufs[0], r.buff), r.endpoint
|
||||
return 1, nil
|
||||
case <-bind.closedCh:
|
||||
errors.LogDebug(context.Background(), "recv func closed")
|
||||
return 0, gonet.ErrClosed
|
||||
}
|
||||
|
||||
copy(bufs[0], r.buff[:r.bytes])
|
||||
sizes[0], eps[0] = r.bytes, r.endpoint
|
||||
r.waiter.Done()
|
||||
return 1, r.err
|
||||
}
|
||||
workers := bind.workers
|
||||
if workers <= 0 {
|
||||
workers = runtime.NumCPU()
|
||||
}
|
||||
if workers <= 0 {
|
||||
workers = 1
|
||||
}
|
||||
@@ -113,8 +106,9 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
|
||||
// Close implements conn.Bind
|
||||
func (bind *netBind) Close() error {
|
||||
if bind.readQueue != nil {
|
||||
close(bind.readQueue)
|
||||
errors.LogDebug(context.Background(), "bind closed")
|
||||
if bind.closedCh != nil {
|
||||
close(bind.closedCh)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -134,35 +128,35 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
||||
}
|
||||
endpoint.conn = c
|
||||
|
||||
go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) {
|
||||
defer func() {
|
||||
_ = recover() // handle send on closed channel
|
||||
}()
|
||||
go func() {
|
||||
for {
|
||||
buff := make([]byte, 1700)
|
||||
i, err := c.Read(buff)
|
||||
buff := make([]byte, device.MaxMessageSize)
|
||||
n, err := c.Read(buff)
|
||||
|
||||
if i > 3 {
|
||||
if err != nil {
|
||||
endpoint.conn = nil
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if n > 3 {
|
||||
buff[1] = 0
|
||||
buff[2] = 0
|
||||
buff[3] = 0
|
||||
}
|
||||
|
||||
r := &netReadInfo{
|
||||
buff: buff,
|
||||
bytes: i,
|
||||
select {
|
||||
case bind.readQueue <- &netReadInfo{
|
||||
buff: buff[:n],
|
||||
endpoint: endpoint,
|
||||
err: err,
|
||||
}
|
||||
r.waiter.Add(1)
|
||||
readQueue <- r
|
||||
r.waiter.Wait()
|
||||
if err != nil {
|
||||
}:
|
||||
case <-bind.closedCh:
|
||||
endpoint.conn = nil
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}(bind.readQueue, endpoint)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -206,7 +200,8 @@ func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
||||
}
|
||||
|
||||
if nend.conn == nil {
|
||||
return errors.New("connection not open yet")
|
||||
errors.LogDebug(context.Background(), nend.dst.NetAddr(), " send on closed peer")
|
||||
return errors.New("peer closed")
|
||||
}
|
||||
|
||||
for _, buff := range buff {
|
||||
|
||||
@@ -121,7 +121,8 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer)
|
||||
IPv4Enable: h.hasIPv4,
|
||||
IPv6Enable: h.hasIPv6,
|
||||
},
|
||||
workers: int(h.conf.NumWorkers),
|
||||
workers: int(h.conf.NumWorkers),
|
||||
readQueue: make(chan *netReadInfo),
|
||||
},
|
||||
ctx: ctx,
|
||||
dialer: dialer,
|
||||
|
||||
@@ -2,8 +2,6 @@ package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
goerrors "errors"
|
||||
"io"
|
||||
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
c "github.com/xtls/xray-core/common/ctx"
|
||||
@@ -51,6 +49,8 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
||||
IPv4Enable: hasIPv4,
|
||||
IPv6Enable: hasIPv6,
|
||||
},
|
||||
workers: int(conf.NumWorkers),
|
||||
readQueue: make(chan *netReadInfo),
|
||||
},
|
||||
},
|
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||
@@ -93,25 +93,31 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
|
||||
|
||||
reader := buf.NewPacketReader(conn)
|
||||
for {
|
||||
mpayload, err := reader.ReadMultiBuffer()
|
||||
mb, err := reader.ReadMultiBuffer()
|
||||
if err != nil {
|
||||
nep.conn = nil
|
||||
buf.ReleaseMulti(mb)
|
||||
return err
|
||||
}
|
||||
|
||||
for _, payload := range mpayload {
|
||||
v, ok := <-s.bindServer.readQueue
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
i, err := payload.Read(v.buff)
|
||||
for i, b := range mb {
|
||||
buff := b.Bytes()
|
||||
|
||||
v.bytes = i
|
||||
v.endpoint = nep
|
||||
v.err = err
|
||||
v.waiter.Done()
|
||||
if err != nil && goerrors.Is(err, io.EOF) {
|
||||
if b.Len() > 3 {
|
||||
buff[1] = 0
|
||||
buff[2] = 0
|
||||
buff[3] = 0
|
||||
}
|
||||
|
||||
select {
|
||||
case s.bindServer.readQueue <- &netReadInfo{
|
||||
buff: buff,
|
||||
endpoint: nep,
|
||||
}:
|
||||
case <-s.bindServer.closedCh:
|
||||
nep.conn = nil
|
||||
return nil
|
||||
buf.ReleaseMulti(mb[i:])
|
||||
return errors.New("bind closed")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -138,9 +144,11 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
|
||||
// Currently we have no way to link to the original source address
|
||||
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
|
||||
ctx = session.ContextWithInbound(ctx, &inbound)
|
||||
content := new(session.Content)
|
||||
if s.info.contentTag != nil {
|
||||
ctx = session.ContextWithContent(ctx, s.info.contentTag)
|
||||
content.SniffingRequest = s.info.contentTag.SniffingRequest
|
||||
}
|
||||
ctx = session.ContextWithContent(ctx, content)
|
||||
ctx = session.SubContextFromMuxInbound(ctx)
|
||||
|
||||
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
||||
|
||||
@@ -8,25 +8,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
var wgLogger = &device.Logger{
|
||||
Verbosef: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Debug,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
Errorf: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Error,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||||
deviceConfig := config.(*DeviceConfig)
|
||||
|
||||
Reference in New Issue
Block a user