Fix WireGuard inbound domain-based routing by storing sniffing config at init time

The root cause of issue #4760 was that the sniffing configuration and inbound tag
were being stored during Process() instead of at initialization time. This caused
a race condition where concurrent connections would overwrite each other's session
information, breaking domain-based routing.

The fix follows the same pattern as the TUN handler:
- Save tag and sniffingRequest during NewServer (at creation time)
- Use these saved values in forwardConnection for each new connection

This ensures that each connection uses the correct sniffing configuration that was
set for the inbound, regardless of concurrent connections from multiple WireGuard peers.

Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-01-16 16:18:23 +00:00
parent 81db73e320
commit 7792ca405e

View File

@@ -26,15 +26,15 @@ var nullDestination = net.TCPDestination(net.AnyIP, 0)
type Server struct { type Server struct {
bindServer *netBindServer bindServer *netBindServer
info routingInfo info routingInfo
policyManager policy.Manager policyManager policy.Manager
tag string
sniffingRequest session.SniffingRequest
} }
type routingInfo struct { type routingInfo struct {
ctx context.Context ctx context.Context
dispatcher routing.Dispatcher dispatcher routing.Dispatcher
inboundTag *session.Inbound
contentTag *session.Content
} }
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) { func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
@@ -58,6 +58,14 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
} }
// Retrieve tag and sniffing config from context (set by AlwaysOnInboundHandler)
if inbound := session.InboundFromContext(ctx); inbound != nil {
server.tag = inbound.Tag
}
if content := session.ContentFromContext(ctx); content != nil {
server.sniffingRequest = content.SniffingRequest
}
tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection) tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -81,8 +89,6 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
s.info = routingInfo{ s.info = routingInfo{
ctx: ctx, ctx: ctx,
dispatcher: dispatcher, dispatcher: dispatcher,
inboundTag: session.InboundFromContext(ctx),
contentTag: session.ContentFromContext(ctx),
} }
ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
@@ -129,21 +135,21 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
sid := session.NewID() sid := session.NewID()
ctx = c.ContextWithID(ctx, sid) ctx = c.ContextWithID(ctx, sid)
inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs)
if s.info.inboundTag != nil {
inbound = *s.info.inboundTag
}
inbound.Name = "wireguard"
inbound.CanSpliceCopy = 3
// overwrite the source to use the tun address for each sub context. inbound := session.Inbound{
// Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context Name: "wireguard",
// Currently we have no way to link to the original source address Tag: s.tag,
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr()) CanSpliceCopy: 3,
ctx = session.ContextWithInbound(ctx, &inbound) // overwrite the source to use the tun address for each sub context.
if s.info.contentTag != nil { // Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context
ctx = session.ContextWithContent(ctx, s.info.contentTag) // Currently we have no way to link to the original source address
Source: net.DestinationFromAddr(conn.RemoteAddr()),
} }
ctx = session.ContextWithInbound(ctx, &inbound)
ctx = session.ContextWithContent(ctx, &session.Content{
SniffingRequest: s.sniffingRequest,
})
ctx = session.SubContextFromMuxInbound(ctx) ctx = session.SubContextFromMuxInbound(ctx)
plcy := s.policyManager.ForLevel(0) plcy := s.policyManager.ForLevel(0)