diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 62902c60..44e7db6a 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -267,6 +267,13 @@ func (h *Handler) DestIpAddress() net.IP { return internet.DestIpAddress() } +func (h *Handler) SocketSettings() *internet.SocketConfig { + if h.streamSettings == nil { + return nil + } + return h.streamSettings.SocketSettings +} + // Dial implements internet.Dialer. func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connection, error) { if h.senderSettings != nil { diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index 4cd48b80..ea608309 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -38,6 +38,11 @@ var defaultBlockAllRule *FinalRule func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { h := new(Handler) + if handler, ok := session.FullHandlerFromContext(ctx).(handlerWithSocketSettings); ok { + if sockopt := handler.SocketSettings(); sockopt != nil { + h.socketStrategy = sockopt.DomainStrategy + } + } if err := core.RequireFeatures(ctx, func(pm policy.Manager) error { return h.Init(config.(*Config), pm) }); err != nil { @@ -88,6 +93,10 @@ func init() { } } +type handlerWithSocketSettings interface { + SocketSettings() *internet.SocketConfig +} + type FinalRule struct { action RuleAction network [8]bool @@ -98,9 +107,10 @@ type FinalRule struct { // Handler handles Freedom connections. type Handler struct { - policyManager policy.Manager - config *Config - finalRules []*FinalRule + policyManager policy.Manager + config *Config + finalRules []*FinalRule + socketStrategy internet.DomainStrategy } func buildFinalRule(config *FinalRuleConfig) (*FinalRule, error) { @@ -246,6 +256,13 @@ func (h *Handler) blockDelay(rule *FinalRule) time.Duration { return time.Duration(min+uint64(dice.Roll(int(abs+1)))) * time.Second } +func (h *Handler) udpDomainStrategy() internet.DomainStrategy { + if h.config.DomainStrategy.HasStrategy() { + return h.config.DomainStrategy + } + return h.socketStrategy +} + func isValidAddress(addr *net.IPOrDomain) bool { if addr == nil { return false @@ -406,7 +423,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte writer = buf.NewWriter(conn) } } else { - writer = NewPacketWriter(conn, h, defaultRule, UDPOverride, destination) + writer = NewPacketWriter(conn, h, defaultRule, UDPOverride, destination, outGateway) if h.config.Noises != nil { errors.LogDebug(ctx, "NOISE", h.config.Noises) writer = &NoisePacketWriter{ @@ -535,7 +552,7 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { } // DialDest means the dial target used in the dialer when creating conn -func NewPacketWriter(conn net.Conn, h *Handler, defaultRule *FinalRule, UDPOverride net.Destination, DialDest net.Destination) buf.Writer { +func NewPacketWriter(conn net.Conn, h *Handler, defaultRule *FinalRule, UDPOverride net.Destination, DialDest net.Destination, outGateway net.Address) buf.Writer { iConn := conn statConn, ok := iConn.(*stat.CounterConnection) if ok { @@ -559,7 +576,7 @@ func NewPacketWriter(conn net.Conn, h *Handler, defaultRule *FinalRule, UDPOverr DefaultRule: defaultRule, UDPOverride: UDPOverride, ResolvedUDPAddr: resolvedUDPAddr, - LocalAddr: net.DestinationFromAddr(conn.LocalAddr()).Address, + OutGateway: outGateway, } } @@ -578,7 +595,7 @@ type PacketWriter struct { // Resulting in these packets being sent to many different IPs randomly // So, cache and keep the resolve result ResolvedUDPAddr *utils.TypedSyncMap[string, net.Address] - LocalAddr net.Address + OutGateway net.Address } func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { @@ -601,21 +618,21 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { if ip, ok := w.ResolvedUDPAddr.Load(b.UDP.Address.Domain()); ok { b.UDP.Address = ip } else { - ShouldUseSystemResolver := true - if w.Handler.config.DomainStrategy.HasStrategy() { - ips, err := internet.LookupForIP(b.UDP.Address.Domain(), w.Handler.config.DomainStrategy, w.LocalAddr) + shouldUseSystemResolver := true + if resolveStrategy := w.Handler.udpDomainStrategy(); resolveStrategy.HasStrategy() { + ips, err := internet.LookupForIP(b.UDP.Address.Domain(), w.Handler.config.DomainStrategy, w.OutGateway) if err != nil { // drop packet if resolve failed when forceIP - if w.Handler.config.DomainStrategy.ForceIP() { + if resolveStrategy.ForceIP() { b.Release() continue } } else { ip = net.IPAddress(ips[dice.Roll(len(ips))]) - ShouldUseSystemResolver = false + shouldUseSystemResolver = false } } - if ShouldUseSystemResolver { + if shouldUseSystemResolver { udpAddr, err := net.ResolveUDPAddr("udp", b.UDP.NetAddr()) if err != nil { b.Release()