From d1db1d6a27733b1bdc16bcce926eb430ed5efff4 Mon Sep 17 00:00:00 2001 From: Meow <197331664+Meo597@users.noreply.github.com> Date: Sun, 26 Apr 2026 01:27:39 +0800 Subject: [PATCH] DNS outbound: Add `rules` (matches `qtype` and `domain`, then `action`) (#5981) https://github.com/XTLS/Xray-core/pull/5981#issuecomment-4279809648 Example: https://github.com/XTLS/Xray-core/pull/5981#issuecomment-4283200236 Closes https://github.com/XTLS/Xray-core/issues/5218 --- infra/conf/dns_proxy.go | 141 ++++++++++++++++++++++-- infra/conf/dns_proxy_test.go | 205 +++++++++++++++++++++++++++++++++++ proxy/dns/config.pb.go | 203 +++++++++++++++++++++++++++------- proxy/dns/config.proto | 23 +++- proxy/dns/dns.go | 148 ++++++++++++++++--------- proxy/dns/dns_test.go | 124 +++++++++++++++++++++ 6 files changed, 735 insertions(+), 109 deletions(-) diff --git a/infra/conf/dns_proxy.go b/infra/conf/dns_proxy.go index b223e502..50c535c8 100644 --- a/infra/conf/dns_proxy.go +++ b/infra/conf/dns_proxy.go @@ -1,19 +1,70 @@ package conf import ( + "strings" + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/proxy/dns" "google.golang.org/protobuf/proto" ) +type DNSOutboundRuleConfig struct { + Action string `json:"action"` + QType *PortList `json:"qtype"` + Domain *StringList `json:"domain"` +} + +func (c *DNSOutboundRuleConfig) Build() (*dns.DNSRuleConfig, error) { + rule := &dns.DNSRuleConfig{} + + switch strings.ToLower(c.Action) { + case "direct": + rule.Action = dns.RuleAction_Direct + case "drop": + rule.Action = dns.RuleAction_Drop + case "reject": + rule.Action = dns.RuleAction_Reject + case "hijack": + rule.Action = dns.RuleAction_Hijack + default: + return nil, errors.New("unknown action: ", c.Action) + } + + if c.QType != nil { + for _, r := range c.QType.Range { + if r.From > r.To { + return nil, errors.New("invalid qtype range: ", r.String()) + } + if r.To > 65535 { + return nil, errors.New("dns rule qtype out of range: ", r.String()) + } + for qtype := r.From; qtype <= r.To; qtype++ { + rule.Qtype = append(rule.Qtype, int32(qtype)) + } + } + } + + if c.Domain != nil { + rules, err := geodata.ParseDomainRules(*c.Domain, geodata.Domain_Substr) + if err != nil { + return nil, err + } + rule.Domain = rules + } + + return rule, nil +} + type DNSOutboundConfig struct { - Network Network `json:"network"` - Address *Address `json:"address"` - Port uint16 `json:"port"` - UserLevel uint32 `json:"userLevel"` - NonIPQuery string `json:"nonIPQuery"` - BlockTypes []int32 `json:"blockTypes"` + Network Network `json:"network"` + Address *Address `json:"address"` + Port uint16 `json:"port"` + UserLevel uint32 `json:"userLevel"` + Rules []*DNSOutboundRuleConfig `json:"rules"` + NonIPQuery *string `json:"nonIPQuery"` // todo: remove legacy + BlockTypes *[]int32 `json:"blockTypes"` // todo: remove legacy } func (c *DNSOutboundConfig) Build() (proto.Message, error) { @@ -27,12 +78,78 @@ func (c *DNSOutboundConfig) Build() (proto.Message, error) { if c.Address != nil { config.Server.Address = c.Address.Build() } - switch c.NonIPQuery { - case "", "reject", "drop", "skip": - default: - return nil, errors.New(`unknown "nonIPQuery": `, c.NonIPQuery) + + // todo: remove legacy + if c.NonIPQuery != nil || c.BlockTypes != nil { + if c.Rules != nil { + return nil, errors.New("legacy nonIPQuery and blockTypes cannot be mixed with rules") + } + errors.PrintDeprecatedFeatureWarning(`"nonIPQuery" and "blockTypes" in DNS outbound`, `"rules"`) + rules, err := c.buildLegacyDNSPolicy() + if err != nil { + return nil, err + } + config.Rule = rules + return config, nil } - config.Non_IPQuery = c.NonIPQuery - config.BlockTypes = c.BlockTypes + + for _, r := range c.Rules { + rule, err := r.Build() + if err != nil { + return nil, err + } + config.Rule = append(config.Rule, rule) + } + return config, nil } + +// todo: remove legacy +func (c *DNSOutboundConfig) buildLegacyDNSPolicy() ([]*dns.DNSRuleConfig, error) { + rules := make([]*dns.DNSRuleConfig, 0, 3) + + mode := "reject" + if c.NonIPQuery != nil && *c.NonIPQuery != "" { + mode = *c.NonIPQuery + } + switch mode { + case "", "reject", "drop", "skip": + default: + return nil, errors.New("unknown nonIPQuery: ", mode) + } + + if c.BlockTypes != nil && len(*c.BlockTypes) > 0 { + rule := &dns.DNSRuleConfig{Action: dns.RuleAction_Drop} + if mode == "reject" { + rule.Action = dns.RuleAction_Reject + } + for _, qtype := range *c.BlockTypes { + if qtype < 0 || qtype > 65535 { + return nil, errors.New("legacy blockTypes qtype out of range: ", qtype) + } + rule.Qtype = append(rule.Qtype, qtype) + } + rules = append(rules, rule) + } + + { + rule := &dns.DNSRuleConfig{Action: dns.RuleAction_Hijack} + rule.Qtype = append(rule.Qtype, 1) + rule.Qtype = append(rule.Qtype, 28) + rules = append(rules, rule) + } + + { + rule := &dns.DNSRuleConfig{Action: dns.RuleAction_Reject} + if mode == "reject" { + rule.Action = dns.RuleAction_Reject + } else if mode == "drop" { + rule.Action = dns.RuleAction_Drop + } else if mode == "skip" { + rule.Action = dns.RuleAction_Direct + } + rules = append(rules, rule) + } + + return rules, nil +} diff --git a/infra/conf/dns_proxy_test.go b/infra/conf/dns_proxy_test.go index 805ac323..f5bcd3d5 100644 --- a/infra/conf/dns_proxy_test.go +++ b/infra/conf/dns_proxy_test.go @@ -1,8 +1,10 @@ package conf_test import ( + "strings" "testing" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" . "github.com/xtls/xray-core/infra/conf" "github.com/xtls/xray-core/proxy/dns" @@ -29,5 +31,208 @@ func TestDnsProxyConfig(t *testing.T) { }, }, }, + { + Input: `{ + "rules": [{ + "action": "direct", + "qtype": "1,3,23-24" + }, { + "action": "drop", + "qtype": 28, + "domain": ["domain:example.com", "full:example.com"] + }] + }`, + Parser: loadJSON(creator), + Output: &dns.Config{ + Server: &net.Endpoint{}, + Rule: []*dns.DNSRuleConfig{ + { + Action: dns.RuleAction_Direct, + Qtype: []int32{1, 3, 23, 24}, + }, + { + Action: dns.RuleAction_Drop, + Qtype: []int32{28}, + Domain: []*geodata.DomainRule{ + { + Value: &geodata.DomainRule_Custom{ + Custom: &geodata.Domain{ + Type: geodata.Domain_Domain, + Value: "example.com", + }, + }, + }, + { + Value: &geodata.DomainRule_Custom{ + Custom: &geodata.Domain{ + Type: geodata.Domain_Full, + Value: "example.com", + }, + }, + }, + }, + }, + }, + }, + }, + { + Input: `{ + "rules": [{ + "action": "reject", + "domain": "keyword:example" + }] + }`, + Parser: loadJSON(creator), + Output: &dns.Config{ + Server: &net.Endpoint{}, + Rule: []*dns.DNSRuleConfig{ + { + Action: dns.RuleAction_Reject, + Domain: []*geodata.DomainRule{ + { + Value: &geodata.DomainRule_Custom{ + Custom: &geodata.Domain{ + Type: geodata.Domain_Substr, + Value: "example", + }, + }, + }, + }, + }, + }, + }, + }, + { + Input: `{ + "rules": [{ + "action": "drop", + "qtype": 257 + }] + }`, + Parser: loadJSON(creator), + Output: &dns.Config{ + Server: &net.Endpoint{}, + Rule: []*dns.DNSRuleConfig{ + { + Action: dns.RuleAction_Drop, + Qtype: []int32{257}, + }, + }, + }, + }, }) } + +// todo: remove legacy +func TestDnsProxyConfigLegacyCompatibility(t *testing.T) { + creator := func() Buildable { + return new(DNSOutboundConfig) + } + + runMultiTestCase(t, []TestCase{ + { + Input: `{ + "blockTypes": [] + }`, + Parser: loadJSON(creator), + Output: &dns.Config{ + Server: &net.Endpoint{}, + Rule: []*dns.DNSRuleConfig{ + { + Action: dns.RuleAction_Hijack, + Qtype: []int32{1, 28}, + }, + { + Action: dns.RuleAction_Reject, + }, + }, + }, + }, + { + Input: `{ + "blockTypes": [1, 65] + }`, + Parser: loadJSON(creator), + Output: &dns.Config{ + Server: &net.Endpoint{}, + Rule: []*dns.DNSRuleConfig{ + { + Action: dns.RuleAction_Reject, + Qtype: []int32{1, 65}, + }, + { + Action: dns.RuleAction_Hijack, + Qtype: []int32{1, 28}, + }, + { + Action: dns.RuleAction_Reject, + }, + }, + }, + }, + { + Input: `{ + "nonIPQuery": "drop", + "blockTypes": [1] + }`, + Parser: loadJSON(creator), + Output: &dns.Config{ + Server: &net.Endpoint{}, + Rule: []*dns.DNSRuleConfig{ + { + Action: dns.RuleAction_Drop, + Qtype: []int32{1}, + }, + { + Action: dns.RuleAction_Hijack, + Qtype: []int32{1, 28}, + }, + { + Action: dns.RuleAction_Drop, + }, + }, + }, + }, + { + Input: `{ + "nonIPQuery": "skip", + "blockTypes": [65, 28] + }`, + Parser: loadJSON(creator), + Output: &dns.Config{ + Server: &net.Endpoint{}, + Rule: []*dns.DNSRuleConfig{ + { + Action: dns.RuleAction_Drop, + Qtype: []int32{65, 28}, + }, + { + Action: dns.RuleAction_Hijack, + Qtype: []int32{1, 28}, + }, + { + Action: dns.RuleAction_Direct, + }, + }, + }, + }, + }) +} + +// todo: remove legacy +func TestDnsProxyConfigRejectsMixedLegacyAndNewFields(t *testing.T) { + creator := func() Buildable { + return new(DNSOutboundConfig) + } + + _, err := loadJSON(creator)(`{ + "rules": [{ + "action": "direct", + "qtype": 65 + }], + "blockTypes": [65] + }`) + if err == nil || !strings.Contains(err.Error(), `legacy nonIPQuery and blockTypes cannot be mixed with rules`) { + t.Fatal("expected mixed legacy/new config error, but got ", err) + } +} diff --git a/proxy/dns/config.pb.go b/proxy/dns/config.pb.go index 436d3d38..836f0af9 100644 --- a/proxy/dns/config.pb.go +++ b/proxy/dns/config.pb.go @@ -7,6 +7,7 @@ package dns import ( + geodata "github.com/xtls/xray-core/common/geodata" net "github.com/xtls/xray-core/common/net" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" @@ -22,21 +23,130 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type RuleAction int32 + +const ( + RuleAction_Direct RuleAction = 0 + RuleAction_Drop RuleAction = 1 + RuleAction_Reject RuleAction = 2 + RuleAction_Hijack RuleAction = 3 +) + +// Enum value maps for RuleAction. +var ( + RuleAction_name = map[int32]string{ + 0: "Direct", + 1: "Drop", + 2: "Reject", + 3: "Hijack", + } + RuleAction_value = map[string]int32{ + "Direct": 0, + "Drop": 1, + "Reject": 2, + "Hijack": 3, + } +) + +func (x RuleAction) Enum() *RuleAction { + p := new(RuleAction) + *p = x + return p +} + +func (x RuleAction) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RuleAction) Descriptor() protoreflect.EnumDescriptor { + return file_proxy_dns_config_proto_enumTypes[0].Descriptor() +} + +func (RuleAction) Type() protoreflect.EnumType { + return &file_proxy_dns_config_proto_enumTypes[0] +} + +func (x RuleAction) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RuleAction.Descriptor instead. +func (RuleAction) EnumDescriptor() ([]byte, []int) { + return file_proxy_dns_config_proto_rawDescGZIP(), []int{0} +} + +type DNSRuleConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + Action RuleAction `protobuf:"varint,1,opt,name=action,proto3,enum=xray.proxy.dns.RuleAction" json:"action,omitempty"` + Qtype []int32 `protobuf:"varint,2,rep,packed,name=qtype,proto3" json:"qtype,omitempty"` + Domain []*geodata.DomainRule `protobuf:"bytes,3,rep,name=domain,proto3" json:"domain,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DNSRuleConfig) Reset() { + *x = DNSRuleConfig{} + mi := &file_proxy_dns_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DNSRuleConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DNSRuleConfig) ProtoMessage() {} + +func (x *DNSRuleConfig) ProtoReflect() protoreflect.Message { + mi := &file_proxy_dns_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DNSRuleConfig.ProtoReflect.Descriptor instead. +func (*DNSRuleConfig) Descriptor() ([]byte, []int) { + return file_proxy_dns_config_proto_rawDescGZIP(), []int{0} +} + +func (x *DNSRuleConfig) GetAction() RuleAction { + if x != nil { + return x.Action + } + return RuleAction_Direct +} + +func (x *DNSRuleConfig) GetQtype() []int32 { + if x != nil { + return x.Qtype + } + return nil +} + +func (x *DNSRuleConfig) GetDomain() []*geodata.DomainRule { + if x != nil { + return x.Domain + } + return nil +} + type Config struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Server is the DNS server address. If specified, this address overrides the - // original one. - Server *net.Endpoint `protobuf:"bytes,1,opt,name=server,proto3" json:"server,omitempty"` - UserLevel uint32 `protobuf:"varint,2,opt,name=user_level,json=userLevel,proto3" json:"user_level,omitempty"` - Non_IPQuery string `protobuf:"bytes,3,opt,name=non_IP_query,json=nonIPQuery,proto3" json:"non_IP_query,omitempty"` - BlockTypes []int32 `protobuf:"varint,4,rep,packed,name=block_types,json=blockTypes,proto3" json:"block_types,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + UserLevel uint32 `protobuf:"varint,1,opt,name=user_level,json=userLevel,proto3" json:"user_level,omitempty"` + Rule []*DNSRuleConfig `protobuf:"bytes,2,rep,name=rule,proto3" json:"rule,omitempty"` + Server *net.Endpoint `protobuf:"bytes,3,opt,name=server,proto3" json:"server,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *Config) Reset() { *x = Config{} - mi := &file_proxy_dns_config_proto_msgTypes[0] + mi := &file_proxy_dns_config_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -48,7 +158,7 @@ func (x *Config) String() string { func (*Config) ProtoMessage() {} func (x *Config) ProtoReflect() protoreflect.Message { - mi := &file_proxy_dns_config_proto_msgTypes[0] + mi := &file_proxy_dns_config_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -61,14 +171,7 @@ func (x *Config) ProtoReflect() protoreflect.Message { // Deprecated: Use Config.ProtoReflect.Descriptor instead. func (*Config) Descriptor() ([]byte, []int) { - return file_proxy_dns_config_proto_rawDescGZIP(), []int{0} -} - -func (x *Config) GetServer() *net.Endpoint { - if x != nil { - return x.Server - } - return nil + return file_proxy_dns_config_proto_rawDescGZIP(), []int{1} } func (x *Config) GetUserLevel() uint32 { @@ -78,16 +181,16 @@ func (x *Config) GetUserLevel() uint32 { return 0 } -func (x *Config) GetNon_IPQuery() string { +func (x *Config) GetRule() []*DNSRuleConfig { if x != nil { - return x.Non_IPQuery + return x.Rule } - return "" + return nil } -func (x *Config) GetBlockTypes() []int32 { +func (x *Config) GetServer() *net.Endpoint { if x != nil { - return x.BlockTypes + return x.Server } return nil } @@ -96,15 +199,25 @@ var File_proxy_dns_config_proto protoreflect.FileDescriptor const file_proxy_dns_config_proto_rawDesc = "" + "\n" + - "\x16proxy/dns/config.proto\x12\x0exray.proxy.dns\x1a\x1ccommon/net/destination.proto\"\x9d\x01\n" + - "\x06Config\x121\n" + - "\x06server\x18\x01 \x01(\v2\x19.xray.common.net.EndpointR\x06server\x12\x1d\n" + + "\x16proxy/dns/config.proto\x12\x0exray.proxy.dns\x1a\x1ccommon/net/destination.proto\x1a\x1bcommon/geodata/geodat.proto\"\x92\x01\n" + + "\rDNSRuleConfig\x122\n" + + "\x06action\x18\x01 \x01(\x0e2\x1a.xray.proxy.dns.RuleActionR\x06action\x12\x14\n" + + "\x05qtype\x18\x02 \x03(\x05R\x05qtype\x127\n" + + "\x06domain\x18\x03 \x03(\v2\x1f.xray.common.geodata.DomainRuleR\x06domain\"\x8d\x01\n" + + "\x06Config\x12\x1d\n" + "\n" + - "user_level\x18\x02 \x01(\rR\tuserLevel\x12 \n" + - "\fnon_IP_query\x18\x03 \x01(\tR\n" + - "nonIPQuery\x12\x1f\n" + - "\vblock_types\x18\x04 \x03(\x05R\n" + - "blockTypesBL\n" + + "user_level\x18\x01 \x01(\rR\tuserLevel\x121\n" + + "\x04rule\x18\x02 \x03(\v2\x1d.xray.proxy.dns.DNSRuleConfigR\x04rule\x121\n" + + "\x06server\x18\x03 \x01(\v2\x19.xray.common.net.EndpointR\x06server*:\n" + + "\n" + + "RuleAction\x12\n" + + "\n" + + "\x06Direct\x10\x00\x12\b\n" + + "\x04Drop\x10\x01\x12\n" + + "\n" + + "\x06Reject\x10\x02\x12\n" + + "\n" + + "\x06Hijack\x10\x03BL\n" + "\x12com.xray.proxy.dnsP\x01Z#github.com/xtls/xray-core/proxy/dns\xaa\x02\x0eXray.Proxy.Dnsb\x06proto3" var ( @@ -119,18 +232,25 @@ func file_proxy_dns_config_proto_rawDescGZIP() []byte { return file_proxy_dns_config_proto_rawDescData } -var file_proxy_dns_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_proxy_dns_config_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_proxy_dns_config_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_proxy_dns_config_proto_goTypes = []any{ - (*Config)(nil), // 0: xray.proxy.dns.Config - (*net.Endpoint)(nil), // 1: xray.common.net.Endpoint + (RuleAction)(0), // 0: xray.proxy.dns.RuleAction + (*DNSRuleConfig)(nil), // 1: xray.proxy.dns.DNSRuleConfig + (*Config)(nil), // 2: xray.proxy.dns.Config + (*geodata.DomainRule)(nil), // 3: xray.common.geodata.DomainRule + (*net.Endpoint)(nil), // 4: xray.common.net.Endpoint } var file_proxy_dns_config_proto_depIdxs = []int32{ - 1, // 0: xray.proxy.dns.Config.server:type_name -> xray.common.net.Endpoint - 1, // [1:1] is the sub-list for method output_type - 1, // [1:1] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 0, // 0: xray.proxy.dns.DNSRuleConfig.action:type_name -> xray.proxy.dns.RuleAction + 3, // 1: xray.proxy.dns.DNSRuleConfig.domain:type_name -> xray.common.geodata.DomainRule + 1, // 2: xray.proxy.dns.Config.rule:type_name -> xray.proxy.dns.DNSRuleConfig + 4, // 3: xray.proxy.dns.Config.server:type_name -> xray.common.net.Endpoint + 4, // [4:4] is the sub-list for method output_type + 4, // [4:4] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name } func init() { file_proxy_dns_config_proto_init() } @@ -143,13 +263,14 @@ func file_proxy_dns_config_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_proxy_dns_config_proto_rawDesc), len(file_proxy_dns_config_proto_rawDesc)), - NumEnums: 0, - NumMessages: 1, + NumEnums: 1, + NumMessages: 2, NumExtensions: 0, NumServices: 0, }, GoTypes: file_proxy_dns_config_proto_goTypes, DependencyIndexes: file_proxy_dns_config_proto_depIdxs, + EnumInfos: file_proxy_dns_config_proto_enumTypes, MessageInfos: file_proxy_dns_config_proto_msgTypes, }.Build() File_proxy_dns_config_proto = out.File diff --git a/proxy/dns/config.proto b/proxy/dns/config.proto index af2aad8c..866812d6 100644 --- a/proxy/dns/config.proto +++ b/proxy/dns/config.proto @@ -7,12 +7,23 @@ option java_package = "com.xray.proxy.dns"; option java_multiple_files = true; import "common/net/destination.proto"; +import "common/geodata/geodat.proto"; + +enum RuleAction { + Direct = 0; + Drop = 1; + Reject = 2; + Hijack = 3; +} + +message DNSRuleConfig { + RuleAction action = 1; + repeated int32 qtype = 2; + repeated xray.common.geodata.DomainRule domain = 3; +} message Config { - // Server is the DNS server address. If specified, this address overrides the - // original one. - xray.common.net.Endpoint server = 1; - uint32 user_level = 2; - string non_IP_query = 3; - repeated int32 block_types = 4; + uint32 user_level = 1; + repeated DNSRuleConfig rule = 2; + xray.common.net.Endpoint server = 3; } diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index 9ae19cbe..1b240206 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -11,6 +11,7 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" dns_proto "github.com/xtls/xray-core/common/protocol/dns" "github.com/xtls/xray-core/common/session" @@ -40,6 +41,31 @@ func init() { })) } +type DNSRule struct { + action RuleAction + qTypes []uint16 + domains geodata.DomainMatcher +} + +func (r *DNSRule) matchQType(qType uint16) bool { + if len(r.qTypes) == 0 { + return true + } + for _, t := range r.qTypes { + if t == qType { + return true + } + } + return false +} + +func (r *DNSRule) Apply(qType uint16, domain string) bool { + if !r.matchQType(qType) { + return false + } + return r.domains == nil || r.domains.MatchAny(strings.TrimSuffix(strings.ToLower(domain), ".")) +} + type ownLinkVerifier interface { IsOwnLink(ctx context.Context) bool } @@ -50,8 +76,7 @@ type Handler struct { ownLinkVerifier ownLinkVerifier server net.Destination timeout time.Duration - nonIPQuery string - blockTypes []int32 + rules []*DNSRule } func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error { @@ -65,11 +90,26 @@ func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager polic if config.Server != nil { h.server = config.Server.AsDestination() } - h.nonIPQuery = config.Non_IPQuery - if h.nonIPQuery == "" { - h.nonIPQuery = "reject" + + h.rules = make([]*DNSRule, 0, len(config.Rule)) + for _, r := range config.Rule { + rule := &DNSRule{ + action: r.Action, + qTypes: make([]uint16, 0, len(r.Qtype)), + } + for _, t := range r.Qtype { + rule.qTypes = append(rule.qTypes, uint16(t)) + } + if len(r.Domain) > 0 { + m, err := geodata.DomainReg.BuildDomainMatcher(r.Domain) + if err != nil { + return err + } + rule.domains = m + } + h.rules = append(h.rules, rule) } - h.blockTypes = config.BlockTypes + return nil } @@ -77,30 +117,38 @@ func (h *Handler) isOwnLink(ctx context.Context) bool { return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) } -func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { +func parseQuery(b []byte) (id uint16, qType dnsmessage.Type, domain string, ok bool) { var parser dnsmessage.Parser header, err := parser.Start(b) if err != nil { errors.LogInfoInner(context.Background(), err, "parser start") return } - id = header.ID q, err := parser.Question() if err != nil { errors.LogInfoInner(context.Background(), err, "question") return } - domain = q.Name.String() qType = q.Type - if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { - return - } - - r = true + domain = q.Name.String() + ok = true return } +func (h *Handler) applyRules(qType dnsmessage.Type, domain string) RuleAction { + qCode := uint16(qType) + for _, r := range h.rules { + if r.Apply(qCode, domain) { + return r.action + } + } + if qType == dnsmessage.TypeA || qType == dnsmessage.TypeAAAA { + return RuleAction_Hijack + } + return RuleAction_Reject +} + // Process implements proxy.Outbound. func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { outbounds := session.OutboundsFromContext(ctx) @@ -183,51 +231,51 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet. if err == io.EOF { return nil } - if err != nil { return err } timer.Update() - if !h.isOwnLink(ctx) { - isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) - if len(h.blockTypes) > 0 { - for _, blocktype := range h.blockTypes { - if blocktype == int32(qType) { - b.Release() - errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain) - if h.nonIPQuery == "reject" { - err := h.rejectNonIPQuery(id, qType, domain, writer) - if err != nil { - return err - } - } - return nil - } - } - } - if isIPQuery { - b.Release() - go h.handleIPQuery(id, qType, domain, writer, timer) - continue - } - if h.nonIPQuery == "drop" { - b.Release() - continue - } - if h.nonIPQuery == "reject" { - b.Release() - err := h.rejectNonIPQuery(id, qType, domain, writer) - if err != nil { - return err - } - continue + if h.isOwnLink(ctx) { + if err := connWriter.WriteMessage(b); err != nil { + return err } + continue } - if err := connWriter.WriteMessage(b); err != nil { - return err + id, qType, domain, ok := parseQuery(b.Bytes()) + if !ok { + b.Release() + continue + } + + switch h.applyRules(qType, domain) { + case RuleAction_Drop: + b.Release() + errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain) + case RuleAction_Reject: + b.Release() + errors.LogInfo(ctx, "rejected type ", qType, " query for domain ", domain) + if err := h.rejectNonIPQuery(id, qType, domain, writer); err != nil { + return err + } + case RuleAction_Hijack: + b.Release() + if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { + errors.LogError(ctx, "can only hijack A/AAAA records") + if err := h.rejectNonIPQuery(id, qType, domain, writer); err != nil { + return err + } + } else { + go h.handleIPQuery(id, qType, domain, writer, timer) + } + case RuleAction_Direct: + if err := connWriter.WriteMessage(b); err != nil { + return err + } + default: + panic("unknown rule action") } } } diff --git a/proxy/dns/dns_test.go b/proxy/dns/dns_test.go index 2a005d68..5df58c7b 100644 --- a/proxy/dns/dns_test.go +++ b/proxy/dns/dns_test.go @@ -14,6 +14,7 @@ import ( _ "github.com/xtls/xray-core/app/proxyman/inbound" _ "github.com/xtls/xray-core/app/proxyman/outbound" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/serial" "github.com/xtls/xray-core/core" @@ -368,3 +369,126 @@ func TestUDP2TCPDNSTunnel(t *testing.T) { t.Error(r) } } + +func TestDNSRules(t *testing.T) { + port := udp.PickPort() + + dnsServer := dns.Server{ + Addr: "127.0.0.1:" + port.String(), + Net: "udp", + Handler: &staticHandler{}, + } + defer dnsServer.Shutdown() + + go dnsServer.ListenAndServe() + time.Sleep(time.Second) + + serverPort := udp.PickPort() + config := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&dnsapp.Config{ + NameServer: []*dnsapp.NameServer{ + { + Address: &net.Endpoint{ + Network: net.Network_UDP, + Address: &net.IPOrDomain{ + Address: &net.IPOrDomain_Ip{ + Ip: []byte{127, 0, 0, 1}, + }, + }, + Port: uint32(port), + }, + }, + }, + }), + serial.ToTypedMessage(&dispatcher.Config{}), + serial.ToTypedMessage(&proxyman.OutboundConfig{}), + serial.ToTypedMessage(&proxyman.InboundConfig{}), + serial.ToTypedMessage(&policy.Config{}), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: net.NewIPOrDomain(net.LocalHostIP), + Port: uint32(port), + Networks: []net.Network{net.Network_UDP}, + }), + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &net.PortList{Range: []*net.PortRange{net.SinglePortRange(serverPort)}}, + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&dns_proxy.Config{ + Rule: []*dns_proxy.DNSRuleConfig{ + { + Qtype: []int32{int32(dns.TypeA)}, + Domain: []*geodata.DomainRule{ + { + Value: &geodata.DomainRule_Custom{ + Custom: &geodata.Domain{ + Type: geodata.Domain_Domain, + Value: "facebook.com", + }, + }, + }, + }, + Action: dns_proxy.RuleAction_Direct, + }, + { + Qtype: []int32{int32(dns.TypeA)}, + Domain: []*geodata.DomainRule{ + { + Value: &geodata.DomainRule_Custom{ + Custom: &geodata.Domain{ + Type: geodata.Domain_Full, + Value: "google.com", + }, + }, + }, + }, + Action: dns_proxy.RuleAction_Reject, + }, + }, + }), + }, + }, + } + + v, err := core.New(config) + common.Must(err) + common.Must(v.Start()) + defer v.Close() + + { + m1 := new(dns.Msg) + m1.Id = dns.Id() + m1.RecursionDesired = true + m1.Question = []dns.Question{{Name: "google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}} + + c := new(dns.Client) + in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort))) + common.Must(err) + + if in.Rcode != dns.RcodeRefused { + t.Fatal("expected Refused, but got ", in.Rcode) + } + } + + { + m1 := new(dns.Msg) + m1.Id = dns.Id() + m1.RecursionDesired = true + m1.Question = []dns.Question{{Name: "facebook.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}} + + c := new(dns.Client) + in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort))) + common.Must(err) + + if in.Rcode != dns.RcodeSuccess { + t.Fatal("expected Success, but got ", in.Rcode) + } + } +}