diff --git a/proxy/tun/README.md b/proxy/tun/README.md index 51aaea37..2b9e0a42 100644 --- a/proxy/tun/README.md +++ b/proxy/tun/README.md @@ -41,10 +41,12 @@ Here is simple Xray config snippet to enable the inbound: - IPv4 and IPv6 - TCP and UDP +- ICMP Echo (ping) ## LIMITATION -- No ICMP support +- Only ICMP Echo request/reply is supported; other ICMP message types are ignored +- ICMP Echo replies are generated locally by the TUN stack; they do not validate real remote ICMP reachability - Connections are established to any host, as connection success is only a mark of successful accepting packet for proxying. Hosts that are not accepting connections or don't even exists, will look like they opened a connection (SYN-ACK), and never send back a single byte, closing connection (RST) after some time. This is the side effect of the whole process actually being a proxy, and not real network layer 3 vpn ## CONSIDERATIONS @@ -248,4 +250,4 @@ Set the environment variable `xray.tun.fd` (or `XRAY_TUN_FD`) to the fd number b Build using gomobile for iOS framework integration: ``` gomobile bind -target=ios -``` \ No newline at end of file +``` diff --git a/proxy/tun/icmp/packet.go b/proxy/tun/icmp/packet.go new file mode 100644 index 00000000..126e9d88 --- /dev/null +++ b/proxy/tun/icmp/packet.go @@ -0,0 +1,106 @@ +package icmp + +import ( + "github.com/xtls/xray-core/common/errors" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checksum" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func ProtocolLabel(netProto tcpip.NetworkProtocolNumber) string { + switch netProto { + case header.IPv4ProtocolNumber: + return "ipv4" + case header.IPv6ProtocolNumber: + return "ipv6" + default: + return "unknown" + } +} + +func ParseEchoRequest(netProto tcpip.NetworkProtocolNumber, message []byte) (uint16, uint16, bool) { + switch netProto { + case header.IPv4ProtocolNumber: + if len(message) < header.ICMPv4MinimumSize { + return 0, 0, false + } + icmpHdr := header.ICMPv4(message) + if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != header.ICMPv4UnusedCode { + return 0, 0, false + } + return icmpHdr.Ident(), icmpHdr.Sequence(), true + case header.IPv6ProtocolNumber: + if len(message) < header.ICMPv6MinimumSize { + return 0, 0, false + } + icmpHdr := header.ICMPv6(message) + if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != header.ICMPv6UnusedCode { + return 0, 0, false + } + return icmpHdr.Ident(), icmpHdr.Sequence(), true + default: + return 0, 0, false + } +} + +func RewriteChecksum(netProto tcpip.NetworkProtocolNumber, message []byte, srcIP, dstIP tcpip.Address) error { + switch netProto { + case header.IPv4ProtocolNumber: + if len(message) < header.ICMPv4MinimumSize { + return errors.New("invalid icmpv4 packet") + } + icmpHdr := header.ICMPv4(message) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) + return nil + case header.IPv6ProtocolNumber: + if len(message) < header.ICMPv6MinimumSize { + return errors.New("invalid icmpv6 packet") + } + icmpHdr := header.ICMPv6(message) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr[:header.ICMPv6MinimumSize], + Src: srcIP, + Dst: dstIP, + PayloadCsum: checksum.Checksum(icmpHdr.Payload(), 0), + PayloadLen: len(icmpHdr.Payload()), + })) + return nil + default: + return errors.New("unsupported icmp network protocol") + } +} + +func BuildLocalEchoReply(netProto tcpip.NetworkProtocolNumber, request []byte, srcIP, dstIP tcpip.Address) ([]byte, error) { + reply := append([]byte(nil), request...) + + switch netProto { + case header.IPv4ProtocolNumber: + if len(reply) < header.ICMPv4MinimumSize { + return nil, errors.New("invalid icmpv4 echo packet") + } + icmpHdr := header.ICMPv4(reply) + if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != header.ICMPv4UnusedCode { + return nil, errors.New("not an icmpv4 echo request") + } + reply[0] = byte(header.ICMPv4EchoReply) + case header.IPv6ProtocolNumber: + if len(reply) < header.ICMPv6MinimumSize { + return nil, errors.New("invalid icmpv6 echo packet") + } + icmpHdr := header.ICMPv6(reply) + if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != header.ICMPv6UnusedCode { + return nil, errors.New("not an icmpv6 echo request") + } + reply[0] = byte(header.ICMPv6EchoReply) + default: + return nil, errors.New("unsupported icmp network protocol") + } + + if err := RewriteChecksum(netProto, reply, srcIP, dstIP); err != nil { + return nil, err + } + + return reply, nil +} diff --git a/proxy/tun/icmp/packet_test.go b/proxy/tun/icmp/packet_test.go new file mode 100644 index 00000000..bb759868 --- /dev/null +++ b/proxy/tun/icmp/packet_test.go @@ -0,0 +1,174 @@ +package icmp + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checksum" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func TestParseEchoRequest(t *testing.T) { + t.Run("ipv4 echo", func(t *testing.T) { + var zero tcpip.Address + packet := []byte{ + byte(header.ICMPv4Echo), 0, + 0, 0, + 0x12, 0x34, + 0x56, 0x78, + 0xaa, 0xbb, + } + if err := RewriteChecksum(header.IPv4ProtocolNumber, packet, zero, zero); err != nil { + t.Fatal(err) + } + + ident, sequence, ok := ParseEchoRequest(header.IPv4ProtocolNumber, packet) + if !ok { + t.Fatal("expected ipv4 echo request to parse") + } + if ident != 0x1234 || sequence != 0x5678 { + t.Fatalf("unexpected ident/sequence: %x/%x", ident, sequence) + } + }) + + t.Run("ipv6 echo", func(t *testing.T) { + packet := []byte{ + byte(header.ICMPv6EchoRequest), 0, + 0, 0, + 0xab, 0xcd, + 0xef, 0x01, + 0xaa, 0xbb, + } + src := tcpip.AddrFromSlice([]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) + dst := tcpip.AddrFromSlice([]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}) + if err := RewriteChecksum(header.IPv6ProtocolNumber, packet, src, dst); err != nil { + t.Fatal(err) + } + + ident, sequence, ok := ParseEchoRequest(header.IPv6ProtocolNumber, packet) + if !ok { + t.Fatal("expected ipv6 echo request to parse") + } + if ident != 0xabcd || sequence != 0xef01 { + t.Fatalf("unexpected ident/sequence: %x/%x", ident, sequence) + } + }) +} + +func TestRewriteChecksum(t *testing.T) { + t.Run("ipv4", func(t *testing.T) { + var zero tcpip.Address + packet := []byte{ + byte(header.ICMPv4Echo), 0, + 0xff, 0xff, + 0x12, 0x34, + 0x56, 0x78, + 0xaa, 0xbb, 0xcc, + } + if err := RewriteChecksum(header.IPv4ProtocolNumber, packet, zero, zero); err != nil { + t.Fatal(err) + } + + icmpHdr := header.ICMPv4(packet) + if got, want := icmpHdr.Checksum(), header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksumPayloadV4(icmpHdr.Payload())); got != want { + t.Fatalf("unexpected ipv4 checksum: got %x want %x", got, want) + } + }) + + t.Run("ipv6", func(t *testing.T) { + packet := []byte{ + byte(header.ICMPv6EchoReply), 0, + 0xff, 0xff, + 0x12, 0x34, + 0x56, 0x78, + 0xaa, 0xbb, 0xcc, + } + src := tcpip.AddrFromSlice([]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) + dst := tcpip.AddrFromSlice([]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}) + if err := RewriteChecksum(header.IPv6ProtocolNumber, packet, src, dst); err != nil { + t.Fatal(err) + } + + icmpHdr := header.ICMPv6(packet) + want := header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr[:header.ICMPv6MinimumSize], + Src: src, + Dst: dst, + PayloadLen: len(icmpHdr.Payload()), + PayloadCsum: checksumPayloadV6(icmpHdr.Payload()), + }) + if got := icmpHdr.Checksum(); got != want { + t.Fatalf("unexpected ipv6 checksum: got %x want %x", got, want) + } + }) +} + +func TestBuildLocalEchoReply(t *testing.T) { + t.Run("ipv4", func(t *testing.T) { + request := []byte{ + byte(header.ICMPv4Echo), 0, + 0, 0, + 0x12, 0x34, + 0x56, 0x78, + 0xaa, 0xbb, 0xcc, + } + src := tcpip.Address{} + dst := tcpip.Address{} + if err := RewriteChecksum(header.IPv4ProtocolNumber, request, src, dst); err != nil { + t.Fatal(err) + } + + reply, err := BuildLocalEchoReply(header.IPv4ProtocolNumber, request, dst, src) + if err != nil { + t.Fatal(err) + } + if request[0] != byte(header.ICMPv4Echo) { + t.Fatal("request mutated") + } + icmpHdr := header.ICMPv4(reply) + if icmpHdr.Type() != header.ICMPv4EchoReply || icmpHdr.Code() != header.ICMPv4UnusedCode { + t.Fatalf("unexpected ipv4 reply type/code: %d/%d", icmpHdr.Type(), icmpHdr.Code()) + } + if icmpHdr.Ident() != 0x1234 || icmpHdr.Sequence() != 0x5678 { + t.Fatalf("unexpected ipv4 ident/sequence: %x/%x", icmpHdr.Ident(), icmpHdr.Sequence()) + } + }) + + t.Run("ipv6", func(t *testing.T) { + request := []byte{ + byte(header.ICMPv6EchoRequest), 0, + 0, 0, + 0xab, 0xcd, + 0xef, 0x01, + 0xaa, 0xbb, 0xcc, + } + src := tcpip.AddrFromSlice([]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) + dst := tcpip.AddrFromSlice([]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}) + if err := RewriteChecksum(header.IPv6ProtocolNumber, request, src, dst); err != nil { + t.Fatal(err) + } + + reply, err := BuildLocalEchoReply(header.IPv6ProtocolNumber, request, dst, src) + if err != nil { + t.Fatal(err) + } + if request[0] != byte(header.ICMPv6EchoRequest) { + t.Fatal("request mutated") + } + icmpHdr := header.ICMPv6(reply) + if icmpHdr.Type() != header.ICMPv6EchoReply || icmpHdr.Code() != header.ICMPv6UnusedCode { + t.Fatalf("unexpected ipv6 reply type/code: %d/%d", icmpHdr.Type(), icmpHdr.Code()) + } + if icmpHdr.Ident() != 0xabcd || icmpHdr.Sequence() != 0xef01 { + t.Fatalf("unexpected ipv6 ident/sequence: %x/%x", icmpHdr.Ident(), icmpHdr.Sequence()) + } + }) +} + +func checksumPayloadV4(payload []byte) uint16 { + return checksum.Checksum(payload, 0) +} + +func checksumPayloadV6(payload []byte) uint16 { + return checksum.Checksum(payload, 0) +} diff --git a/proxy/tun/stack_gvisor.go b/proxy/tun/stack_gvisor.go index b8ee3591..73f65f19 100644 --- a/proxy/tun/stack_gvisor.go +++ b/proxy/tun/stack_gvisor.go @@ -14,6 +14,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -117,6 +118,8 @@ func (t *stackGVisor) Start() error { udpForwarder.HandlePacket(src, dst, data) return true }) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, t.handleICMPv4Packet) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, t.handleICMPv6Packet) t.stack = ipStack t.endpoint = linkEndpoint @@ -205,7 +208,7 @@ func (t *stackGVisor) Close() error { func createStack(ep stack.LinkEndpoint) (*stack.Stack, error) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, HandleLocal: false, } gStack := stack.New(opts) diff --git a/proxy/tun/stack_gvisor_icmp_handler.go b/proxy/tun/stack_gvisor_icmp_handler.go new file mode 100644 index 00000000..17ba692b --- /dev/null +++ b/proxy/tun/stack_gvisor_icmp_handler.go @@ -0,0 +1,98 @@ +package tun + +import ( + "github.com/xtls/xray-core/common/errors" + tunicmp "github.com/xtls/xray-core/proxy/tun/icmp" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +func (t *stackGVisor) handleICMPv4Packet(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + return t.handleICMPEchoPacket(header.IPv4ProtocolNumber, id, pkt) +} + +func (t *stackGVisor) handleICMPv6Packet(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + return t.handleICMPEchoPacket(header.IPv6ProtocolNumber, id, pkt) +} + +func (t *stackGVisor) handleICMPEchoPacket(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + srcIP := id.RemoteAddress + dstIP := id.LocalAddress + if srcIP.Len() == 0 || dstIP.Len() == 0 { + return true + } + + message := transportPacketBytes(pkt) + ident, sequence, ok := tunicmp.ParseEchoRequest(netProto, message) + if !ok { + return true + } + + reply, err := tunicmp.BuildLocalEchoReply(netProto, message, dstIP, srcIP) + if err != nil { + errors.LogInfoInner(t.ctx, err, "[tun] failed to build local icmp echo reply") + return true + } + + errors.LogDebug(t.ctx, "[tun][icmp] ", tunicmp.ProtocolLabel(netProto), " local echo reply ", dstIP, " -> ", srcIP, " id=", ident, " seq=", sequence) + if err := t.writeRawICMPPacket(netProto, reply, dstIP, srcIP); err != nil { + errors.LogInfoInner(t.ctx, err, "[tun] failed to write local icmp echo reply") + } + + return true +} + +func (t *stackGVisor) writeRawICMPPacket(netProto tcpip.NetworkProtocolNumber, message []byte, srcIP, dstIP tcpip.Address) error { + ipHeaderSize := header.IPv6MinimumSize + ipProtocol := header.IPv6ProtocolNumber + transportProtocol := header.ICMPv6ProtocolNumber + if netProto == header.IPv4ProtocolNumber { + ipHeaderSize = header.IPv4MinimumSize + ipProtocol = header.IPv4ProtocolNumber + transportProtocol = header.ICMPv4ProtocolNumber + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: ipHeaderSize, + Payload: buffer.MakeWithData(message), + }) + defer pkt.DecRef() + + if netProto == header.IPv4ProtocolNumber { + ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(header.IPv4MinimumSize + len(message)), + TTL: 64, + Protocol: uint8(transportProtocol), + SrcAddr: srcIP, + DstAddr: dstIP, + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + } else { + ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ipHdr.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(message)), + TransportProtocol: transportProtocol, + HopLimit: 64, + SrcAddr: srcIP, + DstAddr: dstIP, + }) + } + + if err := t.stack.WriteRawPacket(defaultNIC, ipProtocol, buffer.MakeWithView(pkt.ToView())); err != nil { + return errors.New("failed to write raw icmp packet back to stack", err) + } + + return nil +} + +func transportPacketBytes(pkt *stack.PacketBuffer) []byte { + headerBytes := pkt.TransportHeader().Slice() + payloadBytes := pkt.Data().AsRange().ToSlice() + message := make([]byte, len(headerBytes)+len(payloadBytes)) + copy(message, headerBytes) + copy(message[len(headerBytes):], payloadBytes) + return message +}