mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-05-08 14:13:22 +00:00
WireGuard: Implement UDP FullCone NAT (#5833)
Fixes https://github.com/XTLS/Xray-core/issues/5601 --------- Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
This commit is contained in:
@@ -53,7 +53,7 @@ func GetGlobalID(ctx context.Context) (globalID [8]byte) {
|
||||
return
|
||||
}
|
||||
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
|
||||
(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks" || inbound.Name == "tun") {
|
||||
(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks" || inbound.Name == "tun" || inbound.Name == "wireguard") {
|
||||
h := blake3.New(8, BaseKey)
|
||||
h.Write([]byte(inbound.Source.String()))
|
||||
copy(globalID[:], h.Sum(nil))
|
||||
|
||||
2
go.mod
2
go.mod
@@ -26,7 +26,7 @@ require (
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.42.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||
google.golang.org/grpc v1.79.3
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gvisor.dev/gvisor v0.0.0-20260122175437-89a5d21be8f0
|
||||
|
||||
2
go.sum
2
go.sum
@@ -131,6 +131,8 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=
|
||||
|
||||
@@ -130,7 +130,7 @@ func ParseWireGuardKey(str string) (string, error) {
|
||||
return "", errors.New("key must not be empty")
|
||||
}
|
||||
|
||||
if len(str)%2 == 0 {
|
||||
if len(str) == 64 {
|
||||
_, err = hex.DecodeString(str)
|
||||
if err == nil {
|
||||
return str, nil
|
||||
|
||||
@@ -227,6 +227,11 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn = &udpConnClient{
|
||||
Conn: conn,
|
||||
dest: destination,
|
||||
}
|
||||
|
||||
requestFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
||||
@@ -336,3 +341,34 @@ func (h *Handler) createIPCRequest() string {
|
||||
|
||||
return request.String()[:request.Len()]
|
||||
}
|
||||
|
||||
type udpConnClient struct {
|
||||
net.Conn
|
||||
dest net.Destination
|
||||
}
|
||||
|
||||
func (c *udpConnClient) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
b := buf.New()
|
||||
b.Resize(0, buf.Size)
|
||||
n, addr, err := c.Conn.(net.PacketConn).ReadFrom(b.Bytes())
|
||||
if err != nil {
|
||||
b.Release()
|
||||
return nil, err
|
||||
}
|
||||
if addr == nil { // should never hit
|
||||
addr = c.dest.RawNetAddr()
|
||||
}
|
||||
b.Resize(0, int32(n))
|
||||
|
||||
b.UDP = &net.Destination{
|
||||
Address: net.IPAddress(addr.(*net.UDPAddr).IP),
|
||||
Port: net.Port(addr.(*net.UDPAddr).Port),
|
||||
Network: net.Network_UDP,
|
||||
}
|
||||
|
||||
return buf.MultiBuffer{b}, nil
|
||||
}
|
||||
|
||||
func (c *udpConnClient) Write(p []byte) (int, error) {
|
||||
return c.Conn.(net.PacketConn).WriteTo(p, c.dest.RawNetAddr())
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ type netTun struct {
|
||||
ep *channel.Endpoint
|
||||
stack *stack.Stack
|
||||
events chan tun.Event
|
||||
notifyHandle *channel.NotificationHandle
|
||||
incomingPacket chan *buffer.View
|
||||
mtu int
|
||||
hasV4, hasV6 bool
|
||||
@@ -48,12 +49,17 @@ func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (t
|
||||
dev := &netTun{
|
||||
ep: channel.New(1024, uint32(mtu), ""),
|
||||
stack: stack.New(opts),
|
||||
events: make(chan tun.Event, 1),
|
||||
events: make(chan tun.Event, 10),
|
||||
incomingPacket: make(chan *buffer.View),
|
||||
mtu: mtu,
|
||||
}
|
||||
dev.ep.AddNotify(dev)
|
||||
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
|
||||
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
|
||||
tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
|
||||
if tcpipErr != nil {
|
||||
return nil, nil, dev.stack, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
|
||||
}
|
||||
dev.notifyHandle = dev.ep.AddNotify(dev)
|
||||
tcpipErr = dev.stack.CreateNIC(1, dev.ep)
|
||||
if tcpipErr != nil {
|
||||
return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
||||
}
|
||||
@@ -90,20 +96,10 @@ func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (t
|
||||
dev.stack.SetSpoofing(1, true)
|
||||
}
|
||||
|
||||
opt := tcpip.CongestionControlOption("cubic")
|
||||
if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
|
||||
return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
|
||||
}
|
||||
|
||||
dev.events <- tun.EventUp
|
||||
return dev, (*Net)(dev), dev.stack, nil
|
||||
}
|
||||
|
||||
// BatchSize implements tun.Device
|
||||
func (tun *netTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
// Name implements tun.Device
|
||||
func (tun *netTun) Name() (string, error) {
|
||||
return "go", nil
|
||||
@@ -120,7 +116,6 @@ func (tun *netTun) Events() <-chan tun.Event {
|
||||
}
|
||||
|
||||
// Read implements tun.Device
|
||||
|
||||
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
|
||||
view, ok := <-tun.incomingPacket
|
||||
if !ok {
|
||||
@@ -169,20 +164,16 @@ func (tun *netTun) WriteNotify() {
|
||||
tun.incomingPacket <- view
|
||||
}
|
||||
|
||||
// Flush implements tun.Device
|
||||
func (tun *netTun) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements tun.Device
|
||||
func (tun *netTun) Close() error {
|
||||
tun.closeOnce.Do(func() {
|
||||
tun.stack.RemoveNIC(1)
|
||||
tun.stack.Close()
|
||||
tun.ep.RemoveNotify(tun.notifyHandle)
|
||||
tun.ep.Close()
|
||||
|
||||
close(tun.events)
|
||||
|
||||
tun.ep.Close()
|
||||
|
||||
close(tun.incomingPacket)
|
||||
})
|
||||
return nil
|
||||
@@ -193,6 +184,11 @@ func (tun *netTun) MTU() (int, error) {
|
||||
return tun.mtu, nil
|
||||
}
|
||||
|
||||
// BatchSize implements tun.Device
|
||||
func (tun *netTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
||||
var protoNumber tcpip.NetworkProtocolNumber
|
||||
if endpoint.Addr().Is4() {
|
||||
@@ -224,6 +220,7 @@ func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, er
|
||||
var addr tcpip.FullAddress
|
||||
addr, pn = convertToFullAddr(raddr)
|
||||
rfa = &addr
|
||||
rfa = nil // do not ep connect
|
||||
}
|
||||
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
||||
}
|
||||
|
||||
@@ -5,19 +5,17 @@ import (
|
||||
goerrors "errors"
|
||||
"io"
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
c "github.com/xtls/xray-core/common/ctx"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
"github.com/xtls/xray-core/features/routing"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
"github.com/xtls/xray-core/transport/internet/stat"
|
||||
)
|
||||
|
||||
@@ -31,10 +29,10 @@ type Server struct {
|
||||
}
|
||||
|
||||
type routingInfo struct {
|
||||
ctx context.Context
|
||||
dispatcher routing.Dispatcher
|
||||
inboundTag *session.Inbound
|
||||
contentTag *session.Content
|
||||
ctx context.Context
|
||||
dispatcher routing.Dispatcher
|
||||
inboundTag *session.Inbound
|
||||
contentTag *session.Content
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
||||
@@ -124,7 +122,6 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
|
||||
errors.LogError(s.info.ctx, "unexpected: dispatcher == nil")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
|
||||
sid := session.NewID()
|
||||
@@ -146,9 +143,6 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
|
||||
}
|
||||
ctx = session.SubContextFromMuxInbound(ctx)
|
||||
|
||||
plcy := s.policyManager.ForLevel(0)
|
||||
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
|
||||
|
||||
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
||||
From: nullDestination,
|
||||
To: dest,
|
||||
@@ -156,35 +150,15 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
|
||||
Reason: "",
|
||||
})
|
||||
|
||||
link, err := s.info.dispatcher.Dispatch(ctx, dest)
|
||||
err := s.info.dispatcher.DispatchLink(ctx, dest, &transport.Link{
|
||||
Reader: buf.NewReader(conn),
|
||||
Writer: buf.NewWriter(conn),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "dispatch connection")
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
requestDone := func() error {
|
||||
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
|
||||
if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
|
||||
return errors.New("failed to transport all TCP request").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
errors.LogInfoInner(ctx, err, "connection ends")
|
||||
}
|
||||
|
||||
responseDone := func() error {
|
||||
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
|
||||
if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
|
||||
return errors.New("failed to transport all TCP response").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
|
||||
common.Interrupt(link.Reader)
|
||||
common.Interrupt(link.Writer)
|
||||
errors.LogDebugInner(ctx, err, "connection ends")
|
||||
return
|
||||
}
|
||||
cancel()
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package wireguard
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -10,12 +11,17 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
@@ -138,7 +144,7 @@ func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, erro
|
||||
|
||||
func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
|
||||
out := &gvisorNet{}
|
||||
tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
|
||||
tun, n, gstack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -147,60 +153,236 @@ func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
||||
// handler is only used for promiscuous mode
|
||||
// capture all packets and send to handler
|
||||
|
||||
tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
|
||||
tcpForwarder := tcp.NewForwarder(gstack, 0, 65535, func(r *tcp.ForwarderRequest) {
|
||||
go func(r *tcp.ForwarderRequest) {
|
||||
var (
|
||||
wq waiter.Queue
|
||||
id = r.ID()
|
||||
)
|
||||
var wq waiter.Queue
|
||||
var id = r.ID()
|
||||
|
||||
// Perform a TCP three-way handshake.
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
errors.LogError(context.Background(), err.String())
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
defer ep.Close()
|
||||
|
||||
// enable tcp keep-alive to prevent hanging connections
|
||||
ep.SocketOptions().SetKeepAlive(true)
|
||||
options := ep.SocketOptions()
|
||||
options.SetKeepAlive(false)
|
||||
options.SetReuseAddress(true)
|
||||
options.SetReusePort(true)
|
||||
|
||||
// local address is actually destination
|
||||
handler(net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
|
||||
|
||||
ep.Close()
|
||||
r.Complete(false)
|
||||
}(r)
|
||||
})
|
||||
stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
gstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
|
||||
udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) bool {
|
||||
go func(r *udp.ForwarderRequest) {
|
||||
var (
|
||||
wq waiter.Queue
|
||||
id = r.ID()
|
||||
)
|
||||
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
errors.LogError(context.Background(), err.String())
|
||||
return
|
||||
}
|
||||
defer ep.Close()
|
||||
|
||||
// prevents hanging connections and ensure timely release
|
||||
ep.SocketOptions().SetLinger(tcpip.LingerOption{
|
||||
Enabled: true,
|
||||
Timeout: 15 * time.Second,
|
||||
})
|
||||
|
||||
handler(net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewUDPConn(&wq, ep))
|
||||
}(r)
|
||||
manager := &udpManager{
|
||||
stack: gstack,
|
||||
handler: handler,
|
||||
m: make(map[string]*udpConn),
|
||||
}
|
||||
|
||||
gstack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
data := pkt.Clone().Data().AsRange().ToSlice()
|
||||
// if len(data) == 0 {
|
||||
// return false
|
||||
// }
|
||||
src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
|
||||
dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
|
||||
manager.feed(src, dst, data)
|
||||
return true
|
||||
})
|
||||
stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
}
|
||||
|
||||
out.tun, out.net = tun, n
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type udpManager struct {
|
||||
stack *stack.Stack
|
||||
handler func(dest net.Destination, conn net.Conn)
|
||||
m map[string]*udpConn
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (m *udpManager) feed(src net.Destination, dst net.Destination, data []byte) {
|
||||
m.mutex.RLock()
|
||||
uc, ok := m.m[src.NetAddr()]
|
||||
if ok {
|
||||
select {
|
||||
case uc.ch <- data:
|
||||
default:
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
uc, ok = m.m[src.NetAddr()]
|
||||
if !ok {
|
||||
uc = &udpConn{
|
||||
ch: make(chan []byte, 1024),
|
||||
src: src,
|
||||
dst: dst,
|
||||
}
|
||||
uc.writeFunc = m.writeRawUDPPacket
|
||||
uc.closeFunc = func() {
|
||||
m.mutex.Lock()
|
||||
m.close(uc)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
m.m[src.NetAddr()] = uc
|
||||
go m.handler(dst, uc)
|
||||
}
|
||||
|
||||
select {
|
||||
case uc.ch <- data:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpManager) close(uc *udpConn) {
|
||||
if !uc.closed {
|
||||
uc.closed = true
|
||||
close(uc.ch)
|
||||
delete(m.m, uc.src.NetAddr())
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpManager) writeRawUDPPacket(payload []byte, src net.Destination, dst net.Destination) error {
|
||||
udpLen := header.UDPMinimumSize + len(payload)
|
||||
srcIP := tcpip.AddrFromSlice(src.Address.IP())
|
||||
dstIP := tcpip.AddrFromSlice(dst.Address.IP())
|
||||
|
||||
// build packet with appropriate IP header size
|
||||
isIPv4 := dst.Address.Family().IsIPv4()
|
||||
ipHdrSize := header.IPv6MinimumSize
|
||||
ipProtocol := header.IPv6ProtocolNumber
|
||||
if isIPv4 {
|
||||
ipHdrSize = header.IPv4MinimumSize
|
||||
ipProtocol = header.IPv4ProtocolNumber
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
// Build UDP header
|
||||
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: uint16(src.Port),
|
||||
DstPort: uint16(dst.Port),
|
||||
Length: uint16(udpLen),
|
||||
})
|
||||
|
||||
// Calculate and set UDP checksum
|
||||
xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
|
||||
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
|
||||
|
||||
// Build IP header
|
||||
if isIPv4 {
|
||||
ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
|
||||
ipHdr.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + udpLen),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.UDPProtocolNumber),
|
||||
SrcAddr: srcIP,
|
||||
DstAddr: dstIP,
|
||||
})
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
} else {
|
||||
ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
|
||||
ipHdr.Encode(&header.IPv6Fields{
|
||||
PayloadLength: uint16(udpLen),
|
||||
TransportProtocol: header.UDPProtocolNumber,
|
||||
HopLimit: 64,
|
||||
SrcAddr: srcIP,
|
||||
DstAddr: dstIP,
|
||||
})
|
||||
}
|
||||
|
||||
// dispatch the packet
|
||||
err := m.stack.WriteRawPacket(1, ipProtocol, buffer.MakeWithView(pkt.ToView()))
|
||||
if err != nil {
|
||||
return errors.New("failed to write raw udp packet back to stack err ", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type udpConn struct {
|
||||
ch chan []byte
|
||||
src net.Destination
|
||||
dst net.Destination
|
||||
writeFunc func(payload []byte, src net.Destination, dst net.Destination) error
|
||||
closeFunc func()
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *udpConn) Read(p []byte) (int, error) {
|
||||
b, ok := <-c.ch
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, b)
|
||||
if n != len(b) {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *udpConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
for i, b := range mb {
|
||||
dst := c.dst
|
||||
if b.UDP != nil {
|
||||
dst = *b.UDP
|
||||
}
|
||||
err := c.writeFunc(b.Bytes(), dst, c.src)
|
||||
if err != nil {
|
||||
buf.ReleaseMulti(mb[i:])
|
||||
return err
|
||||
}
|
||||
b.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpConn) Write(p []byte) (int, error) {
|
||||
err := c.writeFunc(p, c.dst, c.src)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *udpConn) Close() error {
|
||||
c.closeFunc()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpConn) LocalAddr() net.Addr {
|
||||
return c.src.RawNetAddr() // fake
|
||||
}
|
||||
|
||||
func (c *udpConn) RemoteAddr() net.Addr {
|
||||
return c.src.RawNetAddr() // src
|
||||
}
|
||||
|
||||
func (c *udpConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -100,32 +100,39 @@ func (m *udpSessionManagerServer) run() {
|
||||
func (m *udpSessionManagerServer) feed(id uint32, d []byte) {
|
||||
m.mutex.RLock()
|
||||
udpConn, ok := m.m[id]
|
||||
if ok {
|
||||
select {
|
||||
case udpConn.ch <- d:
|
||||
default:
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
udpConn, ok = m.m[id]
|
||||
if !ok {
|
||||
m.mutex.Lock()
|
||||
udpConn, ok = m.m[id]
|
||||
if !ok {
|
||||
udpConn = &InterUdpConn{
|
||||
conn: m.conn,
|
||||
local: m.conn.LocalAddr(),
|
||||
remote: m.conn.RemoteAddr(),
|
||||
udpConn = &InterUdpConn{
|
||||
conn: m.conn,
|
||||
local: m.conn.LocalAddr(),
|
||||
remote: m.conn.RemoteAddr(),
|
||||
|
||||
id: id,
|
||||
ch: make(chan []byte, udpMessageChanSize),
|
||||
last: time.Now(),
|
||||
id: id,
|
||||
ch: make(chan []byte, udpMessageChanSize),
|
||||
last: time.Now(),
|
||||
|
||||
user: m.user,
|
||||
}
|
||||
udpConn.closeFunc = func() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
m.close(udpConn)
|
||||
}
|
||||
m.m[id] = udpConn
|
||||
m.addConn(udpConn)
|
||||
user: m.user,
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
udpConn.closeFunc = func() {
|
||||
m.mutex.Lock()
|
||||
m.close(udpConn)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
m.m[id] = udpConn
|
||||
m.addConn(udpConn)
|
||||
}
|
||||
|
||||
select {
|
||||
|
||||
Reference in New Issue
Block a user