mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-05-08 14:13:22 +00:00
WireGuard inbound: Fix multi-peer; Fix potential routing issue (#5843)
Fixes https://github.com/XTLS/Xray-core/pull/5554 Fixes https://github.com/XTLS/Xray-core/issues/4760
This commit is contained in:
@@ -36,7 +36,7 @@ type serverityLogger struct {
|
|||||||
func NewLogger(logWriterCreator WriterCreator) Handler {
|
func NewLogger(logWriterCreator WriterCreator) Handler {
|
||||||
return &generalLogger{
|
return &generalLogger{
|
||||||
creator: logWriterCreator,
|
creator: logWriterCreator,
|
||||||
buffer: make(chan Message, 16),
|
buffer: make(chan Message, 128),
|
||||||
access: semaphore.New(1),
|
access: semaphore.New(1),
|
||||||
done: done.New(),
|
done: done.New(),
|
||||||
}
|
}
|
||||||
@@ -46,7 +46,7 @@ func ReplaceWithSeverityLogger(serverity Severity) {
|
|||||||
w := CreateStdoutLogWriter()
|
w := CreateStdoutLogWriter()
|
||||||
g := &generalLogger{
|
g := &generalLogger{
|
||||||
creator: w,
|
creator: w,
|
||||||
buffer: make(chan Message, 16),
|
buffer: make(chan Message, 128),
|
||||||
access: semaphore.New(1),
|
access: semaphore.New(1),
|
||||||
done: done.New(),
|
done: done.New(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,27 +2,23 @@ package wireguard
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
gonet "net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
"github.com/xtls/xray-core/common/errors"
|
||||||
"github.com/xtls/xray-core/common/net"
|
"github.com/xtls/xray-core/common/net"
|
||||||
"github.com/xtls/xray-core/features/dns"
|
"github.com/xtls/xray-core/features/dns"
|
||||||
"github.com/xtls/xray-core/transport/internet"
|
"github.com/xtls/xray-core/transport/internet"
|
||||||
)
|
)
|
||||||
|
|
||||||
type netReadInfo struct {
|
type netReadInfo struct {
|
||||||
// status
|
|
||||||
waiter sync.WaitGroup
|
|
||||||
// param
|
|
||||||
buff []byte
|
buff []byte
|
||||||
// result
|
|
||||||
bytes int
|
|
||||||
endpoint conn.Endpoint
|
endpoint conn.Endpoint
|
||||||
err error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce duplicated code
|
// reduce duplicated code
|
||||||
@@ -32,6 +28,7 @@ type netBind struct {
|
|||||||
|
|
||||||
workers int
|
workers int
|
||||||
readQueue chan *netReadInfo
|
readQueue chan *netReadInfo
|
||||||
|
closedCh chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMark implements conn.Bind
|
// SetMark implements conn.Bind
|
||||||
@@ -79,27 +76,23 @@ func (bind *netBind) BatchSize() int {
|
|||||||
|
|
||||||
// Open implements conn.Bind
|
// Open implements conn.Bind
|
||||||
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||||
bind.readQueue = make(chan *netReadInfo)
|
bind.closedCh = make(chan struct{})
|
||||||
|
errors.LogDebug(context.Background(), "bind opened")
|
||||||
|
|
||||||
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
||||||
defer func() {
|
select {
|
||||||
if r := recover(); r != nil {
|
case r := <-bind.readQueue:
|
||||||
n = 0
|
sizes[0], eps[0] = copy(bufs[0], r.buff), r.endpoint
|
||||||
err = errors.New("channel closed")
|
return 1, nil
|
||||||
|
case <-bind.closedCh:
|
||||||
|
errors.LogDebug(context.Background(), "recv func closed")
|
||||||
|
return 0, gonet.ErrClosed
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
r, ok := <-bind.readQueue
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("channel closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(bufs[0], r.buff[:r.bytes])
|
|
||||||
sizes[0], eps[0] = r.bytes, r.endpoint
|
|
||||||
r.waiter.Done()
|
|
||||||
return 1, r.err
|
|
||||||
}
|
}
|
||||||
workers := bind.workers
|
workers := bind.workers
|
||||||
|
if workers <= 0 {
|
||||||
|
workers = runtime.NumCPU()
|
||||||
|
}
|
||||||
if workers <= 0 {
|
if workers <= 0 {
|
||||||
workers = 1
|
workers = 1
|
||||||
}
|
}
|
||||||
@@ -113,8 +106,9 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|||||||
|
|
||||||
// Close implements conn.Bind
|
// Close implements conn.Bind
|
||||||
func (bind *netBind) Close() error {
|
func (bind *netBind) Close() error {
|
||||||
if bind.readQueue != nil {
|
errors.LogDebug(context.Background(), "bind closed")
|
||||||
close(bind.readQueue)
|
if bind.closedCh != nil {
|
||||||
|
close(bind.closedCh)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -134,35 +128,35 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
|||||||
}
|
}
|
||||||
endpoint.conn = c
|
endpoint.conn = c
|
||||||
|
|
||||||
go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) {
|
go func() {
|
||||||
defer func() {
|
|
||||||
_ = recover() // handle send on closed channel
|
|
||||||
}()
|
|
||||||
for {
|
for {
|
||||||
buff := make([]byte, 1700)
|
buff := make([]byte, device.MaxMessageSize)
|
||||||
i, err := c.Read(buff)
|
n, err := c.Read(buff)
|
||||||
|
|
||||||
if i > 3 {
|
if err != nil {
|
||||||
|
endpoint.conn = nil
|
||||||
|
c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if n > 3 {
|
||||||
buff[1] = 0
|
buff[1] = 0
|
||||||
buff[2] = 0
|
buff[2] = 0
|
||||||
buff[3] = 0
|
buff[3] = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
r := &netReadInfo{
|
select {
|
||||||
buff: buff,
|
case bind.readQueue <- &netReadInfo{
|
||||||
bytes: i,
|
buff: buff[:n],
|
||||||
endpoint: endpoint,
|
endpoint: endpoint,
|
||||||
err: err,
|
}:
|
||||||
}
|
case <-bind.closedCh:
|
||||||
r.waiter.Add(1)
|
|
||||||
readQueue <- r
|
|
||||||
r.waiter.Wait()
|
|
||||||
if err != nil {
|
|
||||||
endpoint.conn = nil
|
endpoint.conn = nil
|
||||||
|
c.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(bind.readQueue, endpoint)
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -206,7 +200,8 @@ func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if nend.conn == nil {
|
if nend.conn == nil {
|
||||||
return errors.New("connection not open yet")
|
errors.LogDebug(context.Background(), nend.dst.NetAddr(), " send on closed peer")
|
||||||
|
return errors.New("peer closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, buff := range buff {
|
for _, buff := range buff {
|
||||||
|
|||||||
@@ -122,6 +122,7 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer)
|
|||||||
IPv6Enable: h.hasIPv6,
|
IPv6Enable: h.hasIPv6,
|
||||||
},
|
},
|
||||||
workers: int(h.conf.NumWorkers),
|
workers: int(h.conf.NumWorkers),
|
||||||
|
readQueue: make(chan *netReadInfo),
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ package wireguard
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
goerrors "errors"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/xtls/xray-core/common/buf"
|
"github.com/xtls/xray-core/common/buf"
|
||||||
c "github.com/xtls/xray-core/common/ctx"
|
c "github.com/xtls/xray-core/common/ctx"
|
||||||
@@ -51,6 +49,8 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
|||||||
IPv4Enable: hasIPv4,
|
IPv4Enable: hasIPv4,
|
||||||
IPv6Enable: hasIPv6,
|
IPv6Enable: hasIPv6,
|
||||||
},
|
},
|
||||||
|
workers: int(conf.NumWorkers),
|
||||||
|
readQueue: make(chan *netReadInfo),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||||
@@ -93,25 +93,31 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
|
|||||||
|
|
||||||
reader := buf.NewPacketReader(conn)
|
reader := buf.NewPacketReader(conn)
|
||||||
for {
|
for {
|
||||||
mpayload, err := reader.ReadMultiBuffer()
|
mb, err := reader.ReadMultiBuffer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
nep.conn = nil
|
||||||
|
buf.ReleaseMulti(mb)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, payload := range mpayload {
|
for i, b := range mb {
|
||||||
v, ok := <-s.bindServer.readQueue
|
buff := b.Bytes()
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
i, err := payload.Read(v.buff)
|
|
||||||
|
|
||||||
v.bytes = i
|
if b.Len() > 3 {
|
||||||
v.endpoint = nep
|
buff[1] = 0
|
||||||
v.err = err
|
buff[2] = 0
|
||||||
v.waiter.Done()
|
buff[3] = 0
|
||||||
if err != nil && goerrors.Is(err, io.EOF) {
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case s.bindServer.readQueue <- &netReadInfo{
|
||||||
|
buff: buff,
|
||||||
|
endpoint: nep,
|
||||||
|
}:
|
||||||
|
case <-s.bindServer.closedCh:
|
||||||
nep.conn = nil
|
nep.conn = nil
|
||||||
return nil
|
buf.ReleaseMulti(mb[i:])
|
||||||
|
return errors.New("bind closed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -138,9 +144,11 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
|
|||||||
// Currently we have no way to link to the original source address
|
// Currently we have no way to link to the original source address
|
||||||
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
|
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
|
||||||
ctx = session.ContextWithInbound(ctx, &inbound)
|
ctx = session.ContextWithInbound(ctx, &inbound)
|
||||||
|
content := new(session.Content)
|
||||||
if s.info.contentTag != nil {
|
if s.info.contentTag != nil {
|
||||||
ctx = session.ContextWithContent(ctx, s.info.contentTag)
|
content.SniffingRequest = s.info.contentTag.SniffingRequest
|
||||||
}
|
}
|
||||||
|
ctx = session.ContextWithContent(ctx, content)
|
||||||
ctx = session.SubContextFromMuxInbound(ctx)
|
ctx = session.SubContextFromMuxInbound(ctx)
|
||||||
|
|
||||||
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
||||||
|
|||||||
@@ -8,25 +8,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/xtls/xray-core/common"
|
"github.com/xtls/xray-core/common"
|
||||||
"github.com/xtls/xray-core/common/log"
|
|
||||||
"golang.zx2c4.com/wireguard/device"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var wgLogger = &device.Logger{
|
|
||||||
Verbosef: func(format string, args ...any) {
|
|
||||||
log.Record(&log.GeneralMessage{
|
|
||||||
Severity: log.Severity_Debug,
|
|
||||||
Content: fmt.Sprintf(format, args...),
|
|
||||||
})
|
|
||||||
},
|
|
||||||
Errorf: func(format string, args ...any) {
|
|
||||||
log.Record(&log.GeneralMessage{
|
|
||||||
Severity: log.Severity_Error,
|
|
||||||
Content: fmt.Sprintf(format, args...),
|
|
||||||
})
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||||||
deviceConfig := config.(*DeviceConfig)
|
deviceConfig := config.(*DeviceConfig)
|
||||||
|
|||||||
Reference in New Issue
Block a user