diff --git a/proxy/tun/stack_gvisor.go b/proxy/tun/stack_gvisor.go index ab767c61..8bcc4ebe 100644 --- a/proxy/tun/stack_gvisor.go +++ b/proxy/tun/stack_gvisor.go @@ -105,17 +105,23 @@ func (t *stackGVisor) Start() error { // Use custom UDP packet handler, instead of strict gVisor forwarder, for FullCone NAT support udpForwarder := newUdpConnectionHandler(t.handler.HandleConnection, t.writeRawUDPPacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - data := pkt.Data().AsRange().ToSlice() - if len(data) == 0 { - return false - } + data := pkt.Clone().Data().AsRange().ToSlice() + // if len(data) == 0 { + // return false + // } // source/destination of the packet we process as incoming, on gVisor side are Remote/Local // in other terms, src is the side behind tun, dst is the side behind gVisor // this function handle packets passing from the tun to the gVisor, therefore the src/dst assignement - src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort)) - dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)) - - return udpForwarder.HandlePacket(src, dst, data) + srcIP := net.IPAddress(id.RemoteAddress.AsSlice()) + dstIP := net.IPAddress(id.LocalAddress.AsSlice()) + if srcIP == nil || dstIP == nil { + errors.LogDebug(context.Background(), "drop udp with size ", len(data), " > invalid ip address ", id.RemoteAddress.AsSlice(), " ", id.LocalAddress.AsSlice()) + return true + } + src := net.UDPDestination(srcIP, net.Port(id.RemotePort)) + dst := net.UDPDestination(dstIP, net.Port(id.LocalPort)) + udpForwarder.HandlePacket(src, dst, data) + return true }) t.stack = ipStack diff --git a/proxy/tun/udp_fullcone.go b/proxy/tun/udp_fullcone.go index df58ce4e..44612100 100644 --- a/proxy/tun/udp_fullcone.go +++ b/proxy/tun/udp_fullcone.go @@ -1,16 +1,24 @@ package tun import ( + "context" "io" "sync" + "time" "github.com/xtls/xray-core/common/buf" + "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" ) +type packet struct { + data []byte + dest *net.Destination +} + // sub-handler specifically for udp connections under main handler type udpConnectionHandler struct { - sync.Mutex + sync.RWMutex udpConns map[net.Destination]*udpConn @@ -30,25 +38,44 @@ func newUdpConnectionHandler(handleConnection func(conn net.Conn, dest net.Desti // HandlePacket handles UDP packets coming from tun, to forward to the dispatcher // this custom handler support FullCone NAT of returning packets, binding connection only by the source addr:port -func (u *udpConnectionHandler) HandlePacket(src net.Destination, dst net.Destination, data []byte) bool { - u.Lock() +func (u *udpConnectionHandler) HandlePacket(src net.Destination, dst net.Destination, data []byte) { + u.RLock() conn, found := u.udpConns[src] + if found { + select { + case conn.egress <- &packet{ + data: data, + dest: &dst, + }: + default: + errors.LogDebug(context.Background(), "drop udp with size ", len(data), " to ", dst.NetAddr(), " original ", conn.dst.NetAddr(), " > queue full") + } + u.RUnlock() + return + } + u.RUnlock() + + u.Lock() + defer u.Unlock() + + conn, found = u.udpConns[src] if !found { - egress := make(chan []byte, 16) + egress := make(chan *packet, 1024) conn = &udpConn{handler: u, egress: egress, src: src, dst: dst} u.udpConns[src] = conn go u.handleConnection(conn, dst) } - u.Unlock() // send packet data to the egress channel, if it has buffer, or discard select { - case conn.egress <- data: + case conn.egress <- &packet{ + data: data, + dest: &dst, + }: default: + errors.LogDebug(context.Background(), "drop udp with size ", len(data), " to ", dst.NetAddr(), " original ", conn.dst.NetAddr(), " > queue full") } - - return true } func (u *udpConnectionHandler) connectionFinished(src net.Destination) { @@ -63,27 +90,64 @@ func (u *udpConnectionHandler) connectionFinished(src net.Destination) { // udp connection abstraction type udpConn struct { - net.Conn - buf.Writer - handler *udpConnectionHandler - egress chan []byte + egress chan *packet src net.Destination dst net.Destination } +func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) { + for { + e, ok := <-c.egress + if !ok { + return nil, io.EOF + } + + b := buf.New() + + _, err := b.Write(e.data) + if err != nil { + errors.LogDebugInner(context.Background(), err, "drop udp with size ", len(e.data), " to ", e.dest.NetAddr(), " original ", c.dst.NetAddr()) + b.Release() + continue + } + + b.UDP = e.dest + + return buf.MultiBuffer{b}, nil + } +} + // Read packets from the connection func (c *udpConn) Read(p []byte) (int, error) { - data, ok := <-c.egress + e, ok := <-c.egress if !ok { return 0, io.EOF } - - n := copy(p, data) + n := copy(p, e.data) + if n != len(e.data) { + return 0, io.ErrShortBuffer + } return n, nil } +func (c *udpConn) WriteMultiBuffer(mb buf.MultiBuffer) error { + for i, b := range mb { + dst := c.dst + if b.UDP != nil { + dst = *b.UDP + } + err := c.handler.writePacket(b.Bytes(), dst, c.src) + if err != nil { + buf.ReleaseMulti(mb[i:]) + return err + } + b.Release() + } + return nil +} + // Write returning packets back func (c *udpConn) Write(p []byte) (int, error) { // sending packets back mean sending payload with source/destination reversed @@ -102,33 +166,21 @@ func (c *udpConn) Close() error { } func (c *udpConn) LocalAddr() net.Addr { - return &net.UDPAddr{IP: c.dst.Address.IP(), Port: int(c.dst.Port.Value())} + return c.dst.RawNetAddr() } func (c *udpConn) RemoteAddr() net.Addr { - return &net.UDPAddr{IP: c.src.Address.IP(), Port: int(c.src.Port.Value())} + return c.src.RawNetAddr() } -// Write returning packets back -func (c *udpConn) WriteMultiBuffer(mb buf.MultiBuffer) error { - for _, b := range mb { - dst := c.dst - if b.UDP != nil { - dst = *b.UDP - } - - // validate address family matches between buffer packet and the connection - if dst.Address.Family() != c.dst.Address.Family() { - continue - } - - // sending packets back mean sending payload with source/destination reversed - err := c.handler.writePacket(b.Bytes(), dst, c.src) - if err != nil { - // udp doesn't guarantee delivery, so in any failure we just continue to the next packet - continue - } - } - +func (c *udpConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *udpConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *udpConn) SetWriteDeadline(t time.Time) error { return nil }