From 82624bcaf01cc67326bd8ede7599e8c64f5f71e4 Mon Sep 17 00:00:00 2001 From: Meow Date: Tue, 14 Apr 2026 00:42:29 +0800 Subject: [PATCH] Xray-core: Refactor geodata (#5814) https://github.com/XTLS/Xray-core/issues/4422#issuecomment-3533007890 Breaking changes https://github.com/XTLS/Xray-core/pull/5569 Reverts https://github.com/XTLS/Xray-core/pull/5505 Closes https://github.com/XTLS/Xray-core/pull/643 --- app/dns/config.go | 46 +- app/dns/config.pb.go | 332 ++----- app/dns/config.proto | 31 +- app/dns/dns.go | 88 +- app/dns/dns_test.go | 138 +-- app/dns/hosts.go | 115 +-- app/dns/hosts_test.go | 75 +- app/dns/nameserver.go | 116 +-- app/router/command/command_test.go | 7 +- app/router/condition.go | 72 +- app/router/condition_geoip_test.go | 266 ----- app/router/condition_serialize_test.go | 167 ---- app/router/condition_test.go | 120 +-- app/router/config.go | 63 +- app/router/config.pb.go | 704 ++------------ app/router/config.proto | 90 +- app/router/geosite_compact.go | 100 -- app/router/router_test.go | 19 +- app/stats/command/command.go | 8 +- common/geodata/domain_matcher.go | 66 ++ common/geodata/domain_registry.go | 13 + common/geodata/geodat.pb.go | 908 ++++++++++++++++++ common/geodata/geodat.proto | 90 ++ common/geodata/geodat_loader.go | 207 ++++ .../geodata/ip_matcher.go | 314 +++--- common/geodata/ip_matcher_test.go | 325 +++++++ common/geodata/ip_registry.go | 17 + common/geodata/rule_parser.go | 254 +++++ common/geodata/rule_parser_test.go | 51 + .../strmatcher/benchmark_indexmatcher_test.go | 58 ++ .../strmatcher/benchmark_matchers_test.go | 149 +++ .../geodata/strmatcher/indexmatcher_linear.go | 96 ++ .../strmatcher/indexmatcher_linear_test.go | 95 ++ common/geodata/strmatcher/indexmatcher_mph.go | 100 ++ .../strmatcher/indexmatcher_mph_test.go | 94 ++ .../strmatcher/matchergroup_ac_automation.go | 282 ++++++ .../matchergroup_ac_automation_test.go | 365 +++++++ .../geodata/strmatcher/matchergroup_domain.go | 109 +++ .../strmatcher/matchergroup_domain_test.go} | 51 +- .../geodata/strmatcher/matchergroup_full.go | 30 + .../strmatcher/matchergroup_full_test.go} | 41 +- common/geodata/strmatcher/matchergroup_mph.go | 198 ++++ .../strmatcher/matchergroup_mph_test.go} | 261 ++--- .../geodata/strmatcher/matchergroup_simple.go | 41 + .../strmatcher/matchergroup_simple_test.go | 69 ++ .../geodata/strmatcher/matchergroup_substr.go | 61 ++ .../strmatcher/matchergroup_substr_test.go | 65 ++ common/geodata/strmatcher/matchers.go | 290 ++++++ common/geodata/strmatcher/matchers_test.go | 149 +++ common/geodata/strmatcher/strmatcher.go | 101 ++ .../geodata/strmatcher/valuematcher_linear.go | 85 ++ common/geodata/strmatcher/valuematcher_mph.go | 89 ++ common/platform/platform.go | 2 - common/strmatcher/ac_automaton_matcher.go | 247 ----- common/strmatcher/benchmark_test.go | 62 -- common/strmatcher/domain_matcher.go | 98 -- common/strmatcher/full_matcher.go | 25 - common/strmatcher/matchers.go | 56 -- common/strmatcher/matchers_test.go | 73 -- common/strmatcher/mph_matcher.go | 308 ------ common/strmatcher/mph_matcher_compact.go | 47 - common/strmatcher/strmatcher.go | 141 --- infra/conf/dns.go | 241 ++--- infra/conf/dns_test.go | 49 +- infra/conf/router.go | 455 +-------- infra/conf/router_test.go | 111 +-- infra/conf/xray.go | 186 ---- infra/conf/xray_test.go | 7 +- main/commands/all/buildmphcache.go | 52 - main/commands/all/commands.go | 1 - main/run.go | 14 +- testing/scenarios/dns_test.go | 10 +- testing/scenarios/reverse_test.go | 21 +- 73 files changed, 5432 insertions(+), 4455 deletions(-) delete mode 100644 app/router/condition_geoip_test.go delete mode 100644 app/router/condition_serialize_test.go delete mode 100644 app/router/geosite_compact.go create mode 100644 common/geodata/domain_matcher.go create mode 100644 common/geodata/domain_registry.go create mode 100644 common/geodata/geodat.pb.go create mode 100644 common/geodata/geodat.proto create mode 100644 common/geodata/geodat_loader.go rename app/router/condition_geoip.go => common/geodata/ip_matcher.go (71%) create mode 100644 common/geodata/ip_matcher_test.go create mode 100644 common/geodata/ip_registry.go create mode 100644 common/geodata/rule_parser.go create mode 100644 common/geodata/rule_parser_test.go create mode 100644 common/geodata/strmatcher/benchmark_indexmatcher_test.go create mode 100644 common/geodata/strmatcher/benchmark_matchers_test.go create mode 100644 common/geodata/strmatcher/indexmatcher_linear.go create mode 100644 common/geodata/strmatcher/indexmatcher_linear_test.go create mode 100644 common/geodata/strmatcher/indexmatcher_mph.go create mode 100644 common/geodata/strmatcher/indexmatcher_mph_test.go create mode 100644 common/geodata/strmatcher/matchergroup_ac_automation.go create mode 100644 common/geodata/strmatcher/matchergroup_ac_automation_test.go create mode 100644 common/geodata/strmatcher/matchergroup_domain.go rename common/{strmatcher/domain_matcher_test.go => geodata/strmatcher/matchergroup_domain_test.go} (61%) create mode 100644 common/geodata/strmatcher/matchergroup_full.go rename common/{strmatcher/full_matcher_test.go => geodata/strmatcher/matchergroup_full_test.go} (56%) create mode 100644 common/geodata/strmatcher/matchergroup_mph.go rename common/{strmatcher/strmatcher_test.go => geodata/strmatcher/matchergroup_mph_test.go} (56%) create mode 100644 common/geodata/strmatcher/matchergroup_simple.go create mode 100644 common/geodata/strmatcher/matchergroup_simple_test.go create mode 100644 common/geodata/strmatcher/matchergroup_substr.go create mode 100644 common/geodata/strmatcher/matchergroup_substr_test.go create mode 100644 common/geodata/strmatcher/matchers.go create mode 100644 common/geodata/strmatcher/matchers_test.go create mode 100644 common/geodata/strmatcher/strmatcher.go create mode 100644 common/geodata/strmatcher/valuematcher_linear.go create mode 100644 common/geodata/strmatcher/valuematcher_mph.go delete mode 100644 common/strmatcher/ac_automaton_matcher.go delete mode 100644 common/strmatcher/benchmark_test.go delete mode 100644 common/strmatcher/domain_matcher.go delete mode 100644 common/strmatcher/full_matcher.go delete mode 100644 common/strmatcher/matchers.go delete mode 100644 common/strmatcher/matchers_test.go delete mode 100644 common/strmatcher/mph_matcher.go delete mode 100644 common/strmatcher/mph_matcher_compact.go delete mode 100644 common/strmatcher/strmatcher.go delete mode 100644 main/commands/all/buildmphcache.go diff --git a/app/dns/config.go b/app/dns/config.go index ab547e14..1f9ac015 100644 --- a/app/dns/config.go +++ b/app/dns/config.go @@ -2,48 +2,24 @@ package dns import ( "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/strmatcher" "github.com/xtls/xray-core/common/uuid" ) -var typeMap = map[DomainMatchingType]strmatcher.Type{ - DomainMatchingType_Full: strmatcher.Full, - DomainMatchingType_Subdomain: strmatcher.Domain, - DomainMatchingType_Keyword: strmatcher.Substr, - DomainMatchingType_Regex: strmatcher.Regex, -} - // References: // https://www.iana.org/assignments/special-use-domain-names/special-use-domain-names.xhtml // https://unix.stackexchange.com/questions/92441/whats-the-difference-between-local-home-and-lan -var localTLDsAndDotlessDomains = []*NameServer_PriorityDomain{ - {Type: DomainMatchingType_Regex, Domain: "^[^.]+$"}, // This will only match domains without any dot - {Type: DomainMatchingType_Subdomain, Domain: "local"}, - {Type: DomainMatchingType_Subdomain, Domain: "localdomain"}, - {Type: DomainMatchingType_Subdomain, Domain: "localhost"}, - {Type: DomainMatchingType_Subdomain, Domain: "lan"}, - {Type: DomainMatchingType_Subdomain, Domain: "home.arpa"}, - {Type: DomainMatchingType_Subdomain, Domain: "example"}, - {Type: DomainMatchingType_Subdomain, Domain: "invalid"}, - {Type: DomainMatchingType_Subdomain, Domain: "test"}, -} - -var localTLDsAndDotlessDomainsRule = &NameServer_OriginalRule{ - Rule: "geosite:private", - Size: uint32(len(localTLDsAndDotlessDomains)), -} - -func toStrMatcher(t DomainMatchingType, domain string) (strmatcher.Matcher, error) { - strMType, f := typeMap[t] - if !f { - return nil, errors.New("unknown mapping type", t).AtWarning() - } - matcher, err := strMType.New(domain) - if err != nil { - return nil, errors.New("failed to create str matcher").Base(err) - } - return matcher, nil +var localTLDsAndDotlessDomainsRules = []*geodata.DomainRule{ + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Regex, Value: "^[^.]+$"}}}, // This will only match domains without any dot + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "local"}}}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "localdomain"}}}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "localhost"}}}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "lan"}}}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "home.arpa"}}}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "example"}}}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "invalid"}}}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "test"}}}, } func toNetIP(addrs []net.Address) ([]net.IP, error) { diff --git a/app/dns/config.pb.go b/app/dns/config.pb.go index fb351afb..c0737a0d 100644 --- a/app/dns/config.pb.go +++ b/app/dns/config.pb.go @@ -7,7 +7,7 @@ package dns import ( - router "github.com/xtls/xray-core/app/router" + geodata "github.com/xtls/xray-core/common/geodata" net "github.com/xtls/xray-core/common/net" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" @@ -23,58 +23,6 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -type DomainMatchingType int32 - -const ( - DomainMatchingType_Full DomainMatchingType = 0 - DomainMatchingType_Subdomain DomainMatchingType = 1 - DomainMatchingType_Keyword DomainMatchingType = 2 - DomainMatchingType_Regex DomainMatchingType = 3 -) - -// Enum value maps for DomainMatchingType. -var ( - DomainMatchingType_name = map[int32]string{ - 0: "Full", - 1: "Subdomain", - 2: "Keyword", - 3: "Regex", - } - DomainMatchingType_value = map[string]int32{ - "Full": 0, - "Subdomain": 1, - "Keyword": 2, - "Regex": 3, - } -) - -func (x DomainMatchingType) Enum() *DomainMatchingType { - p := new(DomainMatchingType) - *p = x - return p -} - -func (x DomainMatchingType) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (DomainMatchingType) Descriptor() protoreflect.EnumDescriptor { - return file_app_dns_config_proto_enumTypes[0].Descriptor() -} - -func (DomainMatchingType) Type() protoreflect.EnumType { - return &file_app_dns_config_proto_enumTypes[0] -} - -func (x DomainMatchingType) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use DomainMatchingType.Descriptor instead. -func (DomainMatchingType) EnumDescriptor() ([]byte, []int) { - return file_app_dns_config_proto_rawDescGZIP(), []int{0} -} - type QueryStrategy int32 const ( @@ -111,11 +59,11 @@ func (x QueryStrategy) String() string { } func (QueryStrategy) Descriptor() protoreflect.EnumDescriptor { - return file_app_dns_config_proto_enumTypes[1].Descriptor() + return file_app_dns_config_proto_enumTypes[0].Descriptor() } func (QueryStrategy) Type() protoreflect.EnumType { - return &file_app_dns_config_proto_enumTypes[1] + return &file_app_dns_config_proto_enumTypes[0] } func (x QueryStrategy) Number() protoreflect.EnumNumber { @@ -124,30 +72,29 @@ func (x QueryStrategy) Number() protoreflect.EnumNumber { // Deprecated: Use QueryStrategy.Descriptor instead. func (QueryStrategy) EnumDescriptor() ([]byte, []int) { - return file_app_dns_config_proto_rawDescGZIP(), []int{1} + return file_app_dns_config_proto_rawDescGZIP(), []int{0} } type NameServer struct { - state protoimpl.MessageState `protogen:"open.v1"` - Address *net.Endpoint `protobuf:"bytes,1,opt,name=address,proto3" json:"address,omitempty"` - ClientIp []byte `protobuf:"bytes,5,opt,name=client_ip,json=clientIp,proto3" json:"client_ip,omitempty"` - SkipFallback bool `protobuf:"varint,6,opt,name=skipFallback,proto3" json:"skipFallback,omitempty"` - PrioritizedDomain []*NameServer_PriorityDomain `protobuf:"bytes,2,rep,name=prioritized_domain,json=prioritizedDomain,proto3" json:"prioritized_domain,omitempty"` - ExpectedGeoip []*router.GeoIP `protobuf:"bytes,3,rep,name=expected_geoip,json=expectedGeoip,proto3" json:"expected_geoip,omitempty"` - OriginalRules []*NameServer_OriginalRule `protobuf:"bytes,4,rep,name=original_rules,json=originalRules,proto3" json:"original_rules,omitempty"` - QueryStrategy QueryStrategy `protobuf:"varint,7,opt,name=query_strategy,json=queryStrategy,proto3,enum=xray.app.dns.QueryStrategy" json:"query_strategy,omitempty"` - ActPrior bool `protobuf:"varint,8,opt,name=actPrior,proto3" json:"actPrior,omitempty"` - Tag string `protobuf:"bytes,9,opt,name=tag,proto3" json:"tag,omitempty"` - TimeoutMs uint64 `protobuf:"varint,10,opt,name=timeoutMs,proto3" json:"timeoutMs,omitempty"` - DisableCache *bool `protobuf:"varint,11,opt,name=disableCache,proto3,oneof" json:"disableCache,omitempty"` - ServeStale *bool `protobuf:"varint,15,opt,name=serveStale,proto3,oneof" json:"serveStale,omitempty"` - ServeExpiredTTL *uint32 `protobuf:"varint,16,opt,name=serveExpiredTTL,proto3,oneof" json:"serveExpiredTTL,omitempty"` - FinalQuery bool `protobuf:"varint,12,opt,name=finalQuery,proto3" json:"finalQuery,omitempty"` - UnexpectedGeoip []*router.GeoIP `protobuf:"bytes,13,rep,name=unexpected_geoip,json=unexpectedGeoip,proto3" json:"unexpected_geoip,omitempty"` - ActUnprior bool `protobuf:"varint,14,opt,name=actUnprior,proto3" json:"actUnprior,omitempty"` - PolicyID uint32 `protobuf:"varint,17,opt,name=policyID,proto3" json:"policyID,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Address *net.Endpoint `protobuf:"bytes,1,opt,name=address,proto3" json:"address,omitempty"` + ClientIp []byte `protobuf:"bytes,5,opt,name=client_ip,json=clientIp,proto3" json:"client_ip,omitempty"` + SkipFallback bool `protobuf:"varint,6,opt,name=skipFallback,proto3" json:"skipFallback,omitempty"` + Domain []*geodata.DomainRule `protobuf:"bytes,2,rep,name=domain,proto3" json:"domain,omitempty"` + ExpectedIp []*geodata.IPRule `protobuf:"bytes,3,rep,name=expected_ip,json=expectedIp,proto3" json:"expected_ip,omitempty"` + QueryStrategy QueryStrategy `protobuf:"varint,7,opt,name=query_strategy,json=queryStrategy,proto3,enum=xray.app.dns.QueryStrategy" json:"query_strategy,omitempty"` + ActPrior bool `protobuf:"varint,8,opt,name=actPrior,proto3" json:"actPrior,omitempty"` + Tag string `protobuf:"bytes,9,opt,name=tag,proto3" json:"tag,omitempty"` + TimeoutMs uint64 `protobuf:"varint,10,opt,name=timeoutMs,proto3" json:"timeoutMs,omitempty"` + DisableCache *bool `protobuf:"varint,11,opt,name=disableCache,proto3,oneof" json:"disableCache,omitempty"` + ServeStale *bool `protobuf:"varint,15,opt,name=serveStale,proto3,oneof" json:"serveStale,omitempty"` + ServeExpiredTTL *uint32 `protobuf:"varint,16,opt,name=serveExpiredTTL,proto3,oneof" json:"serveExpiredTTL,omitempty"` + FinalQuery bool `protobuf:"varint,12,opt,name=finalQuery,proto3" json:"finalQuery,omitempty"` + UnexpectedIp []*geodata.IPRule `protobuf:"bytes,13,rep,name=unexpected_ip,json=unexpectedIp,proto3" json:"unexpected_ip,omitempty"` + ActUnprior bool `protobuf:"varint,14,opt,name=actUnprior,proto3" json:"actUnprior,omitempty"` + PolicyID uint32 `protobuf:"varint,17,opt,name=policyID,proto3" json:"policyID,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *NameServer) Reset() { @@ -201,23 +148,16 @@ func (x *NameServer) GetSkipFallback() bool { return false } -func (x *NameServer) GetPrioritizedDomain() []*NameServer_PriorityDomain { +func (x *NameServer) GetDomain() []*geodata.DomainRule { if x != nil { - return x.PrioritizedDomain + return x.Domain } return nil } -func (x *NameServer) GetExpectedGeoip() []*router.GeoIP { +func (x *NameServer) GetExpectedIp() []*geodata.IPRule { if x != nil { - return x.ExpectedGeoip - } - return nil -} - -func (x *NameServer) GetOriginalRules() []*NameServer_OriginalRule { - if x != nil { - return x.OriginalRules + return x.ExpectedIp } return nil } @@ -278,9 +218,9 @@ func (x *NameServer) GetFinalQuery() bool { return false } -func (x *NameServer) GetUnexpectedGeoip() []*router.GeoIP { +func (x *NameServer) GetUnexpectedIp() []*geodata.IPRule { if x != nil { - return x.UnexpectedGeoip + return x.UnexpectedIp } return nil } @@ -429,114 +369,9 @@ func (x *Config) GetEnableParallelQuery() bool { return false } -type NameServer_PriorityDomain struct { - state protoimpl.MessageState `protogen:"open.v1"` - Type DomainMatchingType `protobuf:"varint,1,opt,name=type,proto3,enum=xray.app.dns.DomainMatchingType" json:"type,omitempty"` - Domain string `protobuf:"bytes,2,opt,name=domain,proto3" json:"domain,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *NameServer_PriorityDomain) Reset() { - *x = NameServer_PriorityDomain{} - mi := &file_app_dns_config_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *NameServer_PriorityDomain) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*NameServer_PriorityDomain) ProtoMessage() {} - -func (x *NameServer_PriorityDomain) ProtoReflect() protoreflect.Message { - mi := &file_app_dns_config_proto_msgTypes[2] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use NameServer_PriorityDomain.ProtoReflect.Descriptor instead. -func (*NameServer_PriorityDomain) Descriptor() ([]byte, []int) { - return file_app_dns_config_proto_rawDescGZIP(), []int{0, 0} -} - -func (x *NameServer_PriorityDomain) GetType() DomainMatchingType { - if x != nil { - return x.Type - } - return DomainMatchingType_Full -} - -func (x *NameServer_PriorityDomain) GetDomain() string { - if x != nil { - return x.Domain - } - return "" -} - -type NameServer_OriginalRule struct { - state protoimpl.MessageState `protogen:"open.v1"` - Rule string `protobuf:"bytes,1,opt,name=rule,proto3" json:"rule,omitempty"` - Size uint32 `protobuf:"varint,2,opt,name=size,proto3" json:"size,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *NameServer_OriginalRule) Reset() { - *x = NameServer_OriginalRule{} - mi := &file_app_dns_config_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *NameServer_OriginalRule) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*NameServer_OriginalRule) ProtoMessage() {} - -func (x *NameServer_OriginalRule) ProtoReflect() protoreflect.Message { - mi := &file_app_dns_config_proto_msgTypes[3] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use NameServer_OriginalRule.ProtoReflect.Descriptor instead. -func (*NameServer_OriginalRule) Descriptor() ([]byte, []int) { - return file_app_dns_config_proto_rawDescGZIP(), []int{0, 1} -} - -func (x *NameServer_OriginalRule) GetRule() string { - if x != nil { - return x.Rule - } - return "" -} - -func (x *NameServer_OriginalRule) GetSize() uint32 { - if x != nil { - return x.Size - } - return 0 -} - type Config_HostMapping struct { state protoimpl.MessageState `protogen:"open.v1"` - Type DomainMatchingType `protobuf:"varint,1,opt,name=type,proto3,enum=xray.app.dns.DomainMatchingType" json:"type,omitempty"` - Domain string `protobuf:"bytes,2,opt,name=domain,proto3" json:"domain,omitempty"` + Domain *geodata.DomainRule `protobuf:"bytes,2,opt,name=domain,proto3" json:"domain,omitempty"` Ip [][]byte `protobuf:"bytes,3,rep,name=ip,proto3" json:"ip,omitempty"` // ProxiedDomain indicates the mapped domain has the same IP address on this // domain. Xray will use this domain for IP queries. @@ -547,7 +382,7 @@ type Config_HostMapping struct { func (x *Config_HostMapping) Reset() { *x = Config_HostMapping{} - mi := &file_app_dns_config_proto_msgTypes[4] + mi := &file_app_dns_config_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -559,7 +394,7 @@ func (x *Config_HostMapping) String() string { func (*Config_HostMapping) ProtoMessage() {} func (x *Config_HostMapping) ProtoReflect() protoreflect.Message { - mi := &file_app_dns_config_proto_msgTypes[4] + mi := &file_app_dns_config_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -575,18 +410,11 @@ func (*Config_HostMapping) Descriptor() ([]byte, []int) { return file_app_dns_config_proto_rawDescGZIP(), []int{1, 0} } -func (x *Config_HostMapping) GetType() DomainMatchingType { - if x != nil { - return x.Type - } - return DomainMatchingType_Full -} - -func (x *Config_HostMapping) GetDomain() string { +func (x *Config_HostMapping) GetDomain() *geodata.DomainRule { if x != nil { return x.Domain } - return "" + return nil } func (x *Config_HostMapping) GetIp() [][]byte { @@ -607,15 +435,15 @@ var File_app_dns_config_proto protoreflect.FileDescriptor const file_app_dns_config_proto_rawDesc = "" + "\n" + - "\x14app/dns/config.proto\x12\fxray.app.dns\x1a\x1ccommon/net/destination.proto\x1a\x17app/router/config.proto\"\xdf\a\n" + + "\x14app/dns/config.proto\x12\fxray.app.dns\x1a\x1ccommon/net/destination.proto\x1a\x1bcommon/geodata/geodat.proto\"\xde\x05\n" + "\n" + "NameServer\x123\n" + "\aaddress\x18\x01 \x01(\v2\x19.xray.common.net.EndpointR\aaddress\x12\x1b\n" + "\tclient_ip\x18\x05 \x01(\fR\bclientIp\x12\"\n" + - "\fskipFallback\x18\x06 \x01(\bR\fskipFallback\x12V\n" + - "\x12prioritized_domain\x18\x02 \x03(\v2'.xray.app.dns.NameServer.PriorityDomainR\x11prioritizedDomain\x12=\n" + - "\x0eexpected_geoip\x18\x03 \x03(\v2\x16.xray.app.router.GeoIPR\rexpectedGeoip\x12L\n" + - "\x0eoriginal_rules\x18\x04 \x03(\v2%.xray.app.dns.NameServer.OriginalRuleR\roriginalRules\x12B\n" + + "\fskipFallback\x18\x06 \x01(\bR\fskipFallback\x127\n" + + "\x06domain\x18\x02 \x03(\v2\x1f.xray.common.geodata.DomainRuleR\x06domain\x12<\n" + + "\vexpected_ip\x18\x03 \x03(\v2\x1b.xray.common.geodata.IPRuleR\n" + + "expectedIp\x12B\n" + "\x0equery_strategy\x18\a \x01(\x0e2\x1b.xray.app.dns.QueryStrategyR\rqueryStrategy\x12\x1a\n" + "\bactPrior\x18\b \x01(\bR\bactPrior\x12\x10\n" + "\x03tag\x18\t \x01(\tR\x03tag\x12\x1c\n" + @@ -628,21 +456,15 @@ const file_app_dns_config_proto_rawDesc = "" + "\x0fserveExpiredTTL\x18\x10 \x01(\rH\x02R\x0fserveExpiredTTL\x88\x01\x01\x12\x1e\n" + "\n" + "finalQuery\x18\f \x01(\bR\n" + - "finalQuery\x12A\n" + - "\x10unexpected_geoip\x18\r \x03(\v2\x16.xray.app.router.GeoIPR\x0funexpectedGeoip\x12\x1e\n" + + "finalQuery\x12@\n" + + "\runexpected_ip\x18\r \x03(\v2\x1b.xray.common.geodata.IPRuleR\funexpectedIp\x12\x1e\n" + "\n" + "actUnprior\x18\x0e \x01(\bR\n" + "actUnprior\x12\x1a\n" + - "\bpolicyID\x18\x11 \x01(\rR\bpolicyID\x1a^\n" + - "\x0ePriorityDomain\x124\n" + - "\x04type\x18\x01 \x01(\x0e2 .xray.app.dns.DomainMatchingTypeR\x04type\x12\x16\n" + - "\x06domain\x18\x02 \x01(\tR\x06domain\x1a6\n" + - "\fOriginalRule\x12\x12\n" + - "\x04rule\x18\x01 \x01(\tR\x04rule\x12\x12\n" + - "\x04size\x18\x02 \x01(\rR\x04sizeB\x0f\n" + + "\bpolicyID\x18\x11 \x01(\rR\bpolicyIDB\x0f\n" + "\r_disableCacheB\r\n" + "\v_serveStaleB\x12\n" + - "\x10_serveExpiredTTL\"\x98\x05\n" + + "\x10_serveExpiredTTLJ\x04\b\x04\x10\x05\"\x82\x05\n" + "\x06Config\x129\n" + "\vname_server\x18\x05 \x03(\v2\x18.xray.app.dns.NameServerR\n" + "nameServer\x12\x1b\n" + @@ -658,17 +480,11 @@ const file_app_dns_config_proto_rawDesc = "" + "\x0fdisableFallback\x18\n" + " \x01(\bR\x0fdisableFallback\x126\n" + "\x16disableFallbackIfMatch\x18\v \x01(\bR\x16disableFallbackIfMatch\x120\n" + - "\x13enableParallelQuery\x18\x0e \x01(\bR\x13enableParallelQuery\x1a\x92\x01\n" + - "\vHostMapping\x124\n" + - "\x04type\x18\x01 \x01(\x0e2 .xray.app.dns.DomainMatchingTypeR\x04type\x12\x16\n" + - "\x06domain\x18\x02 \x01(\tR\x06domain\x12\x0e\n" + + "\x13enableParallelQuery\x18\x0e \x01(\bR\x13enableParallelQuery\x1a}\n" + + "\vHostMapping\x127\n" + + "\x06domain\x18\x02 \x01(\v2\x1f.xray.common.geodata.DomainRuleR\x06domain\x12\x0e\n" + "\x02ip\x18\x03 \x03(\fR\x02ip\x12%\n" + - "\x0eproxied_domain\x18\x04 \x01(\tR\rproxiedDomainJ\x04\b\a\x10\b*E\n" + - "\x12DomainMatchingType\x12\b\n" + - "\x04Full\x10\x00\x12\r\n" + - "\tSubdomain\x10\x01\x12\v\n" + - "\aKeyword\x10\x02\x12\t\n" + - "\x05Regex\x10\x03*B\n" + + "\x0eproxied_domain\x18\x04 \x01(\tR\rproxiedDomainJ\x04\b\a\x10\b*B\n" + "\rQueryStrategy\x12\n" + "\n" + "\x06USE_IP\x10\x00\x12\v\n" + @@ -689,36 +505,32 @@ func file_app_dns_config_proto_rawDescGZIP() []byte { return file_app_dns_config_proto_rawDescData } -var file_app_dns_config_proto_enumTypes = make([]protoimpl.EnumInfo, 2) -var file_app_dns_config_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_app_dns_config_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_app_dns_config_proto_msgTypes = make([]protoimpl.MessageInfo, 3) var file_app_dns_config_proto_goTypes = []any{ - (DomainMatchingType)(0), // 0: xray.app.dns.DomainMatchingType - (QueryStrategy)(0), // 1: xray.app.dns.QueryStrategy - (*NameServer)(nil), // 2: xray.app.dns.NameServer - (*Config)(nil), // 3: xray.app.dns.Config - (*NameServer_PriorityDomain)(nil), // 4: xray.app.dns.NameServer.PriorityDomain - (*NameServer_OriginalRule)(nil), // 5: xray.app.dns.NameServer.OriginalRule - (*Config_HostMapping)(nil), // 6: xray.app.dns.Config.HostMapping - (*net.Endpoint)(nil), // 7: xray.common.net.Endpoint - (*router.GeoIP)(nil), // 8: xray.app.router.GeoIP + (QueryStrategy)(0), // 0: xray.app.dns.QueryStrategy + (*NameServer)(nil), // 1: xray.app.dns.NameServer + (*Config)(nil), // 2: xray.app.dns.Config + (*Config_HostMapping)(nil), // 3: xray.app.dns.Config.HostMapping + (*net.Endpoint)(nil), // 4: xray.common.net.Endpoint + (*geodata.DomainRule)(nil), // 5: xray.common.geodata.DomainRule + (*geodata.IPRule)(nil), // 6: xray.common.geodata.IPRule } var file_app_dns_config_proto_depIdxs = []int32{ - 7, // 0: xray.app.dns.NameServer.address:type_name -> xray.common.net.Endpoint - 4, // 1: xray.app.dns.NameServer.prioritized_domain:type_name -> xray.app.dns.NameServer.PriorityDomain - 8, // 2: xray.app.dns.NameServer.expected_geoip:type_name -> xray.app.router.GeoIP - 5, // 3: xray.app.dns.NameServer.original_rules:type_name -> xray.app.dns.NameServer.OriginalRule - 1, // 4: xray.app.dns.NameServer.query_strategy:type_name -> xray.app.dns.QueryStrategy - 8, // 5: xray.app.dns.NameServer.unexpected_geoip:type_name -> xray.app.router.GeoIP - 2, // 6: xray.app.dns.Config.name_server:type_name -> xray.app.dns.NameServer - 6, // 7: xray.app.dns.Config.static_hosts:type_name -> xray.app.dns.Config.HostMapping - 1, // 8: xray.app.dns.Config.query_strategy:type_name -> xray.app.dns.QueryStrategy - 0, // 9: xray.app.dns.NameServer.PriorityDomain.type:type_name -> xray.app.dns.DomainMatchingType - 0, // 10: xray.app.dns.Config.HostMapping.type:type_name -> xray.app.dns.DomainMatchingType - 11, // [11:11] is the sub-list for method output_type - 11, // [11:11] is the sub-list for method input_type - 11, // [11:11] is the sub-list for extension type_name - 11, // [11:11] is the sub-list for extension extendee - 0, // [0:11] is the sub-list for field type_name + 4, // 0: xray.app.dns.NameServer.address:type_name -> xray.common.net.Endpoint + 5, // 1: xray.app.dns.NameServer.domain:type_name -> xray.common.geodata.DomainRule + 6, // 2: xray.app.dns.NameServer.expected_ip:type_name -> xray.common.geodata.IPRule + 0, // 3: xray.app.dns.NameServer.query_strategy:type_name -> xray.app.dns.QueryStrategy + 6, // 4: xray.app.dns.NameServer.unexpected_ip:type_name -> xray.common.geodata.IPRule + 1, // 5: xray.app.dns.Config.name_server:type_name -> xray.app.dns.NameServer + 3, // 6: xray.app.dns.Config.static_hosts:type_name -> xray.app.dns.Config.HostMapping + 0, // 7: xray.app.dns.Config.query_strategy:type_name -> xray.app.dns.QueryStrategy + 5, // 8: xray.app.dns.Config.HostMapping.domain:type_name -> xray.common.geodata.DomainRule + 9, // [9:9] is the sub-list for method output_type + 9, // [9:9] is the sub-list for method input_type + 9, // [9:9] is the sub-list for extension type_name + 9, // [9:9] is the sub-list for extension extendee + 0, // [0:9] is the sub-list for field type_name } func init() { file_app_dns_config_proto_init() } @@ -732,8 +544,8 @@ func file_app_dns_config_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_app_dns_config_proto_rawDesc), len(file_app_dns_config_proto_rawDesc)), - NumEnums: 2, - NumMessages: 5, + NumEnums: 1, + NumMessages: 3, NumExtensions: 0, NumServices: 0, }, diff --git a/app/dns/config.proto b/app/dns/config.proto index 3ce312bc..ddc19dc7 100644 --- a/app/dns/config.proto +++ b/app/dns/config.proto @@ -7,26 +7,15 @@ option java_package = "com.xray.app.dns"; option java_multiple_files = true; import "common/net/destination.proto"; -import "app/router/config.proto"; +import "common/geodata/geodat.proto"; message NameServer { xray.common.net.Endpoint address = 1; bytes client_ip = 5; bool skipFallback = 6; - - message PriorityDomain { - DomainMatchingType type = 1; - string domain = 2; - } - - message OriginalRule { - string rule = 1; - uint32 size = 2; - } - - repeated PriorityDomain prioritized_domain = 2; - repeated xray.app.router.GeoIP expected_geoip = 3; - repeated OriginalRule original_rules = 4; + repeated xray.common.geodata.DomainRule domain = 2; + repeated xray.common.geodata.IPRule expected_ip = 3; + reserved 4; QueryStrategy query_strategy = 7; bool actPrior = 8; string tag = 9; @@ -35,18 +24,11 @@ message NameServer { optional bool serveStale = 15; optional uint32 serveExpiredTTL = 16; bool finalQuery = 12; - repeated xray.app.router.GeoIP unexpected_geoip = 13; + repeated xray.common.geodata.IPRule unexpected_ip = 13; bool actUnprior = 14; uint32 policyID = 17; } -enum DomainMatchingType { - Full = 0; - Subdomain = 1; - Keyword = 2; - Regex = 3; -} - enum QueryStrategy { USE_IP = 0; USE_IP4 = 1; @@ -64,8 +46,7 @@ message Config { bytes client_ip = 3; message HostMapping { - DomainMatchingType type = 1; - string domain = 2; + xray.common.geodata.DomainRule domain = 2; repeated bytes ip = 3; diff --git a/app/dns/dns.go b/app/dns/dns.go index c1082083..85f82521 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -12,13 +12,11 @@ import ( "sync" "time" - "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/platform" "github.com/xtls/xray-core/common/session" - "github.com/xtls/xray-core/common/strmatcher" "github.com/xtls/xray-core/features/dns" ) @@ -32,15 +30,15 @@ type DNS struct { hosts *StaticHosts clients []*Client ctx context.Context - domainMatcher strmatcher.IndexMatcher + domainMatcher geodata.DomainMatcher matcherInfos []*DomainMatcherInfo checkSystem bool } -// DomainMatcherInfo contains information attached to index returned by Server.domainMatcher +// DomainMatcherInfo contains information attached to index returned by Server.domainMatcher. type DomainMatcherInfo struct { - clientIdx uint16 - domainRuleIdx uint16 + clientIdx uint16 + domainRule string } // New creates a new DNS server with given configuration. @@ -85,56 +83,40 @@ func New(ctx context.Context, config *Config) (*DNS, error) { return nil, errors.New("unexpected query strategy ", config.QueryStrategy) } - var hosts *StaticHosts - mphLoaded := false - domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) - if domainMatcherPath != "" { - if f, err := os.Open(domainMatcherPath); err == nil { - defer f.Close() - if m, err := router.LoadGeoSiteMatcher(f, "HOSTS"); err == nil { - f.Seek(0, 0) - if hostIPs, err := router.LoadGeoSiteHosts(f); err == nil { - if sh, err := NewStaticHostsFromCache(m, hostIPs); err == nil { - hosts = sh - mphLoaded = true - errors.LogDebug(ctx, "MphDomainMatcher loaded from cache for DNS hosts, size: ", sh.matchers.Size()) - } - } - } - } + hosts, err := NewStaticHosts(config.StaticHosts) + if err != nil { + return nil, errors.New("failed to create hosts").Base(err) } - if !mphLoaded { - sh, err := NewStaticHosts(config.StaticHosts) - if err != nil { - return nil, errors.New("failed to create hosts").Base(err) - } - hosts = sh - } - - var clients []*Client - domainRuleCount := 0 - var defaultTag = config.Tag if len(config.Tag) == 0 { defaultTag = generateRandomTag() } - for _, ns := range config.NameServer { - domainRuleCount += len(ns.PrioritizedDomain) - } - - // MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1 - matcherInfos := make([]*DomainMatcherInfo, domainRuleCount+1) - domainMatcher := &strmatcher.MatcherGroup{} + clients := make([]*Client, 0, len(config.NameServer)) + matcherInfos := make([]*DomainMatcherInfo, 0) + effectiveRules := make([]*geodata.DomainRule, 0) for _, ns := range config.NameServer { clientIdx := len(clients) - updateDomain := func(domainRule strmatcher.Matcher, originalRuleIdx int, matcherInfos []*DomainMatcherInfo) { - midx := domainMatcher.Add(domainRule) - matcherInfos[midx] = &DomainMatcherInfo{ - clientIdx: uint16(clientIdx), - domainRuleIdx: uint16(originalRuleIdx), + updateRules := func(isLocalNameServer bool) { + // Prioritize local domains with specific TLDs or those without any dot for the local DNS + if isLocalNameServer { + effectiveRules = append(effectiveRules, localTLDsAndDotlessDomainsRules...) + for _, rule := range localTLDsAndDotlessDomainsRules { + matcherInfos = append(matcherInfos, &DomainMatcherInfo{ + clientIdx: uint16(clientIdx), + domainRule: rule.String(), + }) + } + } + + effectiveRules = append(effectiveRules, ns.Domain...) + for _, rule := range ns.Domain { + matcherInfos = append(matcherInfos, &DomainMatcherInfo{ + clientIdx: uint16(clientIdx), + domainRule: rule.String(), + }) } } @@ -163,18 +145,24 @@ func New(ctx context.Context, config *Config) (*DNS, error) { if len(ns.Tag) > 0 { tag = ns.Tag } + clientIPOption := ResolveIpOptionOverride(ns.QueryStrategy, ipOption) if !clientIPOption.IPv4Enable && !clientIPOption.IPv6Enable { return nil, errors.New("no QueryStrategy available for ", ns.Address) } - client, err := NewClient(ctx, ns, myClientIP, disableCache, serveStale, serveExpiredTTL, tag, clientIPOption, &matcherInfos, updateDomain) + client, err := NewClient(ctx, ns, myClientIP, disableCache, serveStale, serveExpiredTTL, tag, clientIPOption, updateRules) if err != nil { return nil, errors.New("failed to create client").Base(err) } clients = append(clients, client) } + domainMatcher, err := geodata.DomainReg.BuildDomainMatcher(effectiveRules) + if err != nil { + return nil, err + } + // If there is no DNS client in config, add a `localhost` DNS client if len(clients) == 0 { clients = append(clients, NewLocalDNSClient(ipOption)) @@ -283,14 +271,14 @@ func (s *DNS) sortClients(domain string) []*Client { // Priority domain matching hasMatch := false - MatchSlice := s.domainMatcher.Match(domain) + MatchSlice := s.domainMatcher.Match(strings.ToLower(domain)) sort.Slice(MatchSlice, func(i, j int) bool { return MatchSlice[i] < MatchSlice[j] }) for _, match := range MatchSlice { info := s.matcherInfos[match] client := s.clients[info.clientIdx] - domainRule := client.domains[info.domainRuleIdx] + domainRule := info.domainRule domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", domainRule, info.clientIdx)) if clientUsed[info.clientIdx] { continue diff --git a/app/dns/dns_test.go b/app/dns/dns_test.go index cb70b0b3..d18c7686 100644 --- a/app/dns/dns_test.go +++ b/app/dns/dns_test.go @@ -11,9 +11,9 @@ import ( "github.com/xtls/xray-core/app/policy" "github.com/xtls/xray-core/app/proxyman" _ "github.com/xtls/xray-core/app/proxyman/outbound" - "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/serial" "github.com/xtls/xray-core/core" @@ -331,10 +331,9 @@ func TestPrioritizedDomain(t *testing.T) { }, Port: uint32(port), }, - PrioritizedDomain: []*NameServer_PriorityDomain{ + Domain: []*geodata.DomainRule{ { - Type: DomainMatchingType_Full, - Domain: "google.com", + Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "google.com"}}, }, }, }, @@ -471,8 +470,7 @@ func TestStaticHostDomain(t *testing.T) { }, StaticHosts: []*Config_HostMapping{ { - Type: DomainMatchingType_Full, - Domain: "example.com", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "example.com"}}}, ProxiedDomain: "google.com", }, }, @@ -539,11 +537,10 @@ func TestIPMatch(t *testing.T) { }, Port: uint32(port), }, - ExpectedGeoip: []*router.GeoIP{ + ExpectedIp: []*geodata.IPRule{ { - CountryCode: "local", - Cidr: []*router.CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ // inner ip, will not match Ip: []byte{192, 168, 11, 1}, Prefix: 32, @@ -563,20 +560,18 @@ func TestIPMatch(t *testing.T) { }, Port: uint32(port), }, - ExpectedGeoip: []*router.GeoIP{ + ExpectedIp: []*geodata.IPRule{ { - CountryCode: "test", - Cidr: []*router.CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{8, 8, 8, 8}, Prefix: 32, }, }, }, { - CountryCode: "test", - Cidr: []*router.CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{8, 8, 8, 4}, Prefix: 32, }, @@ -663,19 +658,15 @@ func TestLocalDomain(t *testing.T) { }, Port: uint32(port), }, - PrioritizedDomain: []*NameServer_PriorityDomain{ + Domain: []*geodata.DomainRule{ // Equivalent of dotless:localhost - {Type: DomainMatchingType_Regex, Domain: "^[^.]*localhost[^.]*$"}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Regex, Value: "^[^.]*localhost[^.]*$"}}}, }, - ExpectedGeoip: []*router.GeoIP{ - { // Will match localhost, localhost-a and localhost-b, - CountryCode: "local", - Cidr: []*router.CIDR{ - {Ip: []byte{127, 0, 0, 2}, Prefix: 32}, - {Ip: []byte{127, 0, 0, 3}, Prefix: 32}, - {Ip: []byte{127, 0, 0, 4}, Prefix: 32}, - }, - }, + ExpectedIp: []*geodata.IPRule{ + // Will match localhost, localhost-a and localhost-b, + {Value: &geodata.IPRule_Custom{Custom: &geodata.CIDR{Ip: []byte{127, 0, 0, 2}, Prefix: 32}}}, + {Value: &geodata.IPRule_Custom{Custom: &geodata.CIDR{Ip: []byte{127, 0, 0, 3}, Prefix: 32}}}, + {Value: &geodata.IPRule_Custom{Custom: &geodata.CIDR{Ip: []byte{127, 0, 0, 4}, Prefix: 32}}}, }, }, { @@ -688,23 +679,21 @@ func TestLocalDomain(t *testing.T) { }, Port: uint32(port), }, - PrioritizedDomain: []*NameServer_PriorityDomain{ + Domain: []*geodata.DomainRule{ // Equivalent of dotless: and domain:local - {Type: DomainMatchingType_Regex, Domain: "^[^.]*$"}, - {Type: DomainMatchingType_Subdomain, Domain: "local"}, - {Type: DomainMatchingType_Subdomain, Domain: "localdomain"}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Regex, Value: "^[^.]*$"}}}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "local"}}}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "localdomain"}}}, }, }, }, StaticHosts: []*Config_HostMapping{ { - Type: DomainMatchingType_Full, - Domain: "hostnamestatic", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "hostnamestatic"}}}, Ip: [][]byte{{127, 0, 0, 53}}, }, { - Type: DomainMatchingType_Full, - Domain: "hostnamealias", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "hostnamealias"}}}, ProxiedDomain: "hostname.localdomain", }, }, @@ -891,17 +880,27 @@ func TestMultiMatchPrioritizedDomain(t *testing.T) { }, Port: uint32(port), }, - PrioritizedDomain: []*NameServer_PriorityDomain{ + Domain: []*geodata.DomainRule{ { - Type: DomainMatchingType_Subdomain, - Domain: "google.com", + Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "google.com"}}, }, }, - ExpectedGeoip: []*router.GeoIP{ - { // Will only match 8.8.8.8 and 8.8.4.4 - Cidr: []*router.CIDR{ - {Ip: []byte{8, 8, 8, 8}, Prefix: 32}, - {Ip: []byte{8, 8, 4, 4}, Prefix: 32}, + ExpectedIp: []*geodata.IPRule{ + // Will only match 8.8.8.8 and 8.8.4.4 + { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ + Ip: []byte{8, 8, 8, 8}, + Prefix: 32, + }, + }, + }, + { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ + Ip: []byte{8, 8, 4, 4}, + Prefix: 32, + }, }, }, }, @@ -916,16 +915,19 @@ func TestMultiMatchPrioritizedDomain(t *testing.T) { }, Port: uint32(port), }, - PrioritizedDomain: []*NameServer_PriorityDomain{ + Domain: []*geodata.DomainRule{ { - Type: DomainMatchingType_Subdomain, - Domain: "google.com", + Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "google.com"}}, }, }, - ExpectedGeoip: []*router.GeoIP{ - { // Will match 8.8.8.8 and 8.8.8.7, etc - Cidr: []*router.CIDR{ - {Ip: []byte{8, 8, 8, 7}, Prefix: 24}, + ExpectedIp: []*geodata.IPRule{ + // Will match 8.8.8.8 and 8.8.8.7, etc + { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ + Ip: []byte{8, 8, 8, 7}, + Prefix: 24, + }, }, }, }, @@ -940,16 +942,19 @@ func TestMultiMatchPrioritizedDomain(t *testing.T) { }, Port: uint32(port), }, - PrioritizedDomain: []*NameServer_PriorityDomain{ + Domain: []*geodata.DomainRule{ { - Type: DomainMatchingType_Subdomain, - Domain: "api.google.com", + Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "api.google.com"}}, }, }, - ExpectedGeoip: []*router.GeoIP{ - { // Will only match 8.8.7.7 (api.google.com) - Cidr: []*router.CIDR{ - {Ip: []byte{8, 8, 7, 7}, Prefix: 32}, + ExpectedIp: []*geodata.IPRule{ + // Will only match 8.8.7.7 (api.google.com) + { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ + Ip: []byte{8, 8, 7, 7}, + Prefix: 32, + }, }, }, }, @@ -964,16 +969,19 @@ func TestMultiMatchPrioritizedDomain(t *testing.T) { }, Port: uint32(port), }, - PrioritizedDomain: []*NameServer_PriorityDomain{ + Domain: []*geodata.DomainRule{ { - Type: DomainMatchingType_Full, - Domain: "v2.api.google.com", + Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "v2.api.google.com"}}, }, }, - ExpectedGeoip: []*router.GeoIP{ - { // Will only match 8.8.7.8 (v2.api.google.com) - Cidr: []*router.CIDR{ - {Ip: []byte{8, 8, 7, 8}, Prefix: 32}, + ExpectedIp: []*geodata.IPRule{ + // Will only match 8.8.7.8 (v2.api.google.com) + { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ + Ip: []byte{8, 8, 7, 8}, + Prefix: 32, + }, }, }, }, diff --git a/app/dns/hosts.go b/app/dns/hosts.go index fab08d54..c3546967 100644 --- a/app/dns/hosts.go +++ b/app/dns/hosts.go @@ -2,39 +2,28 @@ package dns import ( "context" - "runtime" "strconv" + "strings" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/strmatcher" "github.com/xtls/xray-core/features/dns" ) // StaticHosts represents static domain-ip mapping in DNS server. type StaticHosts struct { - ips [][]net.Address - matchers strmatcher.IndexMatcher + reps [][]net.Address + matcher geodata.DomainMatcher } // NewStaticHosts creates a new StaticHosts instance. func NewStaticHosts(hosts []*Config_HostMapping) (*StaticHosts, error) { - g := new(strmatcher.MatcherGroup) - sh := &StaticHosts{ - ips: make([][]net.Address, len(hosts)+16), - matchers: g, - } + reps := make([][]net.Address, 0, len(hosts)) + rules := make([]*geodata.DomainRule, 0, len(hosts)) - defer runtime.GC() - for i, mapping := range hosts { - hosts[i] = nil - matcher, err := toStrMatcher(mapping.Type, mapping.Domain) - if err != nil { - errors.LogErrorInner(context.Background(), err, "failed to create domain matcher, ignore domain rule [type: ", mapping.Type, ", domain: ", mapping.Domain, "]") - continue - } - id := g.Add(matcher) - ips := make([]net.Address, 0, len(mapping.Ip)+1) + for _, mapping := range hosts { + rep := make([]net.Address, 0, len(mapping.Ip)) switch { case len(mapping.ProxiedDomain) > 0: if mapping.ProxiedDomain[0] == '#' { @@ -42,28 +31,36 @@ func NewStaticHosts(hosts []*Config_HostMapping) (*StaticHosts, error) { if err != nil { return nil, err } - ips = append(ips, dns.RCodeError(rcode)) + rep = append(rep, dns.RCodeError(rcode)) } else { - ips = append(ips, net.DomainAddress(mapping.ProxiedDomain)) + rep = append(rep, net.DomainAddress(mapping.ProxiedDomain)) } case len(mapping.Ip) > 0: for _, ip := range mapping.Ip { addr := net.IPAddress(ip) if addr == nil { - errors.LogError(context.Background(), "invalid IP address in static hosts: ", ip, ", ignore this ip for rule [type: ", mapping.Type, ", domain: ", mapping.Domain, "]") + errors.LogError(context.Background(), "invalid IP address in static hosts: ", ip, ", ignore this ip for rule: ", mapping.Domain) continue } - ips = append(ips, addr) - } - if len(ips) == 0 { - continue + rep = append(rep, addr) } } - - sh.ips[id] = ips + // if len(rep) == 0 { + // errors.LogError(context.Background(), "empty value in static hosts, ignore this rule: ", mapping.Domain) + // continue + // } + reps = append(reps, rep) + rules = append(rules, mapping.Domain) } - return sh, nil + matcher, err := geodata.DomainReg.BuildDomainMatcher(rules) + if err != nil { + return nil, err + } + return &StaticHosts{ + reps: reps, + matcher: matcher, + }, nil } func filterIP(ips []net.Address, option dns.IPOption) []net.Address { @@ -79,16 +76,16 @@ func filterIP(ips []net.Address, option dns.IPOption) []net.Address { func (h *StaticHosts) lookupInternal(domain string) ([]net.Address, error) { ips := make([]net.Address, 0) found := false - for _, id := range h.matchers.Match(domain) { - for _, v := range h.ips[id] { - if err, ok := v.(dns.RCodeError); ok { + for _, ruleIdx := range h.matcher.Match(domain) { + for _, rep := range h.reps[ruleIdx] { + if err, ok := rep.(dns.RCodeError); ok { if uint16(err) == 0 { return nil, dns.ErrEmptyResponse } return nil, err } } - ips = append(ips, h.ips[id]...) + ips = append(ips, h.reps[ruleIdx]...) found = true } if !found { @@ -98,10 +95,13 @@ func (h *StaticHosts) lookupInternal(domain string) ([]net.Address, error) { } func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) ([]net.Address, error) { + domain = strings.ToLower(domain) switch addrs, err := h.lookupInternal(domain); { case err != nil: return nil, err - case len(addrs) == 0: // Not recorded in static hosts, return nil + case addrs == nil: // Not recorded in static hosts, return nil + return nil, nil + case len(addrs) == 0: // Domain recorded, but no valid IP returned return addrs, nil case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Try to unwrap domain errors.LogDebug(context.Background(), "found replaced domain: ", domain, " -> ", addrs[0].Domain(), ". Try to unwrap it") @@ -124,50 +124,3 @@ func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) ( func (h *StaticHosts) Lookup(domain string, option dns.IPOption) ([]net.Address, error) { return h.lookup(domain, option, 5) } -func NewStaticHostsFromCache(matcher strmatcher.IndexMatcher, hostIPs map[string][]string) (*StaticHosts, error) { - sh := &StaticHosts{ - ips: make([][]net.Address, matcher.Size()+1), - matchers: matcher, - } - - order := hostIPs["_ORDER"] - var offset uint32 - - img, ok := matcher.(*strmatcher.IndexMatcherGroup) - if !ok { - // Single matcher (e.g. only manual or only one geosite) - if len(order) > 0 { - pattern := order[0] - ips := parseIPs(hostIPs[pattern]) - for i := uint32(1); i <= matcher.Size(); i++ { - sh.ips[i] = ips - } - } - return sh, nil - } - - for i, m := range img.Matchers { - if i < len(order) { - pattern := order[i] - ips := parseIPs(hostIPs[pattern]) - for j := uint32(1); j <= m.Size(); j++ { - sh.ips[offset+j] = ips - } - offset += m.Size() - } - } - return sh, nil -} - -func parseIPs(raw []string) []net.Address { - addrs := make([]net.Address, 0, len(raw)) - for _, s := range raw { - if len(s) > 1 && s[0] == '#' { - rcode, _ := strconv.Atoi(s[1:]) - addrs = append(addrs, dns.RCodeError(rcode)) - } else { - addrs = append(addrs, net.ParseAddress(s)) - } - } - return addrs -} diff --git a/app/dns/hosts_test.go b/app/dns/hosts_test.go index 2b9c24d8..7ebc3992 100644 --- a/app/dns/hosts_test.go +++ b/app/dns/hosts_test.go @@ -1,13 +1,12 @@ package dns_test import ( - "bytes" "testing" "github.com/google/go-cmp/cmp" . "github.com/xtls/xray-core/app/dns" - "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/features/dns" ) @@ -15,20 +14,17 @@ import ( func TestStaticHosts(t *testing.T) { pb := []*Config_HostMapping{ { - Type: DomainMatchingType_Subdomain, - Domain: "lan", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "lan"}}}, ProxiedDomain: "#3", }, { - Type: DomainMatchingType_Full, - Domain: "example.com", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "example.com"}}}, Ip: [][]byte{ {1, 1, 1, 1}, }, }, { - Type: DomainMatchingType_Full, - Domain: "proxy.xray.com", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "proxy.xray.com"}}}, Ip: [][]byte{ {1, 2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, @@ -36,20 +32,17 @@ func TestStaticHosts(t *testing.T) { ProxiedDomain: "another-proxy.xray.com", }, { - Type: DomainMatchingType_Full, - Domain: "proxy2.xray.com", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "proxy2.xray.com"}}}, ProxiedDomain: "proxy.xray.com", }, { - Type: DomainMatchingType_Subdomain, - Domain: "example.cn", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "example.cn"}}}, Ip: [][]byte{ {2, 2, 2, 2}, }, }, { - Type: DomainMatchingType_Subdomain, - Domain: "baidu.com", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "baidu.com"}}}, Ip: [][]byte{ {127, 0, 0, 1}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, @@ -132,57 +125,3 @@ func TestStaticHosts(t *testing.T) { } } } -func TestStaticHostsFromCache(t *testing.T) { - sites := []*router.GeoSite{ - { - CountryCode: "cloudflare-dns.com", - Domain: []*router.Domain{ - {Type: router.Domain_Full, Value: "example.com"}, - }, - }, - { - CountryCode: "geosite:cn", - Domain: []*router.Domain{ - {Type: router.Domain_Domain, Value: "baidu.cn"}, - }, - }, - } - deps := map[string][]string{ - "HOSTS": {"cloudflare-dns.com", "geosite:cn"}, - } - hostIPs := map[string][]string{ - "cloudflare-dns.com": {"1.1.1.1"}, - "geosite:cn": {"2.2.2.2"}, - "_ORDER": {"cloudflare-dns.com", "geosite:cn"}, - } - - var buf bytes.Buffer - err := router.SerializeGeoSiteList(sites, deps, hostIPs, &buf) - common.Must(err) - - // Load matcher - m, err := router.LoadGeoSiteMatcher(bytes.NewReader(buf.Bytes()), "HOSTS") - common.Must(err) - - // Load hostIPs - f := bytes.NewReader(buf.Bytes()) - hips, err := router.LoadGeoSiteHosts(f) - common.Must(err) - - hosts, err := NewStaticHostsFromCache(m, hips) - common.Must(err) - - { - ips, _ := hosts.Lookup("example.com", dns.IPOption{IPv4Enable: true}) - if len(ips) != 1 || ips[0].String() != "1.1.1.1" { - t.Error("failed to lookup example.com from cache") - } - } - - { - ips, _ := hosts.Lookup("baidu.cn", dns.IPOption{IPv4Enable: true}) - if len(ips) != 1 || ips[0].String() != "2.2.2.2" { - t.Error("failed to lookup baidu.cn from cache deps") - } - } -} diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index 00d435b5..8b003eaa 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -3,34 +3,18 @@ package dns import ( "context" "net/url" - "runtime" "strings" "time" - "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/platform" - "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/common/session" - "github.com/xtls/xray-core/common/strmatcher" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/routing" ) -type mphMatcherWrapper struct { - m strmatcher.IndexMatcher -} - -func (w *mphMatcherWrapper) Match(s string) bool { - return w.m.Match(s) != nil -} - -func (w *mphMatcherWrapper) String() string { - return "mph-matcher" -} - // Server is the interface for Name Server. type Server interface { // Name of the Client. @@ -46,9 +30,8 @@ type Server interface { type Client struct { server Server skipFallback bool - domains []string - expectedIPs router.GeoIPMatcher - unexpectedIPs router.GeoIPMatcher + expectedIPs geodata.IPMatcher + unexpectedIPs geodata.IPMatcher actPrior bool actUnprior bool tag string @@ -111,11 +94,9 @@ func NewClient( disableCache bool, serveStale bool, serveExpiredTTL uint32, tag string, ipOption dns.IPOption, - matcherInfos *[]*DomainMatcherInfo, - updateDomainRule func(strmatcher.Matcher, int, []*DomainMatcherInfo), + updateRules func(bool), ) (*Client, error) { client := &Client{} - err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error { // Create a new server for each client for now server, err := NewServer(ctx, ns.Address.AsDestination(), dispatcher, disableCache, serveStale, serveExpiredTTL, clientIP) @@ -123,97 +104,25 @@ func NewClient( return errors.New("failed to create nameserver").Base(err).AtWarning() } - // Prioritize local domains with specific TLDs or those without any dot for the local DNS - if _, isLocalDNS := server.(*LocalNameServer); isLocalDNS { - ns.PrioritizedDomain = append(ns.PrioritizedDomain, localTLDsAndDotlessDomains...) - ns.OriginalRules = append(ns.OriginalRules, localTLDsAndDotlessDomainsRule) - // The following lines is a solution to avoid core panics(rule index out of range) when setting `localhost` DNS client in config. - // Because the `localhost` DNS client will append len(localTLDsAndDotlessDomains) rules into matcherInfos to match `geosite:private` default rule. - // But `matcherInfos` has no enough length to add rules, which leads to core panics (rule index out of range). - // To avoid this, the length of `matcherInfos` must be equal to the expected, so manually append it with Golang default zero value first for later modification. - // Related issues: - // https://github.com/v2fly/v2ray-core/issues/529 - // https://github.com/v2fly/v2ray-core/issues/719 - for i := 0; i < len(localTLDsAndDotlessDomains); i++ { - *matcherInfos = append(*matcherInfos, &DomainMatcherInfo{ - clientIdx: uint16(0), - domainRuleIdx: uint16(0), - }) - } - } - - // Establish domain rules - var rules []string - ruleCurr := 0 - ruleIter := 0 - - // Check if domain matcher cache is provided via environment - domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) - var mphLoaded bool - - if domainMatcherPath != "" && ns.Tag != "" { - f, err := filesystem.NewFileReader(domainMatcherPath) - if err == nil { - defer f.Close() - g, err := router.LoadGeoSiteMatcher(f, ns.Tag) - if err == nil { - errors.LogDebug(ctx, "MphDomainMatcher loaded from cache for ", ns.Tag, " dns tag)") - updateDomainRule(&mphMatcherWrapper{m: g}, 0, *matcherInfos) - rules = append(rules, "[MPH Cache]") - mphLoaded = true - } - } - } - - if !mphLoaded { - for i, domain := range ns.PrioritizedDomain { - ns.PrioritizedDomain[i] = nil - domainRule, err := toStrMatcher(domain.Type, domain.Domain) - if err != nil { - errors.LogErrorInner(ctx, err, "failed to create domain matcher, ignore domain rule [type: ", domain.Type, ", domain: ", domain.Domain, "]") - domainRule, _ = toStrMatcher(DomainMatchingType_Full, "hack.fix.index.for.illegal.domain.rule") - } - originalRuleIdx := ruleCurr - if ruleCurr < len(ns.OriginalRules) { - rule := ns.OriginalRules[ruleCurr] - if ruleCurr >= len(rules) { - rules = append(rules, rule.Rule) - } - ruleIter++ - if ruleIter >= int(rule.Size) { - ruleIter = 0 - ruleCurr++ - } - } else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests) - rules = append(rules, domainRule.String()) - ruleCurr++ - } - updateDomainRule(domainRule, originalRuleIdx, *matcherInfos) - } - } - ns.PrioritizedDomain = nil - runtime.GC() + _, isLocalDNS := server.(*LocalNameServer) + updateRules(isLocalDNS) // Establish expected IPs - var expectedMatcher router.GeoIPMatcher - if len(ns.ExpectedGeoip) > 0 { - expectedMatcher, err = router.BuildOptimizedGeoIPMatcher(ns.ExpectedGeoip...) + var expectedMatcher geodata.IPMatcher + if len(ns.ExpectedIp) > 0 { + expectedMatcher, err = geodata.IPReg.BuildIPMatcher(ns.ExpectedIp) if err != nil { return errors.New("failed to create expected ip matcher").Base(err).AtWarning() } - ns.ExpectedGeoip = nil - runtime.GC() } // Establish unexpected IPs - var unexpectedMatcher router.GeoIPMatcher - if len(ns.UnexpectedGeoip) > 0 { - unexpectedMatcher, err = router.BuildOptimizedGeoIPMatcher(ns.UnexpectedGeoip...) + var unexpectedMatcher geodata.IPMatcher + if len(ns.UnexpectedIp) > 0 { + unexpectedMatcher, err = geodata.IPReg.BuildIPMatcher(ns.UnexpectedIp) if err != nil { return errors.New("failed to create unexpected ip matcher").Base(err).AtWarning() } - ns.UnexpectedGeoip = nil - runtime.GC() } if len(clientIP) > 0 { @@ -234,7 +143,6 @@ func NewClient( client.server = server client.skipFallback = ns.SkipFallback - client.domains = rules client.expectedIPs = expectedMatcher client.unexpectedIPs = unexpectedMatcher client.actPrior = ns.ActPrior diff --git a/app/router/command/command_test.go b/app/router/command/command_test.go index e8329695..e6706df8 100644 --- a/app/router/command/command_test.go +++ b/app/router/command/command_test.go @@ -12,6 +12,7 @@ import ( . "github.com/xtls/xray-core/app/router/command" "github.com/xtls/xray-core/app/stats" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/testing/mocks" @@ -303,12 +304,12 @@ func TestServiceTestRoute(t *testing.T) { TargetTag: &router.RoutingRule_Tag{Tag: "out"}, }, { - Domain: []*router.Domain{{Type: router.Domain_Domain, Value: "com"}}, + Domain: []*geodata.DomainRule{{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "com"}}}}, TargetTag: &router.RoutingRule_Tag{Tag: "out"}, }, { - SourceGeoip: []*router.GeoIP{{CountryCode: "private", Cidr: []*router.CIDR{{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}}, - TargetTag: &router.RoutingRule_Tag{Tag: "out"}, + SourceIp: []*geodata.IPRule{{Value: &geodata.IPRule_Custom{Custom: &geodata.CIDR{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}}, + TargetTag: &router.RoutingRule_Tag{Tag: "out"}, }, { UserEmail: []string{"example@example.com"}, diff --git a/app/router/condition.go b/app/router/condition.go index a9889a6e..1aae8323 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -2,7 +2,6 @@ package router import ( "context" - "io" "os" "path/filepath" "regexp" @@ -10,8 +9,8 @@ import ( "strings" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/strmatcher" "github.com/xtls/xray-core/features/routing" ) @@ -45,67 +44,18 @@ func (v *ConditionChan) Len() int { return len(*v) } -var matcherTypeMap = map[Domain_Type]strmatcher.Type{ - Domain_Plain: strmatcher.Substr, - Domain_Regex: strmatcher.Regex, - Domain_Domain: strmatcher.Domain, - Domain_Full: strmatcher.Full, -} +type DomainMatcher struct{ geodata.DomainMatcher } -type DomainMatcher struct { - Matchers strmatcher.IndexMatcher -} - -func SerializeDomainMatcher(domains []*Domain, w io.Writer) error { - - g := strmatcher.NewMphMatcherGroup() - for _, d := range domains { - matcherType, f := matcherTypeMap[d.Type] - if !f { - continue - } - - _, err := g.AddPattern(d.Value, matcherType) - if err != nil { - return err - } - } - g.Build() - // serialize - return g.Serialize(w) -} - -func NewDomainMatcherFromBuffer(data []byte) (*strmatcher.MphMatcherGroup, error) { - matcher, err := strmatcher.NewMphMatcherGroupFromBuffer(data) +func NewDomainMatcher(rules []*geodata.DomainRule) (*DomainMatcher, error) { + m, err := geodata.DomainReg.BuildDomainMatcher(rules) if err != nil { return nil, err } - return matcher, nil -} - -func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) { - g := strmatcher.NewMphMatcherGroup() - for i, d := range domains { - domains[i] = nil - matcherType, f := matcherTypeMap[d.Type] - if !f { - errors.LogError(context.Background(), "ignore unsupported domain type ", d.Type, " of rule ", d.Value) - continue - } - _, err := g.AddPattern(d.Value, matcherType) - if err != nil { - errors.LogErrorInner(context.Background(), err, "ignore domain rule ", d.Type, " ", d.Value) - continue - } - } - g.Build() - return &DomainMatcher{ - Matchers: g, - }, nil + return &DomainMatcher{DomainMatcher: m}, nil } func (m *DomainMatcher) ApplyDomain(domain string) bool { - return len(m.Matchers.Match(strings.ToLower(domain))) > 0 + return m.DomainMatcher.MatchAny(strings.ToLower(domain)) } // Apply implements Condition. @@ -114,7 +64,7 @@ func (m *DomainMatcher) Apply(ctx routing.Context) bool { if len(domain) == 0 { return false } - return m.ApplyDomain(domain) + return m.DomainMatcher.MatchAny(strings.ToLower(domain)) } type MatcherAsType byte @@ -127,16 +77,16 @@ const ( ) type IPMatcher struct { - matcher GeoIPMatcher + matcher geodata.IPMatcher asType MatcherAsType } -func NewIPMatcher(geoips []*GeoIP, asType MatcherAsType) (*IPMatcher, error) { - matcher, err := BuildOptimizedGeoIPMatcher(geoips...) +func NewIPMatcher(rules []*geodata.IPRule, asType MatcherAsType) (*IPMatcher, error) { + m, err := geodata.IPReg.BuildIPMatcher(rules) if err != nil { return nil, err } - return &IPMatcher{matcher: matcher, asType: asType}, nil + return &IPMatcher{matcher: m, asType: asType}, nil } // Apply implements Condition. diff --git a/app/router/condition_geoip_test.go b/app/router/condition_geoip_test.go deleted file mode 100644 index b712db9e..00000000 --- a/app/router/condition_geoip_test.go +++ /dev/null @@ -1,266 +0,0 @@ -package router_test - -import ( - "fmt" - "os" - "path/filepath" - "testing" - - "github.com/xtls/xray-core/app/router" - "github.com/xtls/xray-core/common" - "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/platform" - "github.com/xtls/xray-core/common/platform/filesystem" - "google.golang.org/protobuf/proto" -) - -func getAssetPath(file string) (string, error) { - path := platform.GetAssetLocation(file) - _, err := os.Stat(path) - if os.IsNotExist(err) { - path := filepath.Join("..", "..", "resources", file) - _, err := os.Stat(path) - if os.IsNotExist(err) { - return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file) - } - if err != nil { - return "", fmt.Errorf("can't stat %s: %v", path, err) - } - return path, nil - } - if err != nil { - return "", fmt.Errorf("can't stat %s: %v", path, err) - } - - return path, nil -} - -func TestGeoIPMatcher(t *testing.T) { - cidrList := []*router.CIDR{ - {Ip: []byte{0, 0, 0, 0}, Prefix: 8}, - {Ip: []byte{10, 0, 0, 0}, Prefix: 8}, - {Ip: []byte{100, 64, 0, 0}, Prefix: 10}, - {Ip: []byte{127, 0, 0, 0}, Prefix: 8}, - {Ip: []byte{169, 254, 0, 0}, Prefix: 16}, - {Ip: []byte{172, 16, 0, 0}, Prefix: 12}, - {Ip: []byte{192, 0, 0, 0}, Prefix: 24}, - {Ip: []byte{192, 0, 2, 0}, Prefix: 24}, - {Ip: []byte{192, 168, 0, 0}, Prefix: 16}, - {Ip: []byte{192, 18, 0, 0}, Prefix: 15}, - {Ip: []byte{198, 51, 100, 0}, Prefix: 24}, - {Ip: []byte{203, 0, 113, 0}, Prefix: 24}, - {Ip: []byte{8, 8, 8, 8}, Prefix: 32}, - {Ip: []byte{91, 108, 4, 0}, Prefix: 16}, - } - - matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ - Cidr: cidrList, - }) - common.Must(err) - - testCases := []struct { - Input string - Output bool - }{ - { - Input: "192.168.1.1", - Output: true, - }, - { - Input: "192.0.0.0", - Output: true, - }, - { - Input: "192.0.1.0", - Output: false, - }, - { - Input: "0.1.0.0", - Output: true, - }, - { - Input: "1.0.0.1", - Output: false, - }, - { - Input: "8.8.8.7", - Output: false, - }, - { - Input: "8.8.8.8", - Output: true, - }, - { - Input: "2001:cdba::3257:9652", - Output: false, - }, - { - Input: "91.108.255.254", - Output: true, - }, - } - - for _, testCase := range testCases { - ip := net.ParseAddress(testCase.Input).IP() - actual := matcher.Match(ip) - if actual != testCase.Output { - t.Error("expect input", testCase.Input, "to be", testCase.Output, ", but actually", actual) - } - } -} - -func TestGeoIPMatcherRegression(t *testing.T) { - cidrList := []*router.CIDR{ - {Ip: []byte{98, 108, 20, 0}, Prefix: 22}, - {Ip: []byte{98, 108, 20, 0}, Prefix: 23}, - } - - matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ - Cidr: cidrList, - }) - common.Must(err) - - testCases := []struct { - Input string - Output bool - }{ - { - Input: "98.108.22.11", - Output: true, - }, - { - Input: "98.108.25.0", - Output: false, - }, - } - - for _, testCase := range testCases { - ip := net.ParseAddress(testCase.Input).IP() - actual := matcher.Match(ip) - if actual != testCase.Output { - t.Error("expect input", testCase.Input, "to be", testCase.Output, ", but actually", actual) - } - } -} - -func TestGeoIPReverseMatcher(t *testing.T) { - cidrList := []*router.CIDR{ - {Ip: []byte{8, 8, 8, 8}, Prefix: 32}, - {Ip: []byte{91, 108, 4, 0}, Prefix: 16}, - } - matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ - Cidr: cidrList, - }) - common.Must(err) - matcher.SetReverse(true) // Reverse match - - testCases := []struct { - Input string - Output bool - }{ - { - Input: "8.8.8.8", - Output: false, - }, - { - Input: "2001:cdba::3257:9652", - Output: true, - }, - { - Input: "91.108.255.254", - Output: false, - }, - } - - for _, testCase := range testCases { - ip := net.ParseAddress(testCase.Input).IP() - actual := matcher.Match(ip) - if actual != testCase.Output { - t.Error("expect input", testCase.Input, "to be", testCase.Output, ", but actually", actual) - } - } -} - -func TestGeoIPMatcher4CN(t *testing.T) { - ips, err := loadGeoIP("CN") - common.Must(err) - - matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ - Cidr: ips, - }) - common.Must(err) - - if matcher.Match([]byte{8, 8, 8, 8}) { - t.Error("expect CN geoip doesn't contain 8.8.8.8, but actually does") - } -} - -func TestGeoIPMatcher6US(t *testing.T) { - ips, err := loadGeoIP("US") - common.Must(err) - - matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ - Cidr: ips, - }) - common.Must(err) - - if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) { - t.Error("expect US geoip contain 2001:4860:4860::8888, but actually not") - } -} - -func loadGeoIP(country string) ([]*router.CIDR, error) { - path, err := getAssetPath("geoip.dat") - if err != nil { - return nil, err - } - geoipBytes, err := filesystem.ReadFile(path) - if err != nil { - return nil, err - } - - var geoipList router.GeoIPList - if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil { - return nil, err - } - - for _, geoip := range geoipList.Entry { - if geoip.CountryCode == country { - return geoip.Cidr, nil - } - } - - panic("country not found: " + country) -} - -func BenchmarkGeoIPMatcher4CN(b *testing.B) { - ips, err := loadGeoIP("CN") - common.Must(err) - - matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ - Cidr: ips, - }) - common.Must(err) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _ = matcher.Match([]byte{8, 8, 8, 8}) - } -} - -func BenchmarkGeoIPMatcher6US(b *testing.B) { - ips, err := loadGeoIP("US") - common.Must(err) - - matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ - Cidr: ips, - }) - common.Must(err) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _ = matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) - } -} diff --git a/app/router/condition_serialize_test.go b/app/router/condition_serialize_test.go deleted file mode 100644 index 4c6ff464..00000000 --- a/app/router/condition_serialize_test.go +++ /dev/null @@ -1,167 +0,0 @@ -package router_test - -import ( - "bytes" - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" - "github.com/xtls/xray-core/app/router" - "github.com/xtls/xray-core/common/platform/filesystem" -) - -func TestDomainMatcherSerialization(t *testing.T) { - - domains := []*router.Domain{ - {Type: router.Domain_Domain, Value: "google.com"}, - {Type: router.Domain_Domain, Value: "v2ray.com"}, - {Type: router.Domain_Full, Value: "full.example.com"}, - } - - var buf bytes.Buffer - if err := router.SerializeDomainMatcher(domains, &buf); err != nil { - t.Fatalf("Serialize failed: %v", err) - } - - matcher, err := router.NewDomainMatcherFromBuffer(buf.Bytes()) - if err != nil { - t.Fatalf("Deserialize failed: %v", err) - } - - dMatcher := &router.DomainMatcher{ - Matchers: matcher, - } - testCases := []struct { - Input string - Match bool - }{ - {"google.com", true}, - {"maps.google.com", true}, - {"v2ray.com", true}, - {"full.example.com", true}, - - {"example.com", false}, - } - - for _, tc := range testCases { - if res := dMatcher.ApplyDomain(tc.Input); res != tc.Match { - t.Errorf("Match(%s) = %v, want %v", tc.Input, res, tc.Match) - } - } -} - -func TestGeoSiteSerialization(t *testing.T) { - sites := []*router.GeoSite{ - { - CountryCode: "CN", - Domain: []*router.Domain{ - {Type: router.Domain_Domain, Value: "baidu.cn"}, - {Type: router.Domain_Domain, Value: "qq.com"}, - }, - }, - { - CountryCode: "US", - Domain: []*router.Domain{ - {Type: router.Domain_Domain, Value: "google.com"}, - {Type: router.Domain_Domain, Value: "facebook.com"}, - }, - }, - } - - var buf bytes.Buffer - if err := router.SerializeGeoSiteList(sites, nil, nil, &buf); err != nil { - t.Fatalf("SerializeGeoSiteList failed: %v", err) - } - - tmp := t.TempDir() - path := filepath.Join(tmp, "matcher.cache") - - f, err := os.Create(path) - require.NoError(t, err) - _, err = f.Write(buf.Bytes()) - require.NoError(t, err) - f.Close() - - f, err = os.Open(path) - require.NoError(t, err) - defer f.Close() - - require.NoError(t, err) - data, _ := filesystem.ReadFile(path) - - // cn - gp, err := router.LoadGeoSiteMatcher(bytes.NewReader(data), "CN") - if err != nil { - t.Fatalf("LoadGeoSiteMatcher(CN) failed: %v", err) - } - - cnMatcher := &router.DomainMatcher{ - Matchers: gp, - } - - if !cnMatcher.ApplyDomain("baidu.cn") { - t.Error("CN matcher should match baidu.cn") - } - if cnMatcher.ApplyDomain("google.com") { - t.Error("CN matcher should NOT match google.com") - } - - // us - gp, err = router.LoadGeoSiteMatcher(bytes.NewReader(data), "US") - if err != nil { - t.Fatalf("LoadGeoSiteMatcher(US) failed: %v", err) - } - - usMatcher := &router.DomainMatcher{ - Matchers: gp, - } - if !usMatcher.ApplyDomain("google.com") { - t.Error("US matcher should match google.com") - } - if usMatcher.ApplyDomain("baidu.cn") { - t.Error("US matcher should NOT match baidu.cn") - } - - // unknown - _, err = router.LoadGeoSiteMatcher(bytes.NewReader(data), "unknown") - if err == nil { - t.Error("LoadGeoSiteMatcher(unknown) should fail") - } -} -func TestGeoSiteSerializationWithDeps(t *testing.T) { - sites := []*router.GeoSite{ - { - CountryCode: "geosite:cn", - Domain: []*router.Domain{ - {Type: router.Domain_Domain, Value: "baidu.cn"}, - }, - }, - { - CountryCode: "geosite:google@cn", - Domain: []*router.Domain{ - {Type: router.Domain_Domain, Value: "google.cn"}, - }, - }, - { - CountryCode: "rule-1", - Domain: []*router.Domain{ - {Type: router.Domain_Domain, Value: "google.com"}, - }, - }, - } - deps := map[string][]string{ - "rule-1": {"geosite:cn", "geosite:google@cn"}, - } - - var buf bytes.Buffer - err := router.SerializeGeoSiteList(sites, deps, nil, &buf) - require.NoError(t, err) - - matcher, err := router.LoadGeoSiteMatcher(bytes.NewReader(buf.Bytes()), "rule-1") - require.NoError(t, err) - - require.True(t, matcher.Match("google.com") != nil) - require.True(t, matcher.Match("baidu.cn") != nil) - require.True(t, matcher.Match("google.cn") != nil) -} diff --git a/app/router/condition_test.go b/app/router/condition_test.go index 1272aef6..9e57aa91 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -1,20 +1,19 @@ package router_test import ( + "path/filepath" "strconv" "testing" . "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" - "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol/http" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/features/routing" routing_session "github.com/xtls/xray-core/features/routing/session" - "google.golang.org/protobuf/proto" ) func withBackground() routing.Context { @@ -45,18 +44,15 @@ func TestRoutingRule(t *testing.T) { }{ { rule: &RoutingRule{ - Domain: []*Domain{ + Domain: []*geodata.DomainRule{ { - Value: "example.com", - Type: Domain_Plain, + Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Substr, Value: "example.com"}}, }, { - Value: "google.com", - Type: Domain_Domain, + Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "google.com"}}, }, { - Value: "^facebook\\.com$", - Type: Domain_Regex, + Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Regex, Value: "^facebook\\.com$"}}, }, }, }, @@ -93,18 +89,26 @@ func TestRoutingRule(t *testing.T) { }, { rule: &RoutingRule{ - Geoip: []*GeoIP{ + Ip: []*geodata.IPRule{ { - Cidr: []*CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{8, 8, 8, 8}, Prefix: 32, }, - { + }, + }, + { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{8, 8, 8, 8}, Prefix: 32, }, - { + }, + }, + { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: net.ParseAddress("2001:0db8:85a3:0000:0000:8a2e:0370:7334").IP(), Prefix: 128, }, @@ -133,10 +137,10 @@ func TestRoutingRule(t *testing.T) { }, { rule: &RoutingRule{ - SourceGeoip: []*GeoIP{ + SourceIp: []*geodata.IPRule{ { - Cidr: []*CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{192, 168, 0, 0}, Prefix: 16, }, @@ -300,35 +304,12 @@ func TestRoutingRule(t *testing.T) { } } -func loadGeoSite(country string) ([]*Domain, error) { - path, err := getAssetPath("geosite.dat") - if err != nil { - return nil, err - } - geositeBytes, err := filesystem.ReadFile(path) - if err != nil { - return nil, err - } - - var geositeList GeoSiteList - if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil { - return nil, err - } - - for _, site := range geositeList.Entry { - if site.CountryCode == country { - return site.Domain, nil - } - } - - return nil, errors.New("country not found: " + country) -} - func TestChinaSites(t *testing.T) { - domains, err := loadGeoSite("CN") + t.Setenv("xray.location.asset", filepath.Join("..", "..", "resources")) + rules, err := geodata.ParseDomainRules([]string{"geosite:cn"}, geodata.Domain_Substr) common.Must(err) - acMatcher, err := NewMphMatcherGroup(domains) + matcher, err := NewDomainMatcher(rules) common.Must(err) type TestCase struct { @@ -359,18 +340,19 @@ func TestChinaSites(t *testing.T) { } for _, testCase := range testCases { - r := acMatcher.ApplyDomain(testCase.Domain) + r := matcher.ApplyDomain(testCase.Domain) if r != testCase.Output { - t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r) + t.Error("DomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r) } } } func BenchmarkMphDomainMatcher(b *testing.B) { - domains, err := loadGeoSite("CN") + b.Setenv("xray.location.asset", filepath.Join("..", "..", "resources")) + rules, err := geodata.ParseDomainRules([]string{"geosite:cn"}, geodata.Domain_Substr) common.Must(err) - matcher, err := NewMphMatcherGroup(domains) + matcher, err := NewDomainMatcher(rules) common.Must(err) type TestCase struct { @@ -409,45 +391,11 @@ func BenchmarkMphDomainMatcher(b *testing.B) { } func BenchmarkMultiGeoIPMatcher(b *testing.B) { - var geoips []*GeoIP + b.Setenv("xray.location.asset", filepath.Join("..", "..", "resources")) + rules, err := geodata.ParseIPRules([]string{"geoip:cn", "geoip:jp", "geoip:ca", "geoip:us"}) + common.Must(err) - { - ips, err := loadGeoIP("CN") - common.Must(err) - geoips = append(geoips, &GeoIP{ - CountryCode: "CN", - Cidr: ips, - }) - } - - { - ips, err := loadGeoIP("JP") - common.Must(err) - geoips = append(geoips, &GeoIP{ - CountryCode: "JP", - Cidr: ips, - }) - } - - { - ips, err := loadGeoIP("CA") - common.Must(err) - geoips = append(geoips, &GeoIP{ - CountryCode: "CA", - Cidr: ips, - }) - } - - { - ips, err := loadGeoIP("US") - common.Must(err) - geoips = append(geoips, &GeoIP{ - CountryCode: "US", - Cidr: ips, - }) - } - - matcher, err := NewIPMatcher(geoips, MatcherAsType_Target) + matcher, err := NewIPMatcher(rules, MatcherAsType_Target) common.Must(err) ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.ParseAddress("8.8.8.8"), 80)}) diff --git a/app/router/config.go b/app/router/config.go index 4acbaf41..0a76905c 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -3,12 +3,9 @@ package router import ( "context" "regexp" - "runtime" "strings" "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/common/platform" - "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/features/outbound" "github.com/xtls/xray-core/features/routing" ) @@ -76,60 +73,37 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { conds.Add(&AttributeMatcher{configuredKeys}) } - if len(rr.Geoip) > 0 { - cond, err := NewIPMatcher(rr.Geoip, MatcherAsType_Target) + if len(rr.Ip) > 0 { + cond, err := NewIPMatcher(rr.Ip, MatcherAsType_Target) if err != nil { return nil, err } conds.Add(cond) - rr.Geoip = nil - runtime.GC() } - if len(rr.SourceGeoip) > 0 { - cond, err := NewIPMatcher(rr.SourceGeoip, MatcherAsType_Source) + if len(rr.SourceIp) > 0 { + cond, err := NewIPMatcher(rr.SourceIp, MatcherAsType_Source) if err != nil { return nil, err } conds.Add(cond) - rr.SourceGeoip = nil - runtime.GC() } - if len(rr.LocalGeoip) > 0 { - cond, err := NewIPMatcher(rr.LocalGeoip, MatcherAsType_Local) + if len(rr.LocalIp) > 0 { + cond, err := NewIPMatcher(rr.LocalIp, MatcherAsType_Local) if err != nil { return nil, err } conds.Add(cond) errors.LogWarning(context.Background(), "Due to some limitations, in UDP connections, localIP is always equal to listen interface IP, so \"localIP\" rule condition does not work properly on UDP inbound connections that listen on all interfaces") - rr.LocalGeoip = nil - runtime.GC() } if len(rr.Domain) > 0 { - var matcher *DomainMatcher - var err error - // Check if domain matcher cache is provided via environment - domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) - - if domainMatcherPath != "" { - matcher, err = GetDomainMatcherWithRuleTag(domainMatcherPath, rr.RuleTag) - if err != nil { - return nil, errors.New("failed to build domain condition from cached MphDomainMatcher").Base(err) - } - errors.LogDebug(context.Background(), "MphDomainMatcher loaded from cache for ", rr.RuleTag, " rule tag)") - - } else { - matcher, err = NewMphMatcherGroup(rr.Domain) - if err != nil { - return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err) - } - errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)") + cond, err := NewDomainMatcher(rr.Domain) + if err != nil { + return nil, err } - conds.Add(matcher) - rr.Domain = nil - runtime.GC() + conds.Add(cond) } if len(rr.Process) > 0 { @@ -189,20 +163,3 @@ func (br *BalancingRule) Build(ohm outbound.Manager, dispatcher routing.Dispatch return nil, errors.New("unrecognized balancer type") } } - -func GetDomainMatcherWithRuleTag(domainMatcherPath string, ruleTag string) (*DomainMatcher, error) { - f, err := filesystem.NewFileReader(domainMatcherPath) - if err != nil { - return nil, errors.New("failed to load file: ", domainMatcherPath).Base(err) - } - defer f.Close() - - g, err := LoadGeoSiteMatcher(f, ruleTag) - if err != nil { - return nil, errors.New("failed to load file:", domainMatcherPath).Base(err) - } - return &DomainMatcher{ - Matchers: g, - }, nil - -} diff --git a/app/router/config.pb.go b/app/router/config.pb.go index 40676024..6c1e2750 100644 --- a/app/router/config.pb.go +++ b/app/router/config.pb.go @@ -7,6 +7,7 @@ package router import ( + geodata "github.com/xtls/xray-core/common/geodata" net "github.com/xtls/xray-core/common/net" serial "github.com/xtls/xray-core/common/serial" protoreflect "google.golang.org/protobuf/reflect/protoreflect" @@ -23,63 +24,6 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// Type of domain value. -type Domain_Type int32 - -const ( - // The value is used as is. - Domain_Plain Domain_Type = 0 - // The value is used as a regular expression. - Domain_Regex Domain_Type = 1 - // The value is a root domain. - Domain_Domain Domain_Type = 2 - // The value is a domain. - Domain_Full Domain_Type = 3 -) - -// Enum value maps for Domain_Type. -var ( - Domain_Type_name = map[int32]string{ - 0: "Plain", - 1: "Regex", - 2: "Domain", - 3: "Full", - } - Domain_Type_value = map[string]int32{ - "Plain": 0, - "Regex": 1, - "Domain": 2, - "Full": 3, - } -) - -func (x Domain_Type) Enum() *Domain_Type { - p := new(Domain_Type) - *p = x - return p -} - -func (x Domain_Type) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (Domain_Type) Descriptor() protoreflect.EnumDescriptor { - return file_app_router_config_proto_enumTypes[0].Descriptor() -} - -func (Domain_Type) Type() protoreflect.EnumType { - return &file_app_router_config_proto_enumTypes[0] -} - -func (x Domain_Type) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use Domain_Type.Descriptor instead. -func (Domain_Type) EnumDescriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{0, 0} -} - type Config_DomainStrategy int32 const ( @@ -116,11 +60,11 @@ func (x Config_DomainStrategy) String() string { } func (Config_DomainStrategy) Descriptor() protoreflect.EnumDescriptor { - return file_app_router_config_proto_enumTypes[1].Descriptor() + return file_app_router_config_proto_enumTypes[0].Descriptor() } func (Config_DomainStrategy) Type() protoreflect.EnumType { - return &file_app_router_config_proto_enumTypes[1] + return &file_app_router_config_proto_enumTypes[0] } func (x Config_DomainStrategy) Number() protoreflect.EnumNumber { @@ -129,326 +73,7 @@ func (x Config_DomainStrategy) Number() protoreflect.EnumNumber { // Deprecated: Use Config_DomainStrategy.Descriptor instead. func (Config_DomainStrategy) EnumDescriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{11, 0} -} - -// Domain for routing decision. -type Domain struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Domain matching type. - Type Domain_Type `protobuf:"varint,1,opt,name=type,proto3,enum=xray.app.router.Domain_Type" json:"type,omitempty"` - // Domain value. - Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` - // Attributes of this domain. May be used for filtering. - Attribute []*Domain_Attribute `protobuf:"bytes,3,rep,name=attribute,proto3" json:"attribute,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *Domain) Reset() { - *x = Domain{} - mi := &file_app_router_config_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *Domain) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Domain) ProtoMessage() {} - -func (x *Domain) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Domain.ProtoReflect.Descriptor instead. -func (*Domain) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{0} -} - -func (x *Domain) GetType() Domain_Type { - if x != nil { - return x.Type - } - return Domain_Plain -} - -func (x *Domain) GetValue() string { - if x != nil { - return x.Value - } - return "" -} - -func (x *Domain) GetAttribute() []*Domain_Attribute { - if x != nil { - return x.Attribute - } - return nil -} - -// IP for routing decision, in CIDR form. -type CIDR struct { - state protoimpl.MessageState `protogen:"open.v1"` - // IP address, should be either 4 or 16 bytes. - Ip []byte `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"` - // Number of leading ones in the network mask. - Prefix uint32 `protobuf:"varint,2,opt,name=prefix,proto3" json:"prefix,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *CIDR) Reset() { - *x = CIDR{} - mi := &file_app_router_config_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *CIDR) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*CIDR) ProtoMessage() {} - -func (x *CIDR) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[1] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use CIDR.ProtoReflect.Descriptor instead. -func (*CIDR) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{1} -} - -func (x *CIDR) GetIp() []byte { - if x != nil { - return x.Ip - } - return nil -} - -func (x *CIDR) GetPrefix() uint32 { - if x != nil { - return x.Prefix - } - return 0 -} - -type GeoIP struct { - state protoimpl.MessageState `protogen:"open.v1"` - CountryCode string `protobuf:"bytes,1,opt,name=country_code,json=countryCode,proto3" json:"country_code,omitempty"` - Cidr []*CIDR `protobuf:"bytes,2,rep,name=cidr,proto3" json:"cidr,omitempty"` - ReverseMatch bool `protobuf:"varint,3,opt,name=reverse_match,json=reverseMatch,proto3" json:"reverse_match,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *GeoIP) Reset() { - *x = GeoIP{} - mi := &file_app_router_config_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *GeoIP) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GeoIP) ProtoMessage() {} - -func (x *GeoIP) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[2] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GeoIP.ProtoReflect.Descriptor instead. -func (*GeoIP) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{2} -} - -func (x *GeoIP) GetCountryCode() string { - if x != nil { - return x.CountryCode - } - return "" -} - -func (x *GeoIP) GetCidr() []*CIDR { - if x != nil { - return x.Cidr - } - return nil -} - -func (x *GeoIP) GetReverseMatch() bool { - if x != nil { - return x.ReverseMatch - } - return false -} - -type GeoIPList struct { - state protoimpl.MessageState `protogen:"open.v1"` - Entry []*GeoIP `protobuf:"bytes,1,rep,name=entry,proto3" json:"entry,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *GeoIPList) Reset() { - *x = GeoIPList{} - mi := &file_app_router_config_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *GeoIPList) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GeoIPList) ProtoMessage() {} - -func (x *GeoIPList) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[3] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GeoIPList.ProtoReflect.Descriptor instead. -func (*GeoIPList) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{3} -} - -func (x *GeoIPList) GetEntry() []*GeoIP { - if x != nil { - return x.Entry - } - return nil -} - -type GeoSite struct { - state protoimpl.MessageState `protogen:"open.v1"` - CountryCode string `protobuf:"bytes,1,opt,name=country_code,json=countryCode,proto3" json:"country_code,omitempty"` - Domain []*Domain `protobuf:"bytes,2,rep,name=domain,proto3" json:"domain,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *GeoSite) Reset() { - *x = GeoSite{} - mi := &file_app_router_config_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *GeoSite) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GeoSite) ProtoMessage() {} - -func (x *GeoSite) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[4] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GeoSite.ProtoReflect.Descriptor instead. -func (*GeoSite) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{4} -} - -func (x *GeoSite) GetCountryCode() string { - if x != nil { - return x.CountryCode - } - return "" -} - -func (x *GeoSite) GetDomain() []*Domain { - if x != nil { - return x.Domain - } - return nil -} - -type GeoSiteList struct { - state protoimpl.MessageState `protogen:"open.v1"` - Entry []*GeoSite `protobuf:"bytes,1,rep,name=entry,proto3" json:"entry,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *GeoSiteList) Reset() { - *x = GeoSiteList{} - mi := &file_app_router_config_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *GeoSiteList) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GeoSiteList) ProtoMessage() {} - -func (x *GeoSiteList) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[5] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GeoSiteList.ProtoReflect.Descriptor instead. -func (*GeoSiteList) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{5} -} - -func (x *GeoSiteList) GetEntry() []*GeoSite { - if x != nil { - return x.Entry - } - return nil + return file_app_router_config_proto_rawDescGZIP(), []int{5, 0} } type RoutingRule struct { @@ -460,37 +85,35 @@ type RoutingRule struct { TargetTag isRoutingRule_TargetTag `protobuf_oneof:"target_tag"` RuleTag string `protobuf:"bytes,19,opt,name=rule_tag,json=ruleTag,proto3" json:"rule_tag,omitempty"` // List of domains for target domain matching. - Domain []*Domain `protobuf:"bytes,2,rep,name=domain,proto3" json:"domain,omitempty"` - // List of GeoIPs for target IP address matching. If this entry exists, the - // cidr above will have no effect. GeoIP fields with the same country code are - // supposed to contain exactly same content. They will be merged during - // runtime. For customized GeoIPs, please leave country code empty. - Geoip []*GeoIP `protobuf:"bytes,10,rep,name=geoip,proto3" json:"geoip,omitempty"` - // List of ports. + Domain []*geodata.DomainRule `protobuf:"bytes,2,rep,name=domain,proto3" json:"domain,omitempty"` + // List of IPs for target IP address matching. + Ip []*geodata.IPRule `protobuf:"bytes,10,rep,name=ip,proto3" json:"ip,omitempty"` + // List of ports for target port matching. PortList *net.PortList `protobuf:"bytes,14,opt,name=port_list,json=portList,proto3" json:"port_list,omitempty"` // List of networks for matching. Networks []net.Network `protobuf:"varint,13,rep,packed,name=networks,proto3,enum=xray.common.net.Network" json:"networks,omitempty"` - // List of GeoIPs for source IP address matching. If this entry exists, the - // source_cidr above will have no effect. - SourceGeoip []*GeoIP `protobuf:"bytes,11,rep,name=source_geoip,json=sourceGeoip,proto3" json:"source_geoip,omitempty"` + // List of IPs for source IP address matching. + SourceIp []*geodata.IPRule `protobuf:"bytes,11,rep,name=source_ip,json=sourceIp,proto3" json:"source_ip,omitempty"` // List of ports for source port matching. SourcePortList *net.PortList `protobuf:"bytes,16,opt,name=source_port_list,json=sourcePortList,proto3" json:"source_port_list,omitempty"` UserEmail []string `protobuf:"bytes,7,rep,name=user_email,json=userEmail,proto3" json:"user_email,omitempty"` InboundTag []string `protobuf:"bytes,8,rep,name=inbound_tag,json=inboundTag,proto3" json:"inbound_tag,omitempty"` Protocol []string `protobuf:"bytes,9,rep,name=protocol,proto3" json:"protocol,omitempty"` Attributes map[string]string `protobuf:"bytes,15,rep,name=attributes,proto3" json:"attributes,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` - LocalGeoip []*GeoIP `protobuf:"bytes,17,rep,name=local_geoip,json=localGeoip,proto3" json:"local_geoip,omitempty"` - LocalPortList *net.PortList `protobuf:"bytes,18,opt,name=local_port_list,json=localPortList,proto3" json:"local_port_list,omitempty"` - VlessRouteList *net.PortList `protobuf:"bytes,20,opt,name=vless_route_list,json=vlessRouteList,proto3" json:"vless_route_list,omitempty"` - Process []string `protobuf:"bytes,21,rep,name=process,proto3" json:"process,omitempty"` - Webhook *WebhookConfig `protobuf:"bytes,22,opt,name=webhook,proto3" json:"webhook,omitempty"` + // List of IPs for local IP address matching. + LocalIp []*geodata.IPRule `protobuf:"bytes,17,rep,name=local_ip,json=localIp,proto3" json:"local_ip,omitempty"` + // List of ports for local port matching. + LocalPortList *net.PortList `protobuf:"bytes,18,opt,name=local_port_list,json=localPortList,proto3" json:"local_port_list,omitempty"` + VlessRouteList *net.PortList `protobuf:"bytes,20,opt,name=vless_route_list,json=vlessRouteList,proto3" json:"vless_route_list,omitempty"` + Process []string `protobuf:"bytes,21,rep,name=process,proto3" json:"process,omitempty"` + Webhook *WebhookConfig `protobuf:"bytes,22,opt,name=webhook,proto3" json:"webhook,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *RoutingRule) Reset() { *x = RoutingRule{} - mi := &file_app_router_config_proto_msgTypes[6] + mi := &file_app_router_config_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -502,7 +125,7 @@ func (x *RoutingRule) String() string { func (*RoutingRule) ProtoMessage() {} func (x *RoutingRule) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[6] + mi := &file_app_router_config_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -515,7 +138,7 @@ func (x *RoutingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use RoutingRule.ProtoReflect.Descriptor instead. func (*RoutingRule) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{6} + return file_app_router_config_proto_rawDescGZIP(), []int{0} } func (x *RoutingRule) GetTargetTag() isRoutingRule_TargetTag { @@ -550,16 +173,16 @@ func (x *RoutingRule) GetRuleTag() string { return "" } -func (x *RoutingRule) GetDomain() []*Domain { +func (x *RoutingRule) GetDomain() []*geodata.DomainRule { if x != nil { return x.Domain } return nil } -func (x *RoutingRule) GetGeoip() []*GeoIP { +func (x *RoutingRule) GetIp() []*geodata.IPRule { if x != nil { - return x.Geoip + return x.Ip } return nil } @@ -578,9 +201,9 @@ func (x *RoutingRule) GetNetworks() []net.Network { return nil } -func (x *RoutingRule) GetSourceGeoip() []*GeoIP { +func (x *RoutingRule) GetSourceIp() []*geodata.IPRule { if x != nil { - return x.SourceGeoip + return x.SourceIp } return nil } @@ -620,9 +243,9 @@ func (x *RoutingRule) GetAttributes() map[string]string { return nil } -func (x *RoutingRule) GetLocalGeoip() []*GeoIP { +func (x *RoutingRule) GetLocalIp() []*geodata.IPRule { if x != nil { - return x.LocalGeoip + return x.LocalIp } return nil } @@ -684,7 +307,7 @@ type WebhookConfig struct { func (x *WebhookConfig) Reset() { *x = WebhookConfig{} - mi := &file_app_router_config_proto_msgTypes[7] + mi := &file_app_router_config_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -696,7 +319,7 @@ func (x *WebhookConfig) String() string { func (*WebhookConfig) ProtoMessage() {} func (x *WebhookConfig) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[7] + mi := &file_app_router_config_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -709,7 +332,7 @@ func (x *WebhookConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use WebhookConfig.ProtoReflect.Descriptor instead. func (*WebhookConfig) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{7} + return file_app_router_config_proto_rawDescGZIP(), []int{1} } func (x *WebhookConfig) GetUrl() string { @@ -746,7 +369,7 @@ type BalancingRule struct { func (x *BalancingRule) Reset() { *x = BalancingRule{} - mi := &file_app_router_config_proto_msgTypes[8] + mi := &file_app_router_config_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -758,7 +381,7 @@ func (x *BalancingRule) String() string { func (*BalancingRule) ProtoMessage() {} func (x *BalancingRule) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[8] + mi := &file_app_router_config_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -771,7 +394,7 @@ func (x *BalancingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use BalancingRule.ProtoReflect.Descriptor instead. func (*BalancingRule) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{8} + return file_app_router_config_proto_rawDescGZIP(), []int{2} } func (x *BalancingRule) GetTag() string { @@ -820,7 +443,7 @@ type StrategyWeight struct { func (x *StrategyWeight) Reset() { *x = StrategyWeight{} - mi := &file_app_router_config_proto_msgTypes[9] + mi := &file_app_router_config_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -832,7 +455,7 @@ func (x *StrategyWeight) String() string { func (*StrategyWeight) ProtoMessage() {} func (x *StrategyWeight) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[9] + mi := &file_app_router_config_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -845,7 +468,7 @@ func (x *StrategyWeight) ProtoReflect() protoreflect.Message { // Deprecated: Use StrategyWeight.ProtoReflect.Descriptor instead. func (*StrategyWeight) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{9} + return file_app_router_config_proto_rawDescGZIP(), []int{3} } func (x *StrategyWeight) GetRegexp() bool { @@ -887,7 +510,7 @@ type StrategyLeastLoadConfig struct { func (x *StrategyLeastLoadConfig) Reset() { *x = StrategyLeastLoadConfig{} - mi := &file_app_router_config_proto_msgTypes[10] + mi := &file_app_router_config_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -899,7 +522,7 @@ func (x *StrategyLeastLoadConfig) String() string { func (*StrategyLeastLoadConfig) ProtoMessage() {} func (x *StrategyLeastLoadConfig) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[10] + mi := &file_app_router_config_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -912,7 +535,7 @@ func (x *StrategyLeastLoadConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use StrategyLeastLoadConfig.ProtoReflect.Descriptor instead. func (*StrategyLeastLoadConfig) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{10} + return file_app_router_config_proto_rawDescGZIP(), []int{4} } func (x *StrategyLeastLoadConfig) GetCosts() []*StrategyWeight { @@ -961,7 +584,7 @@ type Config struct { func (x *Config) Reset() { *x = Config{} - mi := &file_app_router_config_proto_msgTypes[11] + mi := &file_app_router_config_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -973,7 +596,7 @@ func (x *Config) String() string { func (*Config) ProtoMessage() {} func (x *Config) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[11] + mi := &file_app_router_config_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -986,7 +609,7 @@ func (x *Config) ProtoReflect() protoreflect.Message { // Deprecated: Use Config.ProtoReflect.Descriptor instead. func (*Config) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{11} + return file_app_router_config_proto_rawDescGZIP(), []int{5} } func (x *Config) GetDomainStrategy() Config_DomainStrategy { @@ -1010,141 +633,21 @@ func (x *Config) GetBalancingRule() []*BalancingRule { return nil } -type Domain_Attribute struct { - state protoimpl.MessageState `protogen:"open.v1"` - Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` - // Types that are valid to be assigned to TypedValue: - // - // *Domain_Attribute_BoolValue - // *Domain_Attribute_IntValue - TypedValue isDomain_Attribute_TypedValue `protobuf_oneof:"typed_value"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *Domain_Attribute) Reset() { - *x = Domain_Attribute{} - mi := &file_app_router_config_proto_msgTypes[12] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *Domain_Attribute) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Domain_Attribute) ProtoMessage() {} - -func (x *Domain_Attribute) ProtoReflect() protoreflect.Message { - mi := &file_app_router_config_proto_msgTypes[12] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Domain_Attribute.ProtoReflect.Descriptor instead. -func (*Domain_Attribute) Descriptor() ([]byte, []int) { - return file_app_router_config_proto_rawDescGZIP(), []int{0, 0} -} - -func (x *Domain_Attribute) GetKey() string { - if x != nil { - return x.Key - } - return "" -} - -func (x *Domain_Attribute) GetTypedValue() isDomain_Attribute_TypedValue { - if x != nil { - return x.TypedValue - } - return nil -} - -func (x *Domain_Attribute) GetBoolValue() bool { - if x != nil { - if x, ok := x.TypedValue.(*Domain_Attribute_BoolValue); ok { - return x.BoolValue - } - } - return false -} - -func (x *Domain_Attribute) GetIntValue() int64 { - if x != nil { - if x, ok := x.TypedValue.(*Domain_Attribute_IntValue); ok { - return x.IntValue - } - } - return 0 -} - -type isDomain_Attribute_TypedValue interface { - isDomain_Attribute_TypedValue() -} - -type Domain_Attribute_BoolValue struct { - BoolValue bool `protobuf:"varint,2,opt,name=bool_value,json=boolValue,proto3,oneof"` -} - -type Domain_Attribute_IntValue struct { - IntValue int64 `protobuf:"varint,3,opt,name=int_value,json=intValue,proto3,oneof"` -} - -func (*Domain_Attribute_BoolValue) isDomain_Attribute_TypedValue() {} - -func (*Domain_Attribute_IntValue) isDomain_Attribute_TypedValue() {} - var File_app_router_config_proto protoreflect.FileDescriptor const file_app_router_config_proto_rawDesc = "" + "\n" + - "\x17app/router/config.proto\x12\x0fxray.app.router\x1a!common/serial/typed_message.proto\x1a\x15common/net/port.proto\x1a\x18common/net/network.proto\"\xb3\x02\n" + - "\x06Domain\x120\n" + - "\x04type\x18\x01 \x01(\x0e2\x1c.xray.app.router.Domain.TypeR\x04type\x12\x14\n" + - "\x05value\x18\x02 \x01(\tR\x05value\x12?\n" + - "\tattribute\x18\x03 \x03(\v2!.xray.app.router.Domain.AttributeR\tattribute\x1al\n" + - "\tAttribute\x12\x10\n" + - "\x03key\x18\x01 \x01(\tR\x03key\x12\x1f\n" + - "\n" + - "bool_value\x18\x02 \x01(\bH\x00R\tboolValue\x12\x1d\n" + - "\tint_value\x18\x03 \x01(\x03H\x00R\bintValueB\r\n" + - "\vtyped_value\"2\n" + - "\x04Type\x12\t\n" + - "\x05Plain\x10\x00\x12\t\n" + - "\x05Regex\x10\x01\x12\n" + - "\n" + - "\x06Domain\x10\x02\x12\b\n" + - "\x04Full\x10\x03\".\n" + - "\x04CIDR\x12\x0e\n" + - "\x02ip\x18\x01 \x01(\fR\x02ip\x12\x16\n" + - "\x06prefix\x18\x02 \x01(\rR\x06prefix\"z\n" + - "\x05GeoIP\x12!\n" + - "\fcountry_code\x18\x01 \x01(\tR\vcountryCode\x12)\n" + - "\x04cidr\x18\x02 \x03(\v2\x15.xray.app.router.CIDRR\x04cidr\x12#\n" + - "\rreverse_match\x18\x03 \x01(\bR\freverseMatch\"9\n" + - "\tGeoIPList\x12,\n" + - "\x05entry\x18\x01 \x03(\v2\x16.xray.app.router.GeoIPR\x05entry\"]\n" + - "\aGeoSite\x12!\n" + - "\fcountry_code\x18\x01 \x01(\tR\vcountryCode\x12/\n" + - "\x06domain\x18\x02 \x03(\v2\x17.xray.app.router.DomainR\x06domain\"=\n" + - "\vGeoSiteList\x12.\n" + - "\x05entry\x18\x01 \x03(\v2\x18.xray.app.router.GeoSiteR\x05entry\"\xbc\a\n" + + "\x17app/router/config.proto\x12\x0fxray.app.router\x1a!common/serial/typed_message.proto\x1a\x15common/net/port.proto\x1a\x18common/net/network.proto\x1a\x1bcommon/geodata/geodat.proto\"\xc1\a\n" + "\vRoutingRule\x12\x12\n" + "\x03tag\x18\x01 \x01(\tH\x00R\x03tag\x12%\n" + "\rbalancing_tag\x18\f \x01(\tH\x00R\fbalancingTag\x12\x19\n" + - "\brule_tag\x18\x13 \x01(\tR\aruleTag\x12/\n" + - "\x06domain\x18\x02 \x03(\v2\x17.xray.app.router.DomainR\x06domain\x12,\n" + - "\x05geoip\x18\n" + - " \x03(\v2\x16.xray.app.router.GeoIPR\x05geoip\x126\n" + + "\brule_tag\x18\x13 \x01(\tR\aruleTag\x127\n" + + "\x06domain\x18\x02 \x03(\v2\x1f.xray.common.geodata.DomainRuleR\x06domain\x12+\n" + + "\x02ip\x18\n" + + " \x03(\v2\x1b.xray.common.geodata.IPRuleR\x02ip\x126\n" + "\tport_list\x18\x0e \x01(\v2\x19.xray.common.net.PortListR\bportList\x124\n" + - "\bnetworks\x18\r \x03(\x0e2\x18.xray.common.net.NetworkR\bnetworks\x129\n" + - "\fsource_geoip\x18\v \x03(\v2\x16.xray.app.router.GeoIPR\vsourceGeoip\x12C\n" + + "\bnetworks\x18\r \x03(\x0e2\x18.xray.common.net.NetworkR\bnetworks\x128\n" + + "\tsource_ip\x18\v \x03(\v2\x1b.xray.common.geodata.IPRuleR\bsourceIp\x12C\n" + "\x10source_port_list\x18\x10 \x01(\v2\x19.xray.common.net.PortListR\x0esourcePortList\x12\x1d\n" + "\n" + "user_email\x18\a \x03(\tR\tuserEmail\x12\x1f\n" + @@ -1153,9 +656,8 @@ const file_app_router_config_proto_rawDesc = "" + "\bprotocol\x18\t \x03(\tR\bprotocol\x12L\n" + "\n" + "attributes\x18\x0f \x03(\v2,.xray.app.router.RoutingRule.AttributesEntryR\n" + - "attributes\x127\n" + - "\vlocal_geoip\x18\x11 \x03(\v2\x16.xray.app.router.GeoIPR\n" + - "localGeoip\x12A\n" + + "attributes\x126\n" + + "\blocal_ip\x18\x11 \x03(\v2\x1b.xray.common.geodata.IPRuleR\alocalIp\x12A\n" + "\x0flocal_port_list\x18\x12 \x01(\v2\x19.xray.common.net.PortListR\rlocalPortList\x12C\n" + "\x10vless_route_list\x18\x14 \x01(\v2\x19.xray.common.net.PortListR\x0evlessRouteList\x12\x18\n" + "\aprocess\x18\x15 \x03(\tR\aprocess\x128\n" + @@ -1187,16 +689,16 @@ const file_app_router_config_proto_rawDesc = "" + "\tbaselines\x18\x03 \x03(\x03R\tbaselines\x12\x1a\n" + "\bexpected\x18\x04 \x01(\x05R\bexpected\x12\x16\n" + "\x06maxRTT\x18\x05 \x01(\x03R\x06maxRTT\x12\x1c\n" + - "\ttolerance\x18\x06 \x01(\x02R\ttolerance\"\x90\x02\n" + + "\ttolerance\x18\x06 \x01(\x02R\ttolerance\"\x96\x02\n" + "\x06Config\x12O\n" + "\x0fdomain_strategy\x18\x01 \x01(\x0e2&.xray.app.router.Config.DomainStrategyR\x0edomainStrategy\x120\n" + "\x04rule\x18\x02 \x03(\v2\x1c.xray.app.router.RoutingRuleR\x04rule\x12E\n" + - "\x0ebalancing_rule\x18\x03 \x03(\v2\x1e.xray.app.router.BalancingRuleR\rbalancingRule\"<\n" + + "\x0ebalancing_rule\x18\x03 \x03(\v2\x1e.xray.app.router.BalancingRuleR\rbalancingRule\"B\n" + "\x0eDomainStrategy\x12\b\n" + "\x04AsIs\x10\x00\x12\x10\n" + "\fIpIfNonMatch\x10\x02\x12\x0e\n" + "\n" + - "IpOnDemand\x10\x03BO\n" + + "IpOnDemand\x10\x03\"\x04\b\x01\x10\x01BO\n" + "\x13com.xray.app.routerP\x01Z$github.com/xtls/xray-core/app/router\xaa\x02\x0fXray.App.Routerb\x06proto3" var ( @@ -1211,59 +713,47 @@ func file_app_router_config_proto_rawDescGZIP() []byte { return file_app_router_config_proto_rawDescData } -var file_app_router_config_proto_enumTypes = make([]protoimpl.EnumInfo, 2) -var file_app_router_config_proto_msgTypes = make([]protoimpl.MessageInfo, 15) +var file_app_router_config_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_app_router_config_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_app_router_config_proto_goTypes = []any{ - (Domain_Type)(0), // 0: xray.app.router.Domain.Type - (Config_DomainStrategy)(0), // 1: xray.app.router.Config.DomainStrategy - (*Domain)(nil), // 2: xray.app.router.Domain - (*CIDR)(nil), // 3: xray.app.router.CIDR - (*GeoIP)(nil), // 4: xray.app.router.GeoIP - (*GeoIPList)(nil), // 5: xray.app.router.GeoIPList - (*GeoSite)(nil), // 6: xray.app.router.GeoSite - (*GeoSiteList)(nil), // 7: xray.app.router.GeoSiteList - (*RoutingRule)(nil), // 8: xray.app.router.RoutingRule - (*WebhookConfig)(nil), // 9: xray.app.router.WebhookConfig - (*BalancingRule)(nil), // 10: xray.app.router.BalancingRule - (*StrategyWeight)(nil), // 11: xray.app.router.StrategyWeight - (*StrategyLeastLoadConfig)(nil), // 12: xray.app.router.StrategyLeastLoadConfig - (*Config)(nil), // 13: xray.app.router.Config - (*Domain_Attribute)(nil), // 14: xray.app.router.Domain.Attribute - nil, // 15: xray.app.router.RoutingRule.AttributesEntry - nil, // 16: xray.app.router.WebhookConfig.HeadersEntry - (*net.PortList)(nil), // 17: xray.common.net.PortList - (net.Network)(0), // 18: xray.common.net.Network - (*serial.TypedMessage)(nil), // 19: xray.common.serial.TypedMessage + (Config_DomainStrategy)(0), // 0: xray.app.router.Config.DomainStrategy + (*RoutingRule)(nil), // 1: xray.app.router.RoutingRule + (*WebhookConfig)(nil), // 2: xray.app.router.WebhookConfig + (*BalancingRule)(nil), // 3: xray.app.router.BalancingRule + (*StrategyWeight)(nil), // 4: xray.app.router.StrategyWeight + (*StrategyLeastLoadConfig)(nil), // 5: xray.app.router.StrategyLeastLoadConfig + (*Config)(nil), // 6: xray.app.router.Config + nil, // 7: xray.app.router.RoutingRule.AttributesEntry + nil, // 8: xray.app.router.WebhookConfig.HeadersEntry + (*geodata.DomainRule)(nil), // 9: xray.common.geodata.DomainRule + (*geodata.IPRule)(nil), // 10: xray.common.geodata.IPRule + (*net.PortList)(nil), // 11: xray.common.net.PortList + (net.Network)(0), // 12: xray.common.net.Network + (*serial.TypedMessage)(nil), // 13: xray.common.serial.TypedMessage } var file_app_router_config_proto_depIdxs = []int32{ - 0, // 0: xray.app.router.Domain.type:type_name -> xray.app.router.Domain.Type - 14, // 1: xray.app.router.Domain.attribute:type_name -> xray.app.router.Domain.Attribute - 3, // 2: xray.app.router.GeoIP.cidr:type_name -> xray.app.router.CIDR - 4, // 3: xray.app.router.GeoIPList.entry:type_name -> xray.app.router.GeoIP - 2, // 4: xray.app.router.GeoSite.domain:type_name -> xray.app.router.Domain - 6, // 5: xray.app.router.GeoSiteList.entry:type_name -> xray.app.router.GeoSite - 2, // 6: xray.app.router.RoutingRule.domain:type_name -> xray.app.router.Domain - 4, // 7: xray.app.router.RoutingRule.geoip:type_name -> xray.app.router.GeoIP - 17, // 8: xray.app.router.RoutingRule.port_list:type_name -> xray.common.net.PortList - 18, // 9: xray.app.router.RoutingRule.networks:type_name -> xray.common.net.Network - 4, // 10: xray.app.router.RoutingRule.source_geoip:type_name -> xray.app.router.GeoIP - 17, // 11: xray.app.router.RoutingRule.source_port_list:type_name -> xray.common.net.PortList - 15, // 12: xray.app.router.RoutingRule.attributes:type_name -> xray.app.router.RoutingRule.AttributesEntry - 4, // 13: xray.app.router.RoutingRule.local_geoip:type_name -> xray.app.router.GeoIP - 17, // 14: xray.app.router.RoutingRule.local_port_list:type_name -> xray.common.net.PortList - 17, // 15: xray.app.router.RoutingRule.vless_route_list:type_name -> xray.common.net.PortList - 9, // 16: xray.app.router.RoutingRule.webhook:type_name -> xray.app.router.WebhookConfig - 16, // 17: xray.app.router.WebhookConfig.headers:type_name -> xray.app.router.WebhookConfig.HeadersEntry - 19, // 18: xray.app.router.BalancingRule.strategy_settings:type_name -> xray.common.serial.TypedMessage - 11, // 19: xray.app.router.StrategyLeastLoadConfig.costs:type_name -> xray.app.router.StrategyWeight - 1, // 20: xray.app.router.Config.domain_strategy:type_name -> xray.app.router.Config.DomainStrategy - 8, // 21: xray.app.router.Config.rule:type_name -> xray.app.router.RoutingRule - 10, // 22: xray.app.router.Config.balancing_rule:type_name -> xray.app.router.BalancingRule - 23, // [23:23] is the sub-list for method output_type - 23, // [23:23] is the sub-list for method input_type - 23, // [23:23] is the sub-list for extension type_name - 23, // [23:23] is the sub-list for extension extendee - 0, // [0:23] is the sub-list for field type_name + 9, // 0: xray.app.router.RoutingRule.domain:type_name -> xray.common.geodata.DomainRule + 10, // 1: xray.app.router.RoutingRule.ip:type_name -> xray.common.geodata.IPRule + 11, // 2: xray.app.router.RoutingRule.port_list:type_name -> xray.common.net.PortList + 12, // 3: xray.app.router.RoutingRule.networks:type_name -> xray.common.net.Network + 10, // 4: xray.app.router.RoutingRule.source_ip:type_name -> xray.common.geodata.IPRule + 11, // 5: xray.app.router.RoutingRule.source_port_list:type_name -> xray.common.net.PortList + 7, // 6: xray.app.router.RoutingRule.attributes:type_name -> xray.app.router.RoutingRule.AttributesEntry + 10, // 7: xray.app.router.RoutingRule.local_ip:type_name -> xray.common.geodata.IPRule + 11, // 8: xray.app.router.RoutingRule.local_port_list:type_name -> xray.common.net.PortList + 11, // 9: xray.app.router.RoutingRule.vless_route_list:type_name -> xray.common.net.PortList + 2, // 10: xray.app.router.RoutingRule.webhook:type_name -> xray.app.router.WebhookConfig + 8, // 11: xray.app.router.WebhookConfig.headers:type_name -> xray.app.router.WebhookConfig.HeadersEntry + 13, // 12: xray.app.router.BalancingRule.strategy_settings:type_name -> xray.common.serial.TypedMessage + 4, // 13: xray.app.router.StrategyLeastLoadConfig.costs:type_name -> xray.app.router.StrategyWeight + 0, // 14: xray.app.router.Config.domain_strategy:type_name -> xray.app.router.Config.DomainStrategy + 1, // 15: xray.app.router.Config.rule:type_name -> xray.app.router.RoutingRule + 3, // 16: xray.app.router.Config.balancing_rule:type_name -> xray.app.router.BalancingRule + 17, // [17:17] is the sub-list for method output_type + 17, // [17:17] is the sub-list for method input_type + 17, // [17:17] is the sub-list for extension type_name + 17, // [17:17] is the sub-list for extension extendee + 0, // [0:17] is the sub-list for field type_name } func init() { file_app_router_config_proto_init() } @@ -1271,21 +761,17 @@ func file_app_router_config_proto_init() { if File_app_router_config_proto != nil { return } - file_app_router_config_proto_msgTypes[6].OneofWrappers = []any{ + file_app_router_config_proto_msgTypes[0].OneofWrappers = []any{ (*RoutingRule_Tag)(nil), (*RoutingRule_BalancingTag)(nil), } - file_app_router_config_proto_msgTypes[12].OneofWrappers = []any{ - (*Domain_Attribute_BoolValue)(nil), - (*Domain_Attribute_IntValue)(nil), - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_app_router_config_proto_rawDesc), len(file_app_router_config_proto_rawDesc)), - NumEnums: 2, - NumMessages: 15, + NumEnums: 1, + NumMessages: 8, NumExtensions: 0, NumServices: 0, }, diff --git a/app/router/config.proto b/app/router/config.proto index 07fe4c51..60f565c3 100644 --- a/app/router/config.proto +++ b/app/router/config.proto @@ -9,67 +9,7 @@ option java_multiple_files = true; import "common/serial/typed_message.proto"; import "common/net/port.proto"; import "common/net/network.proto"; - -// Domain for routing decision. -message Domain { - // Type of domain value. - enum Type { - // The value is used as is. - Plain = 0; - // The value is used as a regular expression. - Regex = 1; - // The value is a root domain. - Domain = 2; - // The value is a domain. - Full = 3; - } - - // Domain matching type. - Type type = 1; - - // Domain value. - string value = 2; - - message Attribute { - string key = 1; - - oneof typed_value { - bool bool_value = 2; - int64 int_value = 3; - } - } - - // Attributes of this domain. May be used for filtering. - repeated Attribute attribute = 3; -} - -// IP for routing decision, in CIDR form. -message CIDR { - // IP address, should be either 4 or 16 bytes. - bytes ip = 1; - - // Number of leading ones in the network mask. - uint32 prefix = 2; -} - -message GeoIP { - string country_code = 1; - repeated CIDR cidr = 2; - bool reverse_match = 3; -} - -message GeoIPList { - repeated GeoIP entry = 1; -} - -message GeoSite { - string country_code = 1; - repeated Domain domain = 2; -} - -message GeoSiteList { - repeated GeoSite entry = 1; -} +import "common/geodata/geodat.proto"; message RoutingRule { oneof target_tag { @@ -79,26 +19,23 @@ message RoutingRule { // Tag of routing balancer. string balancing_tag = 12; } - string rule_tag = 19; + + string rule_tag = 19; // List of domains for target domain matching. - repeated Domain domain = 2; + repeated xray.common.geodata.DomainRule domain = 2; - // List of GeoIPs for target IP address matching. If this entry exists, the - // cidr above will have no effect. GeoIP fields with the same country code are - // supposed to contain exactly same content. They will be merged during - // runtime. For customized GeoIPs, please leave country code empty. - repeated GeoIP geoip = 10; + // List of IPs for target IP address matching. + repeated xray.common.geodata.IPRule ip = 10; - // List of ports. + // List of ports for target port matching. xray.common.net.PortList port_list = 14; // List of networks for matching. repeated xray.common.net.Network networks = 13; - // List of GeoIPs for source IP address matching. If this entry exists, the - // source_cidr above will have no effect. - repeated GeoIP source_geoip = 11; + // List of IPs for source IP address matching. + repeated xray.common.geodata.IPRule source_ip = 11; // List of ports for source port matching. xray.common.net.PortList source_port_list = 16; @@ -109,10 +46,14 @@ message RoutingRule { map attributes = 15; - repeated GeoIP local_geoip = 17; + // List of IPs for local IP address matching. + repeated xray.common.geodata.IPRule local_ip = 17; + + // List of ports for local port matching. xray.common.net.PortList local_port_list = 18; xray.common.net.PortList vless_route_list = 20; + repeated string process = 21; WebhookConfig webhook = 22; } @@ -155,8 +96,7 @@ message Config { // Use domain as is. AsIs = 0; - // [Deprecated] Always resolve IP for domains. - // UseIp = 1; + reserved 1; // Resolve to IP if the domain doesn't match any rules. IpIfNonMatch = 2; diff --git a/app/router/geosite_compact.go b/app/router/geosite_compact.go deleted file mode 100644 index 50fee83f..00000000 --- a/app/router/geosite_compact.go +++ /dev/null @@ -1,100 +0,0 @@ -package router - -import ( - "encoding/gob" - "errors" - "io" - "runtime" - - "github.com/xtls/xray-core/common/strmatcher" -) - -type geoSiteListGob struct { - Sites map[string][]byte - Deps map[string][]string - Hosts map[string][]string -} - -func SerializeGeoSiteList(sites []*GeoSite, deps map[string][]string, hosts map[string][]string, w io.Writer) error { - data := geoSiteListGob{ - Sites: make(map[string][]byte), - Deps: deps, - Hosts: hosts, - } - - for _, site := range sites { - if site == nil { - continue - } - var buf bytesWriter - if err := SerializeDomainMatcher(site.Domain, &buf); err != nil { - return err - } - data.Sites[site.CountryCode] = buf.Bytes() - } - - return gob.NewEncoder(w).Encode(data) -} - -type bytesWriter struct { - data []byte -} - -func (w *bytesWriter) Write(p []byte) (n int, err error) { - w.data = append(w.data, p...) - return len(p), nil -} - -func (w *bytesWriter) Bytes() []byte { - return w.data -} - -func LoadGeoSiteMatcher(r io.Reader, countryCode string) (strmatcher.IndexMatcher, error) { - var data geoSiteListGob - if err := gob.NewDecoder(r).Decode(&data); err != nil { - return nil, err - } - - return loadWithDeps(&data, countryCode, make(map[string]bool)) -} - -func loadWithDeps(data *geoSiteListGob, code string, visited map[string]bool) (strmatcher.IndexMatcher, error) { - if visited[code] { - return nil, errors.New("cyclic dependency") - } - visited[code] = true - - var matchers []strmatcher.IndexMatcher - - if siteData, ok := data.Sites[code]; ok { - m, err := NewDomainMatcherFromBuffer(siteData) - if err == nil { - matchers = append(matchers, m) - } - } - - if deps, ok := data.Deps[code]; ok { - for _, dep := range deps { - m, err := loadWithDeps(data, dep, visited) - if err == nil { - matchers = append(matchers, m) - } - } - } - - if len(matchers) == 0 { - return nil, errors.New("matcher not found for: " + code) - } - if len(matchers) == 1 { - return matchers[0], nil - } - runtime.GC() - return &strmatcher.IndexMatcherGroup{Matchers: matchers}, nil -} -func LoadGeoSiteHosts(r io.Reader) (map[string][]string, error) { - var data geoSiteListGob - if err := gob.NewDecoder(r).Decode(&data); err != nil { - return nil, err - } - return data.Hosts, nil -} diff --git a/app/router/router_test.go b/app/router/router_test.go index a0516e05..f038937f 100644 --- a/app/router/router_test.go +++ b/app/router/router_test.go @@ -7,6 +7,7 @@ import ( "github.com/golang/mock/gomock" . "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/features/dns" @@ -155,10 +156,10 @@ func TestIPOnDemand(t *testing.T) { TargetTag: &RoutingRule_Tag{ Tag: "test", }, - Geoip: []*GeoIP{ + Ip: []*geodata.IPRule{ { - Cidr: []*CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{192, 168, 0, 0}, Prefix: 16, }, @@ -200,10 +201,10 @@ func TestIPIfNonMatchDomain(t *testing.T) { TargetTag: &RoutingRule_Tag{ Tag: "test", }, - Geoip: []*GeoIP{ + Ip: []*geodata.IPRule{ { - Cidr: []*CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{192, 168, 0, 0}, Prefix: 16, }, @@ -245,10 +246,10 @@ func TestIPIfNonMatchIP(t *testing.T) { TargetTag: &RoutingRule_Tag{ Tag: "test", }, - Geoip: []*GeoIP{ + Ip: []*geodata.IPRule{ { - Cidr: []*CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{127, 0, 0, 0}, Prefix: 8, }, diff --git a/app/stats/command/command.go b/app/stats/command/command.go index 535cbac5..aa7d6600 100644 --- a/app/stats/command/command.go +++ b/app/stats/command/command.go @@ -7,7 +7,6 @@ import ( "time" "github.com/xtls/xray-core/common" - "github.com/xtls/xray-core/common/strmatcher" "github.com/xtls/xray-core/core" feature_stats "github.com/xtls/xray-core/features/stats" grpc "google.golang.org/grpc" @@ -163,15 +162,10 @@ func (s *statsServer) GetUsersStats(ctx context.Context, request *GetUsersStatsR } func (s *statsServer) QueryStats(ctx context.Context, request *QueryStatsRequest) (*QueryStatsResponse, error) { - matcher, err := strmatcher.Substr.New(request.Pattern) - if err != nil { - return nil, err - } - response := &QueryStatsResponse{} s.stats.VisitCounters(func(name string, c feature_stats.Counter) bool { - if matcher.Match(name) { + if strings.Contains(name, request.Pattern) { var value int64 if request.Reset_ { value = c.Set(0) diff --git a/common/geodata/domain_matcher.go b/common/geodata/domain_matcher.go new file mode 100644 index 00000000..121f50bd --- /dev/null +++ b/common/geodata/domain_matcher.go @@ -0,0 +1,66 @@ +package geodata + +import ( + "context" + "strings" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/geodata/strmatcher" +) + +type DomainMatcher interface { + Match(input string) []uint32 + MatchAny(input string) bool +} + +func buildDomainMatcher(rules []*DomainRule) (DomainMatcher, error) { + g := strmatcher.NewMphValueMatcher() + for i, r := range rules { + switch v := r.Value.(type) { + case *DomainRule_Custom: + m, err := parseDomain(v.Custom) + if err != nil { + return nil, err + } + g.Add(m, uint32(i)) + case *DomainRule_Geosite: + domains, err := loadSiteWithAttrs(v.Geosite.File, v.Geosite.Code, v.Geosite.Attrs) + if err != nil { + return nil, err + } + for j, d := range domains { + domains[j] = nil // peak mem + m, err := parseDomain(d) + if err != nil { + errors.LogError(context.Background(), "ignore invalid geosite entry in ", v.Geosite.File, ":", v.Geosite.Code, " at index ", j, ", ", err) + continue + } + g.Add(m, uint32(i)) + } + default: + panic("unknown domain rule type") + } + } + if err := g.Build(); err != nil { + return nil, err + } + return g, nil +} + +func parseDomain(d *Domain) (strmatcher.Matcher, error) { + if d == nil { + return nil, errors.New("domain must not be nil") + } + switch d.Type { + case Domain_Substr: + return strmatcher.Substr.New(strings.ToLower(d.Value)) + case Domain_Regex: + return strmatcher.Regex.New(d.Value) + case Domain_Domain: + return strmatcher.Domain.New(d.Value) + case Domain_Full: + return strmatcher.Full.New(strings.ToLower(d.Value)) + default: + return nil, errors.New("unknown domain type: ", d.Type) + } +} diff --git a/common/geodata/domain_registry.go b/common/geodata/domain_registry.go new file mode 100644 index 00000000..774bf82e --- /dev/null +++ b/common/geodata/domain_registry.go @@ -0,0 +1,13 @@ +package geodata + +type DomainRegistry struct{} + +func (r *DomainRegistry) BuildDomainMatcher(rules []*DomainRule) (DomainMatcher, error) { + return buildDomainMatcher(rules) +} + +func newDomainRegistry() *DomainRegistry { + return &DomainRegistry{} +} + +var DomainReg = newDomainRegistry() diff --git a/common/geodata/geodat.pb.go b/common/geodata/geodat.pb.go new file mode 100644 index 00000000..b5cc2797 --- /dev/null +++ b/common/geodata/geodat.pb.go @@ -0,0 +1,908 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.5 +// source: common/geodata/geodat.proto + +package geodata + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Type of domain value. +type Domain_Type int32 + +const ( + // The value is used as a sub string. + Domain_Substr Domain_Type = 0 + // The value is used as a regular expression. + Domain_Regex Domain_Type = 1 + // The value is a domain. + Domain_Domain Domain_Type = 2 + // The value is a full domain. + Domain_Full Domain_Type = 3 +) + +// Enum value maps for Domain_Type. +var ( + Domain_Type_name = map[int32]string{ + 0: "Substr", + 1: "Regex", + 2: "Domain", + 3: "Full", + } + Domain_Type_value = map[string]int32{ + "Substr": 0, + "Regex": 1, + "Domain": 2, + "Full": 3, + } +) + +func (x Domain_Type) Enum() *Domain_Type { + p := new(Domain_Type) + *p = x + return p +} + +func (x Domain_Type) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Domain_Type) Descriptor() protoreflect.EnumDescriptor { + return file_common_geodata_geodat_proto_enumTypes[0].Descriptor() +} + +func (Domain_Type) Type() protoreflect.EnumType { + return &file_common_geodata_geodat_proto_enumTypes[0] +} + +func (x Domain_Type) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use Domain_Type.Descriptor instead. +func (Domain_Type) EnumDescriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{0, 0} +} + +type Domain struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Domain matching type. + Type Domain_Type `protobuf:"varint,1,opt,name=type,proto3,enum=xray.common.geodata.Domain_Type" json:"type,omitempty"` + // Domain value. + Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + // Attributes of this domain. May be used for filtering. + Attribute []*Domain_Attribute `protobuf:"bytes,3,rep,name=attribute,proto3" json:"attribute,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Domain) Reset() { + *x = Domain{} + mi := &file_common_geodata_geodat_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Domain) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Domain) ProtoMessage() {} + +func (x *Domain) ProtoReflect() protoreflect.Message { + mi := &file_common_geodata_geodat_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Domain.ProtoReflect.Descriptor instead. +func (*Domain) Descriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{0} +} + +func (x *Domain) GetType() Domain_Type { + if x != nil { + return x.Type + } + return Domain_Substr +} + +func (x *Domain) GetValue() string { + if x != nil { + return x.Value + } + return "" +} + +func (x *Domain) GetAttribute() []*Domain_Attribute { + if x != nil { + return x.Attribute + } + return nil +} + +type GeoSite struct { + state protoimpl.MessageState `protogen:"open.v1"` + Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"` + Domain []*Domain `protobuf:"bytes,2,rep,name=domain,proto3" json:"domain,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GeoSite) Reset() { + *x = GeoSite{} + mi := &file_common_geodata_geodat_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GeoSite) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GeoSite) ProtoMessage() {} + +func (x *GeoSite) ProtoReflect() protoreflect.Message { + mi := &file_common_geodata_geodat_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GeoSite.ProtoReflect.Descriptor instead. +func (*GeoSite) Descriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{1} +} + +func (x *GeoSite) GetCode() string { + if x != nil { + return x.Code + } + return "" +} + +func (x *GeoSite) GetDomain() []*Domain { + if x != nil { + return x.Domain + } + return nil +} + +type GeoSiteList struct { + state protoimpl.MessageState `protogen:"open.v1"` + Entry []*GeoSite `protobuf:"bytes,1,rep,name=entry,proto3" json:"entry,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GeoSiteList) Reset() { + *x = GeoSiteList{} + mi := &file_common_geodata_geodat_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GeoSiteList) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GeoSiteList) ProtoMessage() {} + +func (x *GeoSiteList) ProtoReflect() protoreflect.Message { + mi := &file_common_geodata_geodat_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GeoSiteList.ProtoReflect.Descriptor instead. +func (*GeoSiteList) Descriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{2} +} + +func (x *GeoSiteList) GetEntry() []*GeoSite { + if x != nil { + return x.Entry + } + return nil +} + +type GeoSiteRule struct { + state protoimpl.MessageState `protogen:"open.v1"` + File string `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"` + Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"` + Attrs string `protobuf:"bytes,3,opt,name=attrs,proto3" json:"attrs,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GeoSiteRule) Reset() { + *x = GeoSiteRule{} + mi := &file_common_geodata_geodat_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GeoSiteRule) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GeoSiteRule) ProtoMessage() {} + +func (x *GeoSiteRule) ProtoReflect() protoreflect.Message { + mi := &file_common_geodata_geodat_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GeoSiteRule.ProtoReflect.Descriptor instead. +func (*GeoSiteRule) Descriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{3} +} + +func (x *GeoSiteRule) GetFile() string { + if x != nil { + return x.File + } + return "" +} + +func (x *GeoSiteRule) GetCode() string { + if x != nil { + return x.Code + } + return "" +} + +func (x *GeoSiteRule) GetAttrs() string { + if x != nil { + return x.Attrs + } + return "" +} + +type DomainRule struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Value: + // + // *DomainRule_Geosite + // *DomainRule_Custom + Value isDomainRule_Value `protobuf_oneof:"value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DomainRule) Reset() { + *x = DomainRule{} + mi := &file_common_geodata_geodat_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DomainRule) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DomainRule) ProtoMessage() {} + +func (x *DomainRule) ProtoReflect() protoreflect.Message { + mi := &file_common_geodata_geodat_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DomainRule.ProtoReflect.Descriptor instead. +func (*DomainRule) Descriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{4} +} + +func (x *DomainRule) GetValue() isDomainRule_Value { + if x != nil { + return x.Value + } + return nil +} + +func (x *DomainRule) GetGeosite() *GeoSiteRule { + if x != nil { + if x, ok := x.Value.(*DomainRule_Geosite); ok { + return x.Geosite + } + } + return nil +} + +func (x *DomainRule) GetCustom() *Domain { + if x != nil { + if x, ok := x.Value.(*DomainRule_Custom); ok { + return x.Custom + } + } + return nil +} + +type isDomainRule_Value interface { + isDomainRule_Value() +} + +type DomainRule_Geosite struct { + Geosite *GeoSiteRule `protobuf:"bytes,1,opt,name=geosite,proto3,oneof"` +} + +type DomainRule_Custom struct { + Custom *Domain `protobuf:"bytes,2,opt,name=custom,proto3,oneof"` +} + +func (*DomainRule_Geosite) isDomainRule_Value() {} + +func (*DomainRule_Custom) isDomainRule_Value() {} + +type CIDR struct { + state protoimpl.MessageState `protogen:"open.v1"` + // IP address, should be either 4 or 16 bytes. + Ip []byte `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"` + // Number of leading ones in the network mask. + Prefix uint32 `protobuf:"varint,2,opt,name=prefix,proto3" json:"prefix,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CIDR) Reset() { + *x = CIDR{} + mi := &file_common_geodata_geodat_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CIDR) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CIDR) ProtoMessage() {} + +func (x *CIDR) ProtoReflect() protoreflect.Message { + mi := &file_common_geodata_geodat_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CIDR.ProtoReflect.Descriptor instead. +func (*CIDR) Descriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{5} +} + +func (x *CIDR) GetIp() []byte { + if x != nil { + return x.Ip + } + return nil +} + +func (x *CIDR) GetPrefix() uint32 { + if x != nil { + return x.Prefix + } + return 0 +} + +type GeoIP struct { + state protoimpl.MessageState `protogen:"open.v1"` + Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"` + Cidr []*CIDR `protobuf:"bytes,2,rep,name=cidr,proto3" json:"cidr,omitempty"` + ReverseMatch bool `protobuf:"varint,3,opt,name=reverse_match,json=reverseMatch,proto3" json:"reverse_match,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GeoIP) Reset() { + *x = GeoIP{} + mi := &file_common_geodata_geodat_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GeoIP) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GeoIP) ProtoMessage() {} + +func (x *GeoIP) ProtoReflect() protoreflect.Message { + mi := &file_common_geodata_geodat_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GeoIP.ProtoReflect.Descriptor instead. +func (*GeoIP) Descriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{6} +} + +func (x *GeoIP) GetCode() string { + if x != nil { + return x.Code + } + return "" +} + +func (x *GeoIP) GetCidr() []*CIDR { + if x != nil { + return x.Cidr + } + return nil +} + +func (x *GeoIP) GetReverseMatch() bool { + if x != nil { + return x.ReverseMatch + } + return false +} + +type GeoIPList struct { + state protoimpl.MessageState `protogen:"open.v1"` + Entry []*GeoIP `protobuf:"bytes,1,rep,name=entry,proto3" json:"entry,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GeoIPList) Reset() { + *x = GeoIPList{} + mi := &file_common_geodata_geodat_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GeoIPList) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GeoIPList) ProtoMessage() {} + +func (x *GeoIPList) ProtoReflect() protoreflect.Message { + mi := &file_common_geodata_geodat_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GeoIPList.ProtoReflect.Descriptor instead. +func (*GeoIPList) Descriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{7} +} + +func (x *GeoIPList) GetEntry() []*GeoIP { + if x != nil { + return x.Entry + } + return nil +} + +type GeoIPRule struct { + state protoimpl.MessageState `protogen:"open.v1"` + File string `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"` + Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"` + ReverseMatch bool `protobuf:"varint,3,opt,name=reverse_match,json=reverseMatch,proto3" json:"reverse_match,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GeoIPRule) Reset() { + *x = GeoIPRule{} + mi := &file_common_geodata_geodat_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GeoIPRule) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GeoIPRule) ProtoMessage() {} + +func (x *GeoIPRule) ProtoReflect() protoreflect.Message { + mi := &file_common_geodata_geodat_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GeoIPRule.ProtoReflect.Descriptor instead. +func (*GeoIPRule) Descriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{8} +} + +func (x *GeoIPRule) GetFile() string { + if x != nil { + return x.File + } + return "" +} + +func (x *GeoIPRule) GetCode() string { + if x != nil { + return x.Code + } + return "" +} + +func (x *GeoIPRule) GetReverseMatch() bool { + if x != nil { + return x.ReverseMatch + } + return false +} + +type IPRule struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Value: + // + // *IPRule_Geoip + // *IPRule_Custom + Value isIPRule_Value `protobuf_oneof:"value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *IPRule) Reset() { + *x = IPRule{} + mi := &file_common_geodata_geodat_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *IPRule) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IPRule) ProtoMessage() {} + +func (x *IPRule) ProtoReflect() protoreflect.Message { + mi := &file_common_geodata_geodat_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IPRule.ProtoReflect.Descriptor instead. +func (*IPRule) Descriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{9} +} + +func (x *IPRule) GetValue() isIPRule_Value { + if x != nil { + return x.Value + } + return nil +} + +func (x *IPRule) GetGeoip() *GeoIPRule { + if x != nil { + if x, ok := x.Value.(*IPRule_Geoip); ok { + return x.Geoip + } + } + return nil +} + +func (x *IPRule) GetCustom() *CIDR { + if x != nil { + if x, ok := x.Value.(*IPRule_Custom); ok { + return x.Custom + } + } + return nil +} + +type isIPRule_Value interface { + isIPRule_Value() +} + +type IPRule_Geoip struct { + Geoip *GeoIPRule `protobuf:"bytes,1,opt,name=geoip,proto3,oneof"` +} + +type IPRule_Custom struct { + Custom *CIDR `protobuf:"bytes,2,opt,name=custom,proto3,oneof"` +} + +func (*IPRule_Geoip) isIPRule_Value() {} + +func (*IPRule_Custom) isIPRule_Value() {} + +type Domain_Attribute struct { + state protoimpl.MessageState `protogen:"open.v1"` + Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + // Types that are valid to be assigned to TypedValue: + // + // *Domain_Attribute_BoolValue + // *Domain_Attribute_IntValue + TypedValue isDomain_Attribute_TypedValue `protobuf_oneof:"typed_value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Domain_Attribute) Reset() { + *x = Domain_Attribute{} + mi := &file_common_geodata_geodat_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Domain_Attribute) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Domain_Attribute) ProtoMessage() {} + +func (x *Domain_Attribute) ProtoReflect() protoreflect.Message { + mi := &file_common_geodata_geodat_proto_msgTypes[10] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Domain_Attribute.ProtoReflect.Descriptor instead. +func (*Domain_Attribute) Descriptor() ([]byte, []int) { + return file_common_geodata_geodat_proto_rawDescGZIP(), []int{0, 0} +} + +func (x *Domain_Attribute) GetKey() string { + if x != nil { + return x.Key + } + return "" +} + +func (x *Domain_Attribute) GetTypedValue() isDomain_Attribute_TypedValue { + if x != nil { + return x.TypedValue + } + return nil +} + +func (x *Domain_Attribute) GetBoolValue() bool { + if x != nil { + if x, ok := x.TypedValue.(*Domain_Attribute_BoolValue); ok { + return x.BoolValue + } + } + return false +} + +func (x *Domain_Attribute) GetIntValue() int64 { + if x != nil { + if x, ok := x.TypedValue.(*Domain_Attribute_IntValue); ok { + return x.IntValue + } + } + return 0 +} + +type isDomain_Attribute_TypedValue interface { + isDomain_Attribute_TypedValue() +} + +type Domain_Attribute_BoolValue struct { + BoolValue bool `protobuf:"varint,2,opt,name=bool_value,json=boolValue,proto3,oneof"` +} + +type Domain_Attribute_IntValue struct { + IntValue int64 `protobuf:"varint,3,opt,name=int_value,json=intValue,proto3,oneof"` +} + +func (*Domain_Attribute_BoolValue) isDomain_Attribute_TypedValue() {} + +func (*Domain_Attribute_IntValue) isDomain_Attribute_TypedValue() {} + +var File_common_geodata_geodat_proto protoreflect.FileDescriptor + +const file_common_geodata_geodat_proto_rawDesc = "" + + "\n" + + "\x1bcommon/geodata/geodat.proto\x12\x13xray.common.geodata\"\xbc\x02\n" + + "\x06Domain\x124\n" + + "\x04type\x18\x01 \x01(\x0e2 .xray.common.geodata.Domain.TypeR\x04type\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value\x12C\n" + + "\tattribute\x18\x03 \x03(\v2%.xray.common.geodata.Domain.AttributeR\tattribute\x1al\n" + + "\tAttribute\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x1f\n" + + "\n" + + "bool_value\x18\x02 \x01(\bH\x00R\tboolValue\x12\x1d\n" + + "\tint_value\x18\x03 \x01(\x03H\x00R\bintValueB\r\n" + + "\vtyped_value\"3\n" + + "\x04Type\x12\n" + + "\n" + + "\x06Substr\x10\x00\x12\t\n" + + "\x05Regex\x10\x01\x12\n" + + "\n" + + "\x06Domain\x10\x02\x12\b\n" + + "\x04Full\x10\x03\"R\n" + + "\aGeoSite\x12\x12\n" + + "\x04code\x18\x01 \x01(\tR\x04code\x123\n" + + "\x06domain\x18\x02 \x03(\v2\x1b.xray.common.geodata.DomainR\x06domain\"A\n" + + "\vGeoSiteList\x122\n" + + "\x05entry\x18\x01 \x03(\v2\x1c.xray.common.geodata.GeoSiteR\x05entry\"K\n" + + "\vGeoSiteRule\x12\x12\n" + + "\x04file\x18\x01 \x01(\tR\x04file\x12\x12\n" + + "\x04code\x18\x02 \x01(\tR\x04code\x12\x14\n" + + "\x05attrs\x18\x03 \x01(\tR\x05attrs\"\x8a\x01\n" + + "\n" + + "DomainRule\x12<\n" + + "\ageosite\x18\x01 \x01(\v2 .xray.common.geodata.GeoSiteRuleH\x00R\ageosite\x125\n" + + "\x06custom\x18\x02 \x01(\v2\x1b.xray.common.geodata.DomainH\x00R\x06customB\a\n" + + "\x05value\".\n" + + "\x04CIDR\x12\x0e\n" + + "\x02ip\x18\x01 \x01(\fR\x02ip\x12\x16\n" + + "\x06prefix\x18\x02 \x01(\rR\x06prefix\"o\n" + + "\x05GeoIP\x12\x12\n" + + "\x04code\x18\x01 \x01(\tR\x04code\x12-\n" + + "\x04cidr\x18\x02 \x03(\v2\x19.xray.common.geodata.CIDRR\x04cidr\x12#\n" + + "\rreverse_match\x18\x03 \x01(\bR\freverseMatch\"=\n" + + "\tGeoIPList\x120\n" + + "\x05entry\x18\x01 \x03(\v2\x1a.xray.common.geodata.GeoIPR\x05entry\"X\n" + + "\tGeoIPRule\x12\x12\n" + + "\x04file\x18\x01 \x01(\tR\x04file\x12\x12\n" + + "\x04code\x18\x02 \x01(\tR\x04code\x12#\n" + + "\rreverse_match\x18\x03 \x01(\bR\freverseMatch\"~\n" + + "\x06IPRule\x126\n" + + "\x05geoip\x18\x01 \x01(\v2\x1e.xray.common.geodata.GeoIPRuleH\x00R\x05geoip\x123\n" + + "\x06custom\x18\x02 \x01(\v2\x19.xray.common.geodata.CIDRH\x00R\x06customB\a\n" + + "\x05valueB[\n" + + "\x17com.xray.common.geodataP\x01Z(github.com/xtls/xray-core/common/geodata\xaa\x02\x13Xray.Common.Geodatab\x06proto3" + +var ( + file_common_geodata_geodat_proto_rawDescOnce sync.Once + file_common_geodata_geodat_proto_rawDescData []byte +) + +func file_common_geodata_geodat_proto_rawDescGZIP() []byte { + file_common_geodata_geodat_proto_rawDescOnce.Do(func() { + file_common_geodata_geodat_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_common_geodata_geodat_proto_rawDesc), len(file_common_geodata_geodat_proto_rawDesc))) + }) + return file_common_geodata_geodat_proto_rawDescData +} + +var file_common_geodata_geodat_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_common_geodata_geodat_proto_msgTypes = make([]protoimpl.MessageInfo, 11) +var file_common_geodata_geodat_proto_goTypes = []any{ + (Domain_Type)(0), // 0: xray.common.geodata.Domain.Type + (*Domain)(nil), // 1: xray.common.geodata.Domain + (*GeoSite)(nil), // 2: xray.common.geodata.GeoSite + (*GeoSiteList)(nil), // 3: xray.common.geodata.GeoSiteList + (*GeoSiteRule)(nil), // 4: xray.common.geodata.GeoSiteRule + (*DomainRule)(nil), // 5: xray.common.geodata.DomainRule + (*CIDR)(nil), // 6: xray.common.geodata.CIDR + (*GeoIP)(nil), // 7: xray.common.geodata.GeoIP + (*GeoIPList)(nil), // 8: xray.common.geodata.GeoIPList + (*GeoIPRule)(nil), // 9: xray.common.geodata.GeoIPRule + (*IPRule)(nil), // 10: xray.common.geodata.IPRule + (*Domain_Attribute)(nil), // 11: xray.common.geodata.Domain.Attribute +} +var file_common_geodata_geodat_proto_depIdxs = []int32{ + 0, // 0: xray.common.geodata.Domain.type:type_name -> xray.common.geodata.Domain.Type + 11, // 1: xray.common.geodata.Domain.attribute:type_name -> xray.common.geodata.Domain.Attribute + 1, // 2: xray.common.geodata.GeoSite.domain:type_name -> xray.common.geodata.Domain + 2, // 3: xray.common.geodata.GeoSiteList.entry:type_name -> xray.common.geodata.GeoSite + 4, // 4: xray.common.geodata.DomainRule.geosite:type_name -> xray.common.geodata.GeoSiteRule + 1, // 5: xray.common.geodata.DomainRule.custom:type_name -> xray.common.geodata.Domain + 6, // 6: xray.common.geodata.GeoIP.cidr:type_name -> xray.common.geodata.CIDR + 7, // 7: xray.common.geodata.GeoIPList.entry:type_name -> xray.common.geodata.GeoIP + 9, // 8: xray.common.geodata.IPRule.geoip:type_name -> xray.common.geodata.GeoIPRule + 6, // 9: xray.common.geodata.IPRule.custom:type_name -> xray.common.geodata.CIDR + 10, // [10:10] is the sub-list for method output_type + 10, // [10:10] is the sub-list for method input_type + 10, // [10:10] is the sub-list for extension type_name + 10, // [10:10] is the sub-list for extension extendee + 0, // [0:10] is the sub-list for field type_name +} + +func init() { file_common_geodata_geodat_proto_init() } +func file_common_geodata_geodat_proto_init() { + if File_common_geodata_geodat_proto != nil { + return + } + file_common_geodata_geodat_proto_msgTypes[4].OneofWrappers = []any{ + (*DomainRule_Geosite)(nil), + (*DomainRule_Custom)(nil), + } + file_common_geodata_geodat_proto_msgTypes[9].OneofWrappers = []any{ + (*IPRule_Geoip)(nil), + (*IPRule_Custom)(nil), + } + file_common_geodata_geodat_proto_msgTypes[10].OneofWrappers = []any{ + (*Domain_Attribute_BoolValue)(nil), + (*Domain_Attribute_IntValue)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_common_geodata_geodat_proto_rawDesc), len(file_common_geodata_geodat_proto_rawDesc)), + NumEnums: 1, + NumMessages: 11, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_common_geodata_geodat_proto_goTypes, + DependencyIndexes: file_common_geodata_geodat_proto_depIdxs, + EnumInfos: file_common_geodata_geodat_proto_enumTypes, + MessageInfos: file_common_geodata_geodat_proto_msgTypes, + }.Build() + File_common_geodata_geodat_proto = out.File + file_common_geodata_geodat_proto_goTypes = nil + file_common_geodata_geodat_proto_depIdxs = nil +} diff --git a/common/geodata/geodat.proto b/common/geodata/geodat.proto new file mode 100644 index 00000000..1828b917 --- /dev/null +++ b/common/geodata/geodat.proto @@ -0,0 +1,90 @@ +syntax = "proto3"; + +package xray.common.geodata; +option csharp_namespace = "Xray.Common.Geodata"; +option go_package = "github.com/xtls/xray-core/common/geodata"; +option java_package = "com.xray.common.geodata"; +option java_multiple_files = true; + +message Domain { + // Type of domain value. + enum Type { + // The value is used as a sub string. + Substr = 0; + // The value is used as a regular expression. + Regex = 1; + // The value is a domain. + Domain = 2; + // The value is a full domain. + Full = 3; + } + // Domain matching type. + Type type = 1; + + // Domain value. + string value = 2; + + message Attribute { + string key = 1; + + oneof typed_value { + bool bool_value = 2; + int64 int_value = 3; + } + } + // Attributes of this domain. May be used for filtering. + repeated Attribute attribute = 3; +} + +message GeoSite { + string code = 1; + repeated Domain domain = 2; +} + +message GeoSiteList { + repeated GeoSite entry = 1; +} + +message GeoSiteRule { + string file = 1; + string code = 2; + string attrs = 3; +} + +message DomainRule { + oneof value { + GeoSiteRule geosite = 1; + Domain custom = 2; + } +} + +message CIDR { + // IP address, should be either 4 or 16 bytes. + bytes ip = 1; + + // Number of leading ones in the network mask. + uint32 prefix = 2; +} + +message GeoIP { + string code = 1; + repeated CIDR cidr = 2; + bool reverse_match = 3; +} + +message GeoIPList { + repeated GeoIP entry = 1; +} + +message GeoIPRule { + string file = 1; + string code = 2; + bool reverse_match = 3; +} + +message IPRule { + oneof value { + GeoIPRule geoip = 1; + CIDR custom = 2; + } +} diff --git a/common/geodata/geodat_loader.go b/common/geodata/geodat_loader.go new file mode 100644 index 00000000..0e12aa28 --- /dev/null +++ b/common/geodata/geodat_loader.go @@ -0,0 +1,207 @@ +package geodata + +import ( + "bufio" + "bytes" + "io" + "runtime" + "strings" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/platform/filesystem" + + "google.golang.org/protobuf/proto" +) + +func checkFile(file, code string) error { + r, err := filesystem.OpenAsset(file) + if err != nil { + return errors.New("failed to open ", file).Base(err) + } + defer r.Close() + if _, err := find(r, []byte(code), false); err != nil { + return errors.New("failed to check code ", code, " from ", file).Base(err) + } + return nil +} + +func loadFile(file, code string) ([]byte, error) { + runtime.GC() // peak mem + r, err := filesystem.OpenAsset(file) + if err != nil { + return nil, errors.New("failed to open ", file).Base(err) + } + defer r.Close() + bs, err := find(r, []byte(code), true) + if err != nil { + return nil, errors.New("failed to load code ", code, " from ", file).Base(err) + } + return bs, nil +} + +func loadIP(file, code string) ([]*CIDR, error) { + bs, err := loadFile(file, code) + if err != nil { + return nil, err + } + defer runtime.GC() // peak mem + var geoip GeoIP + if err := proto.Unmarshal(bs, &geoip); err != nil { + return nil, errors.New("error unmarshal IP in ", file, ":", code).Base(err) + } + return geoip.Cidr, nil +} + +func loadSite(file, code string) ([]*Domain, error) { + bs, err := loadFile(file, code) + if err != nil { + return nil, err + } + defer runtime.GC() // peak mem + var geosite GeoSite + if err := proto.Unmarshal(bs, &geosite); err != nil { + return nil, errors.New("error unmarshal Site in ", file, ":", code).Base(err) + } + return geosite.Domain, nil +} + +func decodeVarint(br *bufio.Reader) (uint64, error) { + var x uint64 + for shift := uint(0); shift < 64; shift += 7 { + b, err := br.ReadByte() + if err != nil { + return 0, err + } + x |= (uint64(b) & 0x7F) << shift + if (b & 0x80) == 0 { + return x, nil + } + } + // The number is too large to represent in a 64-bit value. + return 0, errors.New("varint overflow") +} + +func find(r io.Reader, code []byte, readBody bool) ([]byte, error) { + codeL := len(code) + if codeL == 0 { + return nil, errors.New("empty code") + } + + br := bufio.NewReaderSize(r, 64*1024) + need := 2 + codeL // TODO: if code too long + prefixBuf := make([]byte, need) + + for { + if _, err := br.ReadByte(); err != nil { + return nil, err + } + + x, err := decodeVarint(br) + if err != nil { + return nil, err + } + bodyL := int(x) + if bodyL <= 0 { + return nil, errors.New("invalid body length: ", bodyL) + } + + prefixL := bodyL + if prefixL > need { + prefixL = need + } + prefix := prefixBuf[:prefixL] + if _, err := io.ReadFull(br, prefix); err != nil { + return nil, err + } + + match := false + if bodyL >= need { + if int(prefix[1]) == codeL && bytes.Equal(prefix[2:need], code) { + if !readBody { + return nil, nil + } + match = true + } + } + + remain := bodyL - prefixL + if match { + out := make([]byte, bodyL) + copy(out, prefix) + if remain > 0 { + if _, err := io.ReadFull(br, out[prefixL:]); err != nil { + return nil, err + } + } + return out, nil + } + + if remain > 0 { + if _, err := br.Discard(remain); err != nil { + return nil, err + } + } + } +} + +type AttributeMatcher interface { + Match(*Domain) bool +} + +type HasAttrMatcher string + +// Match reports whether this matcher matches any attribute on the domain. +func (m HasAttrMatcher) Match(domain *Domain) bool { + for _, attr := range domain.Attribute { + if attr.Key == string(m) { + return true + } + } + return false +} + +type AllAttrsMatcher struct { + matchers []AttributeMatcher +} + +// Match reports whether the domain matches every matcher in the list. +func (m *AllAttrsMatcher) Match(domain *Domain) bool { + for _, matcher := range m.matchers { + if !matcher.Match(domain) { + return false + } + } + return true +} + +func NewAllAttrsMatcher(attrs string) AttributeMatcher { + if attrs == "" { + return nil + } + m := new(AllAttrsMatcher) + for _, attr := range strings.Split(attrs, "@") { + m.matchers = append(m.matchers, HasAttrMatcher(attr)) + } + return m +} + +func loadSiteWithAttrs(file, code, attrs string) ([]*Domain, error) { + domains, err := loadSite(file, code) + if err != nil { + return nil, err + } + + matcher := NewAllAttrsMatcher(attrs) + if matcher == nil { + return domains, nil + } + + filtered := make([]*Domain, 0, len(domains)) + for _, d := range domains { + if matcher.Match(d) { + filtered = append(filtered, d) + } + } + + return filtered, nil +} diff --git a/app/router/condition_geoip.go b/common/geodata/ip_matcher.go similarity index 71% rename from app/router/condition_geoip.go rename to common/geodata/ip_matcher.go index cdfcb9fe..565b38ea 100644 --- a/app/router/condition_geoip.go +++ b/common/geodata/ip_matcher.go @@ -1,8 +1,10 @@ -package router +package geodata import ( "context" "net/netip" + "runtime" + "slices" "sort" "strings" "sync" @@ -13,7 +15,7 @@ import ( "go4.org/netipx" ) -type GeoIPMatcher interface { +type IPMatcher interface { // TODO: (PERF) all net.IP -> netipx.Addr // Invalid IP always return false. @@ -33,13 +35,13 @@ type GeoIPMatcher interface { SetReverse(reverse bool) } -type GeoIPSet struct { +type IPSet struct { ipv4, ipv6 *netipx.IPSet max4, max6 uint8 } -type HeuristicGeoIPMatcher struct { - ipset *GeoIPSet +type HeuristicIPMatcher struct { + ipset *IPSet reverse bool } @@ -48,8 +50,8 @@ type ipBucket struct { ips []net.IP } -// Match implements GeoIPMatcher. -func (m *HeuristicGeoIPMatcher) Match(ip net.IP) bool { +// Match implements IPMatcher. +func (m *HeuristicIPMatcher) Match(ip net.IP) bool { ipx, ok := netipx.FromStdIP(ip) if !ok { return false @@ -57,18 +59,24 @@ func (m *HeuristicGeoIPMatcher) Match(ip net.IP) bool { return m.matchAddr(ipx) } -func (m *HeuristicGeoIPMatcher) matchAddr(ipx netip.Addr) bool { +func (m *HeuristicIPMatcher) matchAddr(ipx netip.Addr) bool { if ipx.Is4() { + if m.ipset.max4 == 0xff { + return false + } return m.ipset.ipv4.Contains(ipx) != m.reverse } if ipx.Is6() { + if m.ipset.max6 == 0xff { + return false + } return m.ipset.ipv6.Contains(ipx) != m.reverse } return false } -// AnyMatch implements GeoIPMatcher. -func (m *HeuristicGeoIPMatcher) AnyMatch(ips []net.IP) bool { +// AnyMatch implements IPMatcher. +func (m *HeuristicIPMatcher) AnyMatch(ips []net.IP) bool { n := len(ips) if n == 0 { return false @@ -117,8 +125,8 @@ func (m *HeuristicGeoIPMatcher) AnyMatch(ips []net.IP) bool { return false } -// Matches implements GeoIPMatcher. -func (m *HeuristicGeoIPMatcher) Matches(ips []net.IP) bool { +// Matches implements IPMatcher. +func (m *HeuristicIPMatcher) Matches(ips []net.IP) bool { n := len(ips) if n == 0 { return false @@ -205,8 +213,8 @@ func prefixKeyFromIP(ip net.IP) (key [9]byte, ok bool) { return key, false // illegal } -// FilterIPs implements GeoIPMatcher. -func (m *HeuristicGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) { +// FilterIPs implements IPMatcher. +func (m *HeuristicIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) { n := len(ips) if n == 0 { return []net.IP{}, []net.IP{} @@ -295,22 +303,22 @@ func (m *HeuristicGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmat return } -// ToggleReverse implements GeoIPMatcher. -func (m *HeuristicGeoIPMatcher) ToggleReverse() { +// ToggleReverse implements IPMatcher. +func (m *HeuristicIPMatcher) ToggleReverse() { m.reverse = !m.reverse } -// SetReverse implements GeoIPMatcher. -func (m *HeuristicGeoIPMatcher) SetReverse(reverse bool) { +// SetReverse implements IPMatcher. +func (m *HeuristicIPMatcher) SetReverse(reverse bool) { m.reverse = reverse } -type GeneralMultiGeoIPMatcher struct { - matchers []GeoIPMatcher +type GeneralMultiIPMatcher struct { + matchers []IPMatcher } -// Match implements GeoIPMatcher. -func (mm *GeneralMultiGeoIPMatcher) Match(ip net.IP) bool { +// Match implements IPMatcher. +func (mm *GeneralMultiIPMatcher) Match(ip net.IP) bool { for _, m := range mm.matchers { if m.Match(ip) { return true @@ -319,8 +327,8 @@ func (mm *GeneralMultiGeoIPMatcher) Match(ip net.IP) bool { return false } -// AnyMatch implements GeoIPMatcher. -func (mm *GeneralMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool { +// AnyMatch implements IPMatcher. +func (mm *GeneralMultiIPMatcher) AnyMatch(ips []net.IP) bool { for _, m := range mm.matchers { if m.AnyMatch(ips) { return true @@ -329,8 +337,8 @@ func (mm *GeneralMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool { return false } -// Matches implements GeoIPMatcher. -func (mm *GeneralMultiGeoIPMatcher) Matches(ips []net.IP) bool { +// Matches implements IPMatcher. +func (mm *GeneralMultiIPMatcher) Matches(ips []net.IP) bool { for _, m := range mm.matchers { if m.Matches(ips) { return true @@ -339,8 +347,8 @@ func (mm *GeneralMultiGeoIPMatcher) Matches(ips []net.IP) bool { return false } -// FilterIPs implements GeoIPMatcher. -func (mm *GeneralMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) { +// FilterIPs implements IPMatcher. +func (mm *GeneralMultiIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) { matched = make([]net.IP, 0, len(ips)) unmatched = ips for _, m := range mm.matchers { @@ -356,26 +364,26 @@ func (mm *GeneralMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, u return } -// ToggleReverse implements GeoIPMatcher. -func (mm *GeneralMultiGeoIPMatcher) ToggleReverse() { +// ToggleReverse implements IPMatcher. +func (mm *GeneralMultiIPMatcher) ToggleReverse() { for _, m := range mm.matchers { m.ToggleReverse() } } -// SetReverse implements GeoIPMatcher. -func (mm *GeneralMultiGeoIPMatcher) SetReverse(reverse bool) { +// SetReverse implements IPMatcher. +func (mm *GeneralMultiIPMatcher) SetReverse(reverse bool) { for _, m := range mm.matchers { m.SetReverse(reverse) } } -type HeuristicMultiGeoIPMatcher struct { - matchers []*HeuristicGeoIPMatcher +type HeuristicMultiIPMatcher struct { + matchers []*HeuristicIPMatcher } -// Match implements GeoIPMatcher. -func (mm *HeuristicMultiGeoIPMatcher) Match(ip net.IP) bool { +// Match implements IPMatcher. +func (mm *HeuristicMultiIPMatcher) Match(ip net.IP) bool { ipx, ok := netipx.FromStdIP(ip) if !ok { return false @@ -389,8 +397,8 @@ func (mm *HeuristicMultiGeoIPMatcher) Match(ip net.IP) bool { return false } -// AnyMatch implements GeoIPMatcher. -func (mm *HeuristicMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool { +// AnyMatch implements IPMatcher. +func (mm *HeuristicMultiIPMatcher) AnyMatch(ips []net.IP) bool { n := len(ips) if n == 0 { return false @@ -439,8 +447,8 @@ func (mm *HeuristicMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool { return false } -// Matches implements GeoIPMatcher. -func (mm *HeuristicMultiGeoIPMatcher) Matches(ips []net.IP) bool { +// Matches implements IPMatcher. +func (mm *HeuristicMultiIPMatcher) Matches(ips []net.IP) bool { n := len(ips) if n == 0 { return false @@ -503,7 +511,7 @@ type ipViews struct { precise4, precise6 []netip.Addr } -func (v *ipViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) bool { +func (v *ipViews) ensureForMatcher(m *HeuristicIPMatcher, ips []net.IP) bool { needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil @@ -581,8 +589,8 @@ func (v *ipViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) bool return true } -// FilterIPs implements GeoIPMatcher. -func (mm *HeuristicMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) { +// FilterIPs implements IPMatcher. +func (mm *HeuristicMultiIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) { n := len(ips) if n == 0 { return []net.IP{}, []net.IP{} @@ -694,7 +702,7 @@ type ipBucketViews struct { precise4, precise6 map[netip.Addr]net.IP } -func (v *ipBucketViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) { +func (v *ipBucketViews) ensureForMatcher(m *HeuristicIPMatcher, ips []net.IP) { needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil @@ -782,28 +790,28 @@ func (v *ipBucketViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) } } -// ToggleReverse implements GeoIPMatcher. -func (mm *HeuristicMultiGeoIPMatcher) ToggleReverse() { +// ToggleReverse implements IPMatcher. +func (mm *HeuristicMultiIPMatcher) ToggleReverse() { for _, m := range mm.matchers { m.ToggleReverse() } } -// SetReverse implements GeoIPMatcher. -func (mm *HeuristicMultiGeoIPMatcher) SetReverse(reverse bool) { +// SetReverse implements IPMatcher. +func (mm *HeuristicMultiIPMatcher) SetReverse(reverse bool) { for _, m := range mm.matchers { m.SetReverse(reverse) } } -type GeoIPSetFactory struct { +type IPSetFactory struct { sync.Mutex - shared map[string]*GeoIPSet // TODO: cleanup + shared map[string]*IPSet // TODO: cleanup } -var ipsetFactory = GeoIPSetFactory{shared: make(map[string]*GeoIPSet)} +func (f *IPSetFactory) GetOrCreateFromGeoIPRules(rules []*GeoIPRule) (*IPSet, error) { + key := buildGeoIPRulesKey(rules) -func (f *GeoIPSetFactory) GetOrCreate(key string, cidrGroups [][]*CIDR) (*GeoIPSet, error) { f.Lock() defer f.Unlock() @@ -811,41 +819,92 @@ func (f *GeoIPSetFactory) GetOrCreate(key string, cidrGroups [][]*CIDR) (*GeoIPS return ipset, nil } - ipset, err := f.Create(cidrGroups...) + ipset, err := f.createFrom(func(add func(*CIDR)) error { + for _, r := range rules { + cidrs, err := loadIP(r.File, r.Code) + if err != nil { + return err + } + for i, c := range cidrs { + add(c) + cidrs[i] = nil // peak mem + } + } + return nil + }) if err == nil { f.shared[key] = ipset } return ipset, err } -func (f *GeoIPSetFactory) Create(cidrGroups ...[]*CIDR) (*GeoIPSet, error) { - var ipv4Builder, ipv6Builder netipx.IPSetBuilder +func buildGeoIPRulesKey(rules []*GeoIPRule) string { + rules = slices.Clone(rules) - for _, cidrGroup := range cidrGroups { - for i, cidrEntry := range cidrGroup { - cidrGroup[i] = nil - ipBytes := cidrEntry.GetIp() - prefixLen := int(cidrEntry.GetPrefix()) + sort.Slice(rules, func(i, j int) bool { + ri, rj := rules[i], rules[j] + if ri.File != rj.File { + return ri.File < rj.File + } + return ri.Code < rj.Code + }) - addr, ok := netip.AddrFromSlice(ipBytes) - if !ok { - errors.LogError(context.Background(), "ignore invalid IP byte slice: ", ipBytes) - continue - } - - prefix := netip.PrefixFrom(addr, prefixLen) - if !prefix.IsValid() { - errors.LogError(context.Background(), "ignore created invalid prefix from addr ", addr, " and length ", prefixLen) - continue - } - - if addr.Is4() { - ipv4Builder.AddPrefix(prefix) - } else if addr.Is6() { - ipv6Builder.AddPrefix(prefix) - } + var sb strings.Builder + sb.Grow(len(rules) * 20) // geoip.dat:xx, + var last *GeoIPRule + for i, r := range rules { + if i == 0 || (r.File != last.File || r.Code != last.Code) { + last = r + sb.WriteString(r.File) + sb.WriteString(":") + sb.WriteString(r.Code) + sb.WriteString(",") } } + return sb.String() +} + +func (f *IPSetFactory) CreateFromCIDRs(cidrs []*CIDR) (*IPSet, error) { + return f.createFrom(func(add func(*CIDR)) error { + for _, c := range cidrs { + add(c) + } + return nil + }) +} + +func (f *IPSetFactory) createFrom(yield func(func(*CIDR)) error) (*IPSet, error) { + var ipv4Builder, ipv6Builder netipx.IPSetBuilder + + err := yield(func(c *CIDR) { + ipBytes := c.GetIp() + prefixLen := int(c.GetPrefix()) + + addr, ok := netip.AddrFromSlice(ipBytes) + if !ok { + errors.LogError(context.Background(), "ignore invalid IP byte slice: ", ipBytes) + return + } + + prefix := netip.PrefixFrom(addr, prefixLen) + if !prefix.IsValid() { + errors.LogError(context.Background(), "ignore created invalid prefix from addr ", addr, " and length ", prefixLen) + return + } + + if addr.Is4() { + ipv4Builder.AddPrefix(prefix) + } else if addr.Is6() { + ipv6Builder.AddPrefix(prefix) + } + }) + if err != nil { + return nil, err + } + + // peak mem + runtime.GC() + defer runtime.GC() ipv4, err := ipv4Builder.IPSet() if err != nil { @@ -876,87 +935,62 @@ func (f *GeoIPSetFactory) Create(cidrGroups ...[]*CIDR) (*GeoIPSet, error) { max6 = 0xff } - return &GeoIPSet{ipv4: ipv4, ipv6: ipv6, max4: uint8(max4), max6: uint8(max6)}, nil + return &IPSet{ipv4: ipv4, ipv6: ipv6, max4: uint8(max4), max6: uint8(max6)}, nil } -func BuildOptimizedGeoIPMatcher(geoips ...*GeoIP) (GeoIPMatcher, error) { - n := len(geoips) - if n == 0 { - return nil, errors.New("no geoip configs provided") - } +func buildOptimizedIPMatcher(f *IPSetFactory, rules []*IPRule) (IPMatcher, error) { + n := len(rules) + custom := make([]*CIDR, 0, n) + pos := make([]*GeoIPRule, 0, n) + neg := make([]*GeoIPRule, 0, n) - var subs []*HeuristicGeoIPMatcher - pos := make([]*GeoIP, 0, n) - neg := make([]*GeoIP, 0, n/2) - - for _, geoip := range geoips { - if geoip == nil { - return nil, errors.New("geoip entry is nil") - } - if geoip.CountryCode == "" { - ipset, err := ipsetFactory.Create(geoip.Cidr) - if err != nil { - return nil, err + for _, r := range rules { + switch v := r.Value.(type) { + case *IPRule_Custom: + custom = append(custom, v.Custom) + case *IPRule_Geoip: + if !v.Geoip.ReverseMatch { + pos = append(pos, v.Geoip) + } else { + neg = append(neg, v.Geoip) } - subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: geoip.ReverseMatch}) - continue - } - if !geoip.ReverseMatch { - pos = append(pos, geoip) - } else { - neg = append(neg, geoip) + default: + panic("unknown ip rule type") } } - buildIPSet := func(mergeables []*GeoIP) (*GeoIPSet, error) { - n := len(mergeables) - if n == 0 { - return nil, nil + subs := make([]*HeuristicIPMatcher, 0, 3) + + if len(custom) > 0 { + ipset, err := f.CreateFromCIDRs(custom) + if err != nil { + return nil, err } + subs = append(subs, &HeuristicIPMatcher{ipset: ipset, reverse: false}) + } - sort.Slice(mergeables, func(i, j int) bool { - gi, gj := mergeables[i], mergeables[j] - return gi.CountryCode < gj.CountryCode - }) - - var sb strings.Builder - sb.Grow(n * 3) // xx, - cidrGroups := make([][]*CIDR, 0, n) - var last *GeoIP - for i, geoip := range mergeables { - if i == 0 || (geoip.CountryCode != last.CountryCode) { - last = geoip - sb.WriteString(geoip.CountryCode) - sb.WriteString(",") - cidrGroups = append(cidrGroups, geoip.Cidr) - } + if len(pos) > 0 { + ipset, err := f.GetOrCreateFromGeoIPRules(pos) + if err != nil { + return nil, err } - - return ipsetFactory.GetOrCreate(sb.String(), cidrGroups) + subs = append(subs, &HeuristicIPMatcher{ipset: ipset, reverse: false}) } - ipset, err := buildIPSet(pos) - if err != nil { - return nil, err - } - if ipset != nil { - subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: false}) - } - - ipset, err = buildIPSet(neg) - if err != nil { - return nil, err - } - if ipset != nil { - subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: true}) + if len(neg) > 0 { + ipset, err := f.GetOrCreateFromGeoIPRules(neg) + if err != nil { + return nil, err + } + subs = append(subs, &HeuristicIPMatcher{ipset: ipset, reverse: true}) } switch len(subs) { case 0: - return nil, errors.New("no valid geoip matcher") + return nil, errors.New("no valid ip matcher") case 1: return subs[0], nil default: - return &HeuristicMultiGeoIPMatcher{matchers: subs}, nil + return &HeuristicMultiIPMatcher{matchers: subs}, nil } } diff --git a/common/geodata/ip_matcher_test.go b/common/geodata/ip_matcher_test.go new file mode 100644 index 00000000..ac627506 --- /dev/null +++ b/common/geodata/ip_matcher_test.go @@ -0,0 +1,325 @@ +package geodata + +import ( + "net" + "path/filepath" + "reflect" + "slices" + "testing" + + "github.com/xtls/xray-core/common" + xnet "github.com/xtls/xray-core/common/net" +) + +func buildIPMatcher(rawRules ...string) IPMatcher { + rules, err := ParseIPRules(rawRules) + common.Must(err) + + matcher, err := newIPRegistry().BuildIPMatcher(rules) + common.Must(err) + + return matcher +} + +func sortIPStrings(ips []net.IP) []string { + output := make([]string, 0, len(ips)) + for _, ip := range ips { + output = append(output, ip.String()) + } + slices.Sort(output) + return output +} + +func TestIPMatcher(t *testing.T) { + matcher := buildIPMatcher( + "0.0.0.0/8", + "10.0.0.0/8", + "100.64.0.0/10", + "127.0.0.0/8", + "169.254.0.0/16", + "172.16.0.0/12", + "192.0.0.0/24", + "192.0.2.0/24", + "192.168.0.0/16", + "192.18.0.0/15", + "198.51.100.0/24", + "203.0.113.0/24", + "8.8.8.8/32", + "91.108.4.0/16", + ) + + testCases := []struct { + Input string + Output bool + }{ + { + Input: "192.168.1.1", + Output: true, + }, + { + Input: "192.0.0.0", + Output: true, + }, + { + Input: "192.0.1.0", + Output: false, + }, + { + Input: "0.1.0.0", + Output: true, + }, + { + Input: "1.0.0.1", + Output: false, + }, + { + Input: "8.8.8.7", + Output: false, + }, + { + Input: "8.8.8.8", + Output: true, + }, + { + Input: "2001:cdba::3257:9652", + Output: false, + }, + { + Input: "91.108.255.254", + Output: true, + }, + } + + for _, test := range testCases { + if v := matcher.Match(xnet.ParseAddress(test.Input).IP()); v != test.Output { + t.Error("unexpected output: ", v, " for test case ", test) + } + } +} + +func TestIPMatcherRegression(t *testing.T) { + matcher := buildIPMatcher( + "98.108.20.0/22", + "98.108.20.0/23", + ) + + testCases := []struct { + Input string + Output bool + }{ + { + Input: "98.108.22.11", + Output: true, + }, + { + Input: "98.108.25.0", + Output: false, + }, + } + + for _, test := range testCases { + if v := matcher.Match(xnet.ParseAddress(test.Input).IP()); v != test.Output { + t.Error("unexpected output: ", v, " for test case ", test) + } + } +} + +func TestIPReverseMatcher(t *testing.T) { + matcher := buildIPMatcher( + "8.8.8.8/32", + "91.108.4.0/16", + ) + matcher.SetReverse(true) + + testCases := []struct { + Input string + Output bool + }{ + { + Input: "8.8.8.8", + Output: false, + }, + { + Input: "2001:cdba::3257:9652", + Output: false, + }, + { + Input: "91.108.255.254", + Output: false, + }, + } + + for _, test := range testCases { + if v := matcher.Match(xnet.ParseAddress(test.Input).IP()); v != test.Output { + t.Error("unexpected output: ", v, " for test case ", test) + } + } +} + +func TestIPReverseMatcher2(t *testing.T) { + matcher := buildIPMatcher( + "8.8.8.8/32", + "91.108.4.0/16", + "fe80::", // Keep IPv6 family non-empty so reverse matching can evaluate IPv6 input. + ) + matcher.SetReverse(true) + + testCases := []struct { + Input string + Output bool + }{ + { + Input: "8.8.8.8", + Output: false, + }, + { + Input: "2001:cdba::3257:9652", + Output: true, + }, + { + Input: "91.108.255.254", + Output: false, + }, + } + + for _, test := range testCases { + if v := matcher.Match(xnet.ParseAddress(test.Input).IP()); v != test.Output { + t.Error("unexpected output: ", v, " for test case ", test) + } + } +} + +func TestIPMatcherAnyMatchAndMatches(t *testing.T) { + matcher := buildIPMatcher( + "8.8.8.8/32", + "2001:4860:4860::8888/128", + ) + ip := func(raw string) net.IP { + return xnet.ParseAddress(raw).IP() + } + + if matcher.AnyMatch(nil) { + t.Fatal("expect AnyMatch(nil) to be false") + } + + if !matcher.AnyMatch([]net.IP{ + net.IP{}, + ip("1.1.1.1"), + ip("8.8.8.8"), + }) { + t.Fatal("expect AnyMatch to ignore invalid IPs and return true when one valid IP matches") + } + + if matcher.AnyMatch([]net.IP{ + ip("1.1.1.1"), + ip("2001:db8::1"), + }) { + t.Fatal("expect AnyMatch to be false when no valid IP matches") + } + + if !matcher.Matches([]net.IP{ + ip("8.8.8.8"), + ip("2001:4860:4860::8888"), + }) { + t.Fatal("expect Matches to be true when all valid IPs match") + } + + if matcher.Matches([]net.IP{ + ip("8.8.8.8"), + ip("1.1.1.1"), + }) { + t.Fatal("expect Matches to be false when one valid IP does not match") + } + + if matcher.Matches([]net.IP{ + ip("8.8.8.8"), + net.IP{}, + }) { + t.Fatal("expect Matches to be false when any IP is invalid") + } +} + +func TestIPMatcherFilterIPs(t *testing.T) { + matcher := buildIPMatcher( + "8.8.8.8/32", + "91.108.4.0/16", + "2001:4860:4860::8888/128", + ) + ip := func(raw string) net.IP { + return xnet.ParseAddress(raw).IP() + } + + matched, unmatched := matcher.FilterIPs([]net.IP{ + net.IP{}, + ip("8.8.8.8"), + ip("91.108.255.254"), + ip("1.1.1.1"), + ip("2001:4860:4860::8888"), + ip("2001:db8::1"), + }) + + wantMatched := []string{ + "2001:4860:4860::8888", + "8.8.8.8", + "91.108.255.254", + } + slices.Sort(wantMatched) + if v := sortIPStrings(matched); !reflect.DeepEqual(v, wantMatched) { + t.Error("unexpected output: ", v, " want ", wantMatched) + } + + wantUnmatched := []string{ + "1.1.1.1", + "2001:db8::1", + } + slices.Sort(wantUnmatched) + if v := sortIPStrings(unmatched); !reflect.DeepEqual(v, wantUnmatched) { + t.Error("unexpected output: ", v, " want ", wantUnmatched) + } +} + +func TestIPMatcher4CN(t *testing.T) { + t.Setenv("xray.location.asset", filepath.Join("..", "..", "resources")) + + matcher := buildIPMatcher("geoip:cn") + + if matcher.Match([]byte{8, 8, 8, 8}) { + t.Error("expect CN geoip doesn't contain 8.8.8.8, but actually does") + } +} + +func TestIPMatcher6US(t *testing.T) { + t.Setenv("xray.location.asset", filepath.Join("..", "..", "resources")) + + matcher := buildIPMatcher("geoip:us") + + if !matcher.Match(xnet.ParseAddress("2001:4860:4860::8888").IP()) { + t.Error("expect US geoip contain 2001:4860:4860::8888, but actually not") + } +} + +func BenchmarkIPMatcher4CN(b *testing.B) { + b.Setenv("xray.location.asset", filepath.Join("..", "..", "resources")) + + matcher := buildIPMatcher("geoip:cn") + ip := net.IP{8, 8, 8, 8} + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _ = matcher.Match(ip) + } +} + +func BenchmarkIPMatcher6US(b *testing.B) { + b.Setenv("xray.location.asset", filepath.Join("..", "..", "resources")) + + matcher := buildIPMatcher("geoip:us") + ip := xnet.ParseAddress("2001:4860:4860::8888").IP() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _ = matcher.Match(ip) + } +} diff --git a/common/geodata/ip_registry.go b/common/geodata/ip_registry.go new file mode 100644 index 00000000..fab4a778 --- /dev/null +++ b/common/geodata/ip_registry.go @@ -0,0 +1,17 @@ +package geodata + +type IPRegistry struct { + ipsetFactory *IPSetFactory +} + +func (r *IPRegistry) BuildIPMatcher(rules []*IPRule) (IPMatcher, error) { + return buildOptimizedIPMatcher(r.ipsetFactory, rules) +} + +func newIPRegistry() *IPRegistry { + return &IPRegistry{ + ipsetFactory: &IPSetFactory{shared: make(map[string]*IPSet)}, + } +} + +var IPReg = newIPRegistry() diff --git a/common/geodata/rule_parser.go b/common/geodata/rule_parser.go new file mode 100644 index 00000000..1184f553 --- /dev/null +++ b/common/geodata/rule_parser.go @@ -0,0 +1,254 @@ +package geodata + +import ( + "strconv" + "strings" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/net" +) + +const ( + DefaultGeoIPDat = "geoip.dat" + DefaultGeoSiteDat = "geosite.dat" +) + +func ParseIPRules(rules []string) ([]*IPRule, error) { + var ipRules []*IPRule + + for i, r := range rules { + if strings.HasPrefix(r, "geoip:") { + r = "ext:" + DefaultGeoIPDat + ":" + r[len("geoip:"):] + } + + prefix := 0 + for _, ext := range [...]string{"ext:", "ext-ip:"} { + if strings.HasPrefix(r, ext) { + prefix = len(ext) + break + } + } + + var rule isIPRule_Value + var err error + if prefix > 0 { + rule, err = parseGeoIPRule(r[prefix:]) + } else { + rule, err = parseCustomIPRule(r) + } + if err != nil { + return nil, errors.New("illegal ip rule: ", rules[i]).Base(err) + } + ipRules = append(ipRules, &IPRule{Value: rule}) + } + + return ipRules, nil +} + +func parseGeoIPRule(rule string) (*IPRule_Geoip, error) { + file, code, ok := strings.Cut(rule, ":") + if !ok { + return nil, errors.New("syntax error") + } + + if file == "" { + return nil, errors.New("empty file") + } + + reverse := false + if strings.HasPrefix(code, "!") { + code = code[1:] + reverse = true + } + if code == "" { + return nil, errors.New("empty code") + } + code = strings.ToUpper(code) + + if err := checkFile(file, code); err != nil { + return nil, err + } + + return &IPRule_Geoip{ + Geoip: &GeoIPRule{ + File: file, + Code: code, + ReverseMatch: reverse, + }, + }, nil +} + +func parseCustomIPRule(rule string) (*IPRule_Custom, error) { + cidr, err := parseCIDR(rule) + if err != nil { + return nil, err + } + return &IPRule_Custom{ + Custom: cidr, + }, nil +} + +func parseCIDR(s string) (*CIDR, error) { + ipStr, prefixStr, _ := strings.Cut(s, "/") + + ipAddr := net.ParseAddress(ipStr) + + var maxPrefix uint32 + switch ipAddr.Family() { + case net.AddressFamilyIPv4: + maxPrefix = 32 + case net.AddressFamilyIPv6: + maxPrefix = 128 + default: + return nil, errors.New("unsupported address family") + } + + prefixBits := maxPrefix + if prefixStr != "" { + parsedPrefix, err := strconv.ParseUint(prefixStr, 10, 32) + if err != nil { + return nil, errors.New("invalid CIDR prefix length: ", prefixStr).Base(err) + } + prefixBits = uint32(parsedPrefix) + } + if prefixBits > maxPrefix { + return nil, errors.New("CIDR prefix length ", prefixBits, " exceeds max ", maxPrefix) + } + + return &CIDR{ + Ip: []byte(ipAddr.IP()), + Prefix: prefixBits, + }, nil +} + +func ParseDomainRule(r string, defaultType Domain_Type) (*DomainRule, error) { + if strings.HasPrefix(r, "geosite:") { + r = "ext:" + DefaultGeoSiteDat + ":" + r[len("geosite:"):] + } + + prefix := 0 + for _, ext := range [...]string{"ext:", "ext-domain:"} { + if strings.HasPrefix(r, ext) { + prefix = len(ext) + break + } + } + + var rule isDomainRule_Value + var err error + if prefix > 0 { + rule, err = parseGeoSiteRule(r[prefix:]) + } else { + rule, err = parseCustomDomainRule(r, defaultType) + } + if err != nil { + return nil, errors.New("illegal domain rule: ", r).Base(err) + } + return &DomainRule{Value: rule}, nil +} + +func ParseDomainRules(rules []string, defaultType Domain_Type) ([]*DomainRule, error) { + var domainRules []*DomainRule + + for i, r := range rules { + if strings.HasPrefix(r, "geosite:") { + r = "ext:" + DefaultGeoSiteDat + ":" + r[len("geosite:"):] + } + + prefix := 0 + for _, ext := range [...]string{"ext:", "ext-domain:"} { + if strings.HasPrefix(r, ext) { + prefix = len(ext) + break + } + } + + var rule isDomainRule_Value + var err error + if prefix > 0 { + rule, err = parseGeoSiteRule(r[prefix:]) + } else { + rule, err = parseCustomDomainRule(r, defaultType) + } + if err != nil { + return nil, errors.New("illegal domain rule: ", rules[i]).Base(err) + } + domainRules = append(domainRules, &DomainRule{Value: rule}) + } + + return domainRules, nil +} + +func parseGeoSiteRule(rule string) (*DomainRule_Geosite, error) { + file, codeWithAttrs, ok := strings.Cut(rule, ":") + if !ok { + return nil, errors.New("syntax error") + } + + if file == "" { + return nil, errors.New("empty file") + } + + if strings.HasSuffix(codeWithAttrs, "@") || strings.Contains(codeWithAttrs, "@@") { + return nil, errors.New("empty attr") + } + code, attrs, _ := strings.Cut(codeWithAttrs, "@") + + if code == "" { + return nil, errors.New("empty code") + } + code = strings.ToUpper(code) + + if err := checkFile(file, code); err != nil { + return nil, err + } + + return &DomainRule_Geosite{ + Geosite: &GeoSiteRule{ + File: file, + Code: code, + Attrs: strings.ToLower(attrs), + }, + }, nil +} + +func parseCustomDomainRule(rule string, defaultType Domain_Type) (*DomainRule_Custom, error) { + domain := new(Domain) + + switch { + case strings.HasPrefix(rule, "regexp:"): + domain.Type = Domain_Regex + domain.Value = rule[7:] + + case strings.HasPrefix(rule, "domain:"): + domain.Type = Domain_Domain + domain.Value = rule[7:] + + case strings.HasPrefix(rule, "full:"): + domain.Type = Domain_Full + domain.Value = rule[5:] + + case strings.HasPrefix(rule, "keyword:"): + domain.Type = Domain_Substr + domain.Value = rule[8:] + + case strings.HasPrefix(rule, "dotless:"): + domain.Type = Domain_Regex + switch substr := rule[8:]; { + case substr == "": + domain.Value = "^[^.]*$" + case !strings.Contains(substr, "."): + domain.Value = "^[^.]*" + substr + "[^.]*$" + default: + return nil, errors.New("substr in dotless rule should not contain a dot") + } + + default: + domain.Type = defaultType + domain.Value = rule + } + + return &DomainRule_Custom{ + Custom: domain, + }, nil +} diff --git a/common/geodata/rule_parser_test.go b/common/geodata/rule_parser_test.go new file mode 100644 index 00000000..87dbaaab --- /dev/null +++ b/common/geodata/rule_parser_test.go @@ -0,0 +1,51 @@ +package geodata_test + +import ( + "path/filepath" + "testing" + + "github.com/xtls/xray-core/common/geodata" +) + +func TestParseIPRules(t *testing.T) { + t.Setenv("xray.location.asset", filepath.Join("..", "..", "resources")) + + rules := []string{ + "geoip:us", + "geoip:cn", + "geoip:!cn", + "ext:geoip.dat:!cn", + "ext:geoip.dat:ca", + "ext-ip:geoip.dat:!cn", + "ext-ip:geoip.dat:!ca", + "192.168.0.0/24", + "192.168.0.1", + "fe80::/64", + "fe80::", + } + + _, err := geodata.ParseIPRules(rules) + if err != nil { + t.Fatalf("Failed to parse ip rules, got %s", err) + } +} + +func TestParseDomainRules(t *testing.T) { + t.Setenv("xray.location.asset", filepath.Join("..", "..", "resources")) + + rules := []string{ + "geosite:cn", + "geosite:geolocation-!cn", + "geosite:cn@!cn", + "ext:geosite.dat:geolocation-!cn", + "ext:geosite.dat:cn@!cn", + "ext-site:geosite.dat:geolocation-!cn", + "ext-site:geosite.dat:cn@!cn", + "domain:google.com", + } + + _, err := geodata.ParseDomainRules(rules, geodata.Domain_Domain) + if err != nil { + t.Fatalf("Failed to parse domain rules, got %s", err) + } +} diff --git a/common/geodata/strmatcher/benchmark_indexmatcher_test.go b/common/geodata/strmatcher/benchmark_indexmatcher_test.go new file mode 100644 index 00000000..41a7345a --- /dev/null +++ b/common/geodata/strmatcher/benchmark_indexmatcher_test.go @@ -0,0 +1,58 @@ +package strmatcher_test + +import ( + "testing" + + . "github.com/xtls/xray-core/common/geodata/strmatcher" +) + +func BenchmarkLinearIndexMatcher(b *testing.B) { + benchmarkIndexMatcher(b, func() IndexMatcher { + return NewLinearIndexMatcher() + }) +} + +func BenchmarkMphIndexMatcher(b *testing.B) { + benchmarkIndexMatcher(b, func() IndexMatcher { + return NewMphIndexMatcher() + }) +} + +func benchmarkIndexMatcher(b *testing.B, ctor func() IndexMatcher) { + b.Run("Match", func(b *testing.B) { + b.Run("Domain------------", func(b *testing.B) { + benchmarkMatch(b, ctor(), map[Type]bool{Domain: true}) + }) + b.Run("Domain+Full-------", func(b *testing.B) { + benchmarkMatch(b, ctor(), map[Type]bool{Domain: true, Full: true}) + }) + b.Run("Domain+Full+Substr", func(b *testing.B) { + benchmarkMatch(b, ctor(), map[Type]bool{Domain: true, Full: true, Substr: true}) + }) + b.Run("All-Fail----------", func(b *testing.B) { + benchmarkMatch(b, ctor(), map[Type]bool{Domain: false, Full: false, Substr: false}) + }) + }) + b.Run("Match/Dotless", func(b *testing.B) { // Dotless domain matcher automatically inserted in DNS app when "localhost" DNS is used. + b.Run("All-Succ", func(b *testing.B) { + benchmarkMatch(b, ctor(), map[Type]bool{Domain: true, Full: true, Substr: true, Regex: true}) + }) + b.Run("All-Fail", func(b *testing.B) { + benchmarkMatch(b, ctor(), map[Type]bool{Domain: false, Full: false, Substr: false, Regex: false}) + }) + }) + b.Run("MatchAny", func(b *testing.B) { + b.Run("First-Full--", func(b *testing.B) { + benchmarkMatchAny(b, ctor(), map[Type]bool{Full: true, Domain: true, Substr: true}) + }) + b.Run("First-Domain", func(b *testing.B) { + benchmarkMatchAny(b, ctor(), map[Type]bool{Full: false, Domain: true, Substr: true}) + }) + b.Run("First-Substr", func(b *testing.B) { + benchmarkMatchAny(b, ctor(), map[Type]bool{Full: false, Domain: false, Substr: true}) + }) + b.Run("All-Fail----", func(b *testing.B) { + benchmarkMatchAny(b, ctor(), map[Type]bool{Full: false, Domain: false, Substr: false}) + }) + }) +} diff --git a/common/geodata/strmatcher/benchmark_matchers_test.go b/common/geodata/strmatcher/benchmark_matchers_test.go new file mode 100644 index 00000000..9e00c816 --- /dev/null +++ b/common/geodata/strmatcher/benchmark_matchers_test.go @@ -0,0 +1,149 @@ +package strmatcher_test + +import ( + "strconv" + "testing" + + "github.com/xtls/xray-core/common" + . "github.com/xtls/xray-core/common/geodata/strmatcher" +) + +func BenchmarkFullMatcher(b *testing.B) { + b.Run("SimpleMatcherGroup------", func(b *testing.B) { + benchmarkMatcherType(b, Full, func() MatcherGroup { + return new(SimpleMatcherGroup) + }) + }) + b.Run("FullMatcherGroup--------", func(b *testing.B) { + benchmarkMatcherType(b, Full, func() MatcherGroup { + return NewFullMatcherGroup() + }) + }) + b.Run("ACAutomationMatcherGroup", func(b *testing.B) { + benchmarkMatcherType(b, Full, func() MatcherGroup { + return NewACAutomatonMatcherGroup() + }) + }) + b.Run("MphMatcherGroup---------", func(b *testing.B) { + benchmarkMatcherType(b, Full, func() MatcherGroup { + return NewMphMatcherGroup() + }) + }) +} + +func BenchmarkDomainMatcher(b *testing.B) { + b.Run("SimpleMatcherGroup------", func(b *testing.B) { + benchmarkMatcherType(b, Domain, func() MatcherGroup { + return new(SimpleMatcherGroup) + }) + }) + b.Run("DomainMatcherGroup------", func(b *testing.B) { + benchmarkMatcherType(b, Domain, func() MatcherGroup { + return NewDomainMatcherGroup() + }) + }) + b.Run("ACAutomationMatcherGroup", func(b *testing.B) { + benchmarkMatcherType(b, Domain, func() MatcherGroup { + return NewACAutomatonMatcherGroup() + }) + }) + b.Run("MphMatcherGroup---------", func(b *testing.B) { + benchmarkMatcherType(b, Domain, func() MatcherGroup { + return NewMphMatcherGroup() + }) + }) +} + +func BenchmarkSubstrMatcher(b *testing.B) { + b.Run("SimpleMatcherGroup------", func(b *testing.B) { + benchmarkMatcherType(b, Substr, func() MatcherGroup { + return new(SimpleMatcherGroup) + }) + }) + b.Run("SubstrMatcherGroup------", func(b *testing.B) { + benchmarkMatcherType(b, Substr, func() MatcherGroup { + return new(SubstrMatcherGroup) + }) + }) + b.Run("ACAutomationMatcherGroup", func(b *testing.B) { + benchmarkMatcherType(b, Substr, func() MatcherGroup { + return NewACAutomatonMatcherGroup() + }) + }) +} + +// Utility functions for benchmark + +func benchmarkMatcherType(b *testing.B, t Type, ctor func() MatcherGroup) { + b.Run("Match", func(b *testing.B) { + b.Run("Succ", func(b *testing.B) { + benchmarkMatch(b, ctor(), map[Type]bool{t: true}) + }) + b.Run("Fail", func(b *testing.B) { + benchmarkMatch(b, ctor(), map[Type]bool{t: false}) + }) + }) + b.Run("MatchAny", func(b *testing.B) { + b.Run("Succ", func(b *testing.B) { + benchmarkMatchAny(b, ctor(), map[Type]bool{t: true}) + }) + b.Run("Fail", func(b *testing.B) { + benchmarkMatchAny(b, ctor(), map[Type]bool{t: false}) + }) + }) +} + +func benchmarkMatch(b *testing.B, g MatcherGroup, enabledTypes map[Type]bool) { + prepareMatchers(g, enabledTypes) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = g.Match("0.example.com") + } +} + +func benchmarkMatchAny(b *testing.B, g MatcherGroup, enabledTypes map[Type]bool) { + prepareMatchers(g, enabledTypes) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = g.MatchAny("0.example.com") + } +} + +func prepareMatchers(g MatcherGroup, enabledTypes map[Type]bool) { + for matcherType, hasMatch := range enabledTypes { + switch matcherType { + case Domain: + if hasMatch { + AddMatcherToGroup(g, DomainMatcher("example.com"), 0) + } + for i := 1; i < 1024; i++ { + AddMatcherToGroup(g, DomainMatcher(strconv.Itoa(i)+".example.com"), uint32(i)) + } + case Full: + if hasMatch { + AddMatcherToGroup(g, FullMatcher("0.example.com"), 0) + } + for i := 1; i < 64; i++ { + AddMatcherToGroup(g, FullMatcher(strconv.Itoa(i)+".example.com"), uint32(i)) + } + case Substr: + if hasMatch { + AddMatcherToGroup(g, SubstrMatcher("example.com"), 0) + } + for i := 1; i < 4; i++ { + AddMatcherToGroup(g, SubstrMatcher(strconv.Itoa(i)+".example.com"), uint32(i)) + } + case Regex: + matcher, err := Regex.New("^[^.]*$") // Dotless domain matcher automatically inserted in DNS app when "localhost" DNS is used. + common.Must(err) + AddMatcherToGroup(g, matcher, 0) + } + } + if g, ok := g.(buildable); ok { + common.Must(g.Build()) + } +} + +type buildable interface { + Build() error +} diff --git a/common/geodata/strmatcher/indexmatcher_linear.go b/common/geodata/strmatcher/indexmatcher_linear.go new file mode 100644 index 00000000..dcdc1d33 --- /dev/null +++ b/common/geodata/strmatcher/indexmatcher_linear.go @@ -0,0 +1,96 @@ +package strmatcher + +// LinearIndexMatcher is an implementation of IndexMatcher. +type LinearIndexMatcher struct { + count uint32 + full *FullMatcherGroup + domain *DomainMatcherGroup + substr *SubstrMatcherGroup + regex *SimpleMatcherGroup +} + +func NewLinearIndexMatcher() *LinearIndexMatcher { + return new(LinearIndexMatcher) +} + +// Add implements IndexMatcher.Add. +func (g *LinearIndexMatcher) Add(matcher Matcher) uint32 { + g.count++ + index := g.count + + switch matcher := matcher.(type) { + case FullMatcher: + if g.full == nil { + g.full = NewFullMatcherGroup() + } + g.full.AddFullMatcher(matcher, index) + case DomainMatcher: + if g.domain == nil { + g.domain = NewDomainMatcherGroup() + } + g.domain.AddDomainMatcher(matcher, index) + case SubstrMatcher: + if g.substr == nil { + g.substr = new(SubstrMatcherGroup) + } + g.substr.AddSubstrMatcher(matcher, index) + default: + if g.regex == nil { + g.regex = new(SimpleMatcherGroup) + } + g.regex.AddMatcher(matcher, index) + } + + return index +} + +// Build implements IndexMatcher.Build. +func (*LinearIndexMatcher) Build() error { + return nil +} + +// Match implements IndexMatcher.Match. +func (g *LinearIndexMatcher) Match(input string) []uint32 { + // Allocate capacity to prevent matches escaping to heap + result := make([][]uint32, 0, 5) + if g.full != nil { + if matches := g.full.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + if g.domain != nil { + if matches := g.domain.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + if g.substr != nil { + if matches := g.substr.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + if g.regex != nil { + if matches := g.regex.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + return CompositeMatches(result) +} + +// MatchAny implements IndexMatcher.MatchAny. +func (g *LinearIndexMatcher) MatchAny(input string) bool { + if g.full != nil && g.full.MatchAny(input) { + return true + } + if g.domain != nil && g.domain.MatchAny(input) { + return true + } + if g.substr != nil && g.substr.MatchAny(input) { + return true + } + return g.regex != nil && g.regex.MatchAny(input) +} + +// Size implements IndexMatcher.Size. +func (g *LinearIndexMatcher) Size() uint32 { + return g.count +} diff --git a/common/geodata/strmatcher/indexmatcher_linear_test.go b/common/geodata/strmatcher/indexmatcher_linear_test.go new file mode 100644 index 00000000..2f518ed3 --- /dev/null +++ b/common/geodata/strmatcher/indexmatcher_linear_test.go @@ -0,0 +1,95 @@ +package strmatcher_test + +import ( + "reflect" + "testing" + + "github.com/xtls/xray-core/common" + . "github.com/xtls/xray-core/common/geodata/strmatcher" +) + +// See https://github.com/v2fly/v2ray-core/issues/92#issuecomment-673238489 +func TestLinearIndexMatcher(t *testing.T) { + rules := []struct { + Type Type + Domain string + }{ + { + Type: Regex, + Domain: "apis\\.us$", + }, + { + Type: Substr, + Domain: "apis", + }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Domain, + Domain: "com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + { + Type: Substr, + Domain: "apis", + }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Full, + Domain: "fonts.googleapis.com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + { + Type: Domain, + Domain: "example.com", + }, + } + cases := []struct { + Input string + Output []uint32 + }{ + { + Input: "www.baidu.com", + Output: []uint32{5, 9, 4}, + }, + { + Input: "fonts.googleapis.com", + Output: []uint32{8, 3, 7, 4, 2, 6}, + }, + { + Input: "example.googleapis.com", + Output: []uint32{3, 7, 4, 2, 6}, + }, + { + Input: "testapis.us", + Output: []uint32{2, 6, 1}, + }, + { + Input: "example.com", + Output: []uint32{10, 4}, + }, + } + matcherGroup := NewLinearIndexMatcher() + for _, rule := range rules { + matcher, err := rule.Type.New(rule.Domain) + common.Must(err) + matcherGroup.Add(matcher) + } + matcherGroup.Build() + for _, test := range cases { + if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) { + t.Error("unexpected output: ", m, " for test case ", test) + } + } +} diff --git a/common/geodata/strmatcher/indexmatcher_mph.go b/common/geodata/strmatcher/indexmatcher_mph.go new file mode 100644 index 00000000..b23f8376 --- /dev/null +++ b/common/geodata/strmatcher/indexmatcher_mph.go @@ -0,0 +1,100 @@ +package strmatcher + +import "runtime" + +// A MphIndexMatcher is divided into three parts: +// 1. `full` and `domain` patterns are matched by Rabin-Karp algorithm and minimal perfect hash table; +// 2. `substr` patterns are matched by ac automaton; +// 3. `regex` patterns are matched with the regex library. +type MphIndexMatcher struct { + count uint32 + mph *MphMatcherGroup + ac *ACAutomatonMatcherGroup + regex *SimpleMatcherGroup +} + +func NewMphIndexMatcher() *MphIndexMatcher { + return new(MphIndexMatcher) +} + +// Add implements IndexMatcher.Add. +func (g *MphIndexMatcher) Add(matcher Matcher) uint32 { + g.count++ + index := g.count + + switch matcher := matcher.(type) { + case FullMatcher: + if g.mph == nil { + g.mph = NewMphMatcherGroup() + } + g.mph.AddFullMatcher(matcher, index) + case DomainMatcher: + if g.mph == nil { + g.mph = NewMphMatcherGroup() + } + g.mph.AddDomainMatcher(matcher, index) + case SubstrMatcher: + if g.ac == nil { + g.ac = NewACAutomatonMatcherGroup() + } + g.ac.AddSubstrMatcher(matcher, index) + case *RegexMatcher: + if g.regex == nil { + g.regex = &SimpleMatcherGroup{} + } + g.regex.AddMatcher(matcher, index) + } + + return index +} + +// Build implements IndexMatcher.Build. +func (g *MphIndexMatcher) Build() error { + if g.mph != nil { + runtime.GC() // peak mem + g.mph.Build() + } + runtime.GC() // peak mem + if g.ac != nil { + g.ac.Build() + runtime.GC() // peak mem + } + return nil +} + +// Match implements IndexMatcher.Match. +func (g *MphIndexMatcher) Match(input string) []uint32 { + result := make([][]uint32, 0, 5) + if g.mph != nil { + if matches := g.mph.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + if g.ac != nil { + if matches := g.ac.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + if g.regex != nil { + if matches := g.regex.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + return CompositeMatches(result) +} + +// MatchAny implements IndexMatcher.MatchAny. +func (g *MphIndexMatcher) MatchAny(input string) bool { + if g.mph != nil && g.mph.MatchAny(input) { + return true + } + if g.ac != nil && g.ac.MatchAny(input) { + return true + } + return g.regex != nil && g.regex.MatchAny(input) +} + +// Size implements IndexMatcher.Size. +func (g *MphIndexMatcher) Size() uint32 { + return g.count +} diff --git a/common/geodata/strmatcher/indexmatcher_mph_test.go b/common/geodata/strmatcher/indexmatcher_mph_test.go new file mode 100644 index 00000000..2e1c70dd --- /dev/null +++ b/common/geodata/strmatcher/indexmatcher_mph_test.go @@ -0,0 +1,94 @@ +package strmatcher_test + +import ( + "reflect" + "testing" + + "github.com/xtls/xray-core/common" + . "github.com/xtls/xray-core/common/geodata/strmatcher" +) + +func TestMphIndexMatcher(t *testing.T) { + rules := []struct { + Type Type + Domain string + }{ + { + Type: Regex, + Domain: "apis\\.us$", + }, + { + Type: Substr, + Domain: "apis", + }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Domain, + Domain: "com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + { + Type: Substr, + Domain: "apis", + }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Full, + Domain: "fonts.googleapis.com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + { + Type: Domain, + Domain: "example.com", + }, + } + cases := []struct { + Input string + Output []uint32 + }{ + { + Input: "www.baidu.com", + Output: []uint32{5, 9, 4}, + }, + { + Input: "fonts.googleapis.com", + Output: []uint32{8, 3, 7, 4, 2, 6}, + }, + { + Input: "example.googleapis.com", + Output: []uint32{3, 7, 4, 2, 6}, + }, + { + Input: "testapis.us", + Output: []uint32{2, 6, 1}, + }, + { + Input: "example.com", + Output: []uint32{10, 4}, + }, + } + matcherGroup := NewMphIndexMatcher() + for _, rule := range rules { + matcher, err := rule.Type.New(rule.Domain) + common.Must(err) + matcherGroup.Add(matcher) + } + matcherGroup.Build() + for _, test := range cases { + if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) { + t.Error("unexpected output: ", m, " for test case ", test) + } + } +} diff --git a/common/geodata/strmatcher/matchergroup_ac_automation.go b/common/geodata/strmatcher/matchergroup_ac_automation.go new file mode 100644 index 00000000..bf07648a --- /dev/null +++ b/common/geodata/strmatcher/matchergroup_ac_automation.go @@ -0,0 +1,282 @@ +package strmatcher + +import ( + "container/list" +) + +const ( + acValidCharCount = 39 // aA-zZ (26), 0-9 (10), - (1), . (1), invalid(1) + acMatchTypeCount = 3 // Full, Domain and Substr +) + +type acEdge byte + +const ( + acTrieEdge acEdge = 1 + acFailEdge acEdge = 0 +) + +type acNode struct { + next [acValidCharCount]uint32 // EdgeIdx -> Next NodeIdx (Next trie node or fail node) + edge [acValidCharCount]acEdge // EdgeIdx -> Trie Edge / Fail Edge + fail uint32 // NodeIdx of *next matched* Substr Pattern on its fail path + match uint32 // MatchIdx of matchers registered on this node, 0 indicates no match +} // Sizeof acNode: (4+1)*acValidCharCount + + 4 + 4 + +type acValue [acMatchTypeCount][]uint32 // MatcherType -> Registered Matcher Values + +// ACAutoMationMatcherGroup is an implementation of MatcherGroup. +// It uses an AC Automata to provide support for Full, Domain and Substr matcher. Trie node is char based. +// +// NOTICE: ACAutomatonMatcherGroup currently uses a restricted charset (LDH Subset), +// upstream should manually in a way to ensure all patterns and inputs passed to it to be in this charset. +type ACAutomatonMatcherGroup struct { + nodes []acNode // NodeIdx -> acNode + values []acValue // MatchIdx -> acValue +} + +func NewACAutomatonMatcherGroup() *ACAutomatonMatcherGroup { + ac := new(ACAutomatonMatcherGroup) + ac.addNode() // Create root node (NodeIdx 0) + ac.addMatchEntry() // Create sentinel match entry (MatchIdx 0) + return ac +} + +// AddFullMatcher implements MatcherGroupForFull.AddFullMatcher. +func (ac *ACAutomatonMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) { + ac.addPattern(0, matcher.Pattern(), matcher.Type(), value) +} + +// AddDomainMatcher implements MatcherGroupForDomain.AddDomainMatcher. +func (ac *ACAutomatonMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) { + node := ac.addPattern(0, matcher.Pattern(), matcher.Type(), value) // For full domain match + ac.addPattern(node, ".", matcher.Type(), value) // For partial domain match +} + +// AddSubstrMatcher implements MatcherGroupForSubstr.AddSubstrMatcher. +func (ac *ACAutomatonMatcherGroup) AddSubstrMatcher(matcher SubstrMatcher, value uint32) { + ac.addPattern(0, matcher.Pattern(), matcher.Type(), value) +} + +func (ac *ACAutomatonMatcherGroup) addPattern(nodeIdx uint32, pattern string, matcherType Type, value uint32) uint32 { + node := &ac.nodes[nodeIdx] + for i := len(pattern) - 1; i >= 0; i-- { + edgeIdx := acCharset[pattern[i]] + nextIdx := node.next[edgeIdx] + if nextIdx == 0 { // Add new Trie Edge + nextIdx = ac.addNode() + ac.nodes[nodeIdx].next[edgeIdx] = nextIdx + ac.nodes[nodeIdx].edge[edgeIdx] = acTrieEdge + } + nodeIdx = nextIdx + node = &ac.nodes[nodeIdx] + } + if node.match == 0 { // Add new match entry + node.match = ac.addMatchEntry() + } + ac.values[node.match][matcherType] = append(ac.values[node.match][matcherType], value) + return nodeIdx +} + +func (ac *ACAutomatonMatcherGroup) addNode() uint32 { + ac.nodes = append(ac.nodes, acNode{}) + return uint32(len(ac.nodes) - 1) +} + +func (ac *ACAutomatonMatcherGroup) addMatchEntry() uint32 { + ac.values = append(ac.values, acValue{}) + return uint32(len(ac.values) - 1) +} + +func (ac *ACAutomatonMatcherGroup) Build() error { + fail := make([]uint32, len(ac.nodes)) + queue := list.New() + for edgeIdx := 0; edgeIdx < acValidCharCount; edgeIdx++ { + if nextIdx := ac.nodes[0].next[edgeIdx]; nextIdx != 0 { + queue.PushBack(nextIdx) + } + } + for { + front := queue.Front() + if front == nil { + break + } + queue.Remove(front) + nodeIdx := front.Value.(uint32) + node := &ac.nodes[nodeIdx] // Current node + failNode := &ac.nodes[fail[nodeIdx]] // Fail node of currrent node + for edgeIdx := 0; edgeIdx < acValidCharCount; edgeIdx++ { + nodeIdx := node.next[edgeIdx] // Next node through trie edge + failIdx := failNode.next[edgeIdx] // Next node through fail edge + if nodeIdx != 0 { + queue.PushBack(nodeIdx) + fail[nodeIdx] = failIdx + if match := ac.nodes[failIdx].match; match != 0 && len(ac.values[match][Substr]) > 0 { // Fail node is a Substr match node + ac.nodes[nodeIdx].fail = failIdx + } else { // Use path compression to reduce fail path to only contain match nodes + ac.nodes[nodeIdx].fail = ac.nodes[failIdx].fail + } + } else { // Add new fail edge + node.next[edgeIdx] = failIdx + node.edge[edgeIdx] = acFailEdge + } + } + } + return nil +} + +// Match implements MatcherGroup.Match. +func (ac *ACAutomatonMatcherGroup) Match(input string) []uint32 { + suffixMatches := make([][]uint32, 0, 5) + substrMatches := make([][]uint32, 0, 5) + fullMatch := true // fullMatch indicates no fail edge traversed so far. + node := &ac.nodes[0] // start from root node. + // 1. the match string is all through trie edge. FULL MATCH or DOMAIN + // 2. the match string is through a fail edge. NOT FULL MATCH + // 2.1 Through a fail edge, but there exists a valid node. SUBSTR + for i := len(input) - 1; i >= 0; i-- { + edge := acCharset[input[i]] + fullMatch = fullMatch && (node.edge[edge] == acTrieEdge) + node = &ac.nodes[node.next[edge]] // Advance to next node + // When entering a new node, traverse the fail path to find all possible Substr patterns: + // 1. The fail path is compressed to only contains match nodes and root node (for terminate condition). + // 2. node.fail != 0 is added here for better performance (as shown by benchmark), possibly it helps branch prediction. + if node.fail != 0 { + for failIdx, failNode := node.fail, &ac.nodes[node.fail]; failIdx != 0; failIdx, failNode = failNode.fail, &ac.nodes[failIdx] { + substrMatches = append(substrMatches, ac.values[failNode.match][Substr]) + } + } + // When entering a new node, check whether this node is a match. + // For Substr matchers: + // 1. Matched in any situation, whether a failNode edge is traversed or not. + // For Domain matchers: + // 1. Should not traverse any fail edge (fullMatch). + // 2. Only check on dot separator (input[i] == '.'). + if node.match != 0 { + values := ac.values[node.match] + if len(values[Substr]) > 0 { + substrMatches = append(substrMatches, values[Substr]) + } + if fullMatch && input[i] == '.' && len(values[Domain]) > 0 { + suffixMatches = append(suffixMatches, values[Domain]) + } + } + } + // At the end of input, check if the whole string matches a pattern. + // For Domain matchers: + // 1. Exact match on Domain Matcher works like Full Match. e.g. foo.com is a full match for domain:foo.com. + // For Full matchers: + // 1. Only when no fail edge is traversed (fullMatch). + // 2. Takes the highest priority (added at last). + if fullMatch && node.match != 0 { + values := ac.values[node.match] + if len(values[Domain]) > 0 { + suffixMatches = append(suffixMatches, values[Domain]) + } + if len(values[Full]) > 0 { + suffixMatches = append(suffixMatches, values[Full]) + } + } + if len(substrMatches) == 0 { + return CompositeMatchesReverse(suffixMatches) + } + return CompositeMatchesReverse(append(substrMatches, suffixMatches...)) +} + +// MatchAny implements MatcherGroup.MatchAny. +func (ac *ACAutomatonMatcherGroup) MatchAny(input string) bool { + fullMatch := true + node := &ac.nodes[0] + for i := len(input) - 1; i >= 0; i-- { + edge := acCharset[input[i]] + fullMatch = fullMatch && (node.edge[edge] == acTrieEdge) + node = &ac.nodes[node.next[edge]] + if node.fail != 0 { // There is a match on this node's fail path + return true + } + if node.match != 0 { // There is a match on this node + values := ac.values[node.match] + if len(values[Substr]) > 0 { // Substr match succeeds unconditionally + return true + } + if fullMatch && input[i] == '.' && len(values[Domain]) > 0 { // Domain match only succeeds with dot separator on trie path + return true + } + } + } + return fullMatch && node.match != 0 // At the end of input, Domain and Full match will succeed if no fail edge is traversed +} + +// Letter-Digit-Hyphen (LDH) subset (https://tools.ietf.org/html/rfc952): +// - Letters A to Z (no distinction is made between uppercase and lowercase) +// - Digits 0 to 9 +// - Hyphens(-) and Periods(.) +// +// If for future the strmatcher are used for other scenarios than domain, +// we could add a new Charset interface to represent variable charsets. +var acCharset = [256]int{ + 'A': 1, + 'a': 1, + 'B': 2, + 'b': 2, + 'C': 3, + 'c': 3, + 'D': 4, + 'd': 4, + 'E': 5, + 'e': 5, + 'F': 6, + 'f': 6, + 'G': 7, + 'g': 7, + 'H': 8, + 'h': 8, + 'I': 9, + 'i': 9, + 'J': 10, + 'j': 10, + 'K': 11, + 'k': 11, + 'L': 12, + 'l': 12, + 'M': 13, + 'm': 13, + 'N': 14, + 'n': 14, + 'O': 15, + 'o': 15, + 'P': 16, + 'p': 16, + 'Q': 17, + 'q': 17, + 'R': 18, + 'r': 18, + 'S': 19, + 's': 19, + 'T': 20, + 't': 20, + 'U': 21, + 'u': 21, + 'V': 22, + 'v': 22, + 'W': 23, + 'w': 23, + 'X': 24, + 'x': 24, + 'Y': 25, + 'y': 25, + 'Z': 26, + 'z': 26, + '-': 27, + '.': 28, + '0': 29, + '1': 30, + '2': 31, + '3': 32, + '4': 33, + '5': 34, + '6': 35, + '7': 36, + '8': 37, + '9': 38, +} diff --git a/common/geodata/strmatcher/matchergroup_ac_automation_test.go b/common/geodata/strmatcher/matchergroup_ac_automation_test.go new file mode 100644 index 00000000..f0b88962 --- /dev/null +++ b/common/geodata/strmatcher/matchergroup_ac_automation_test.go @@ -0,0 +1,365 @@ +package strmatcher_test + +import ( + "reflect" + "testing" + + "github.com/xtls/xray-core/common" + . "github.com/xtls/xray-core/common/geodata/strmatcher" +) + +func TestACAutomatonMatcherGroup(t *testing.T) { + cases1 := []struct { + pattern string + mType Type + input string + output bool + }{ + { + pattern: "example.com", + mType: Domain, + input: "www.example.com", + output: true, + }, + { + pattern: "example.com", + mType: Domain, + input: "example.com", + output: true, + }, + { + pattern: "example.com", + mType: Domain, + input: "www.e3ample.com", + output: false, + }, + { + pattern: "example.com", + mType: Domain, + input: "xample.com", + output: false, + }, + { + pattern: "example.com", + mType: Domain, + input: "xexample.com", + output: false, + }, + { + pattern: "example.com", + mType: Full, + input: "example.com", + output: true, + }, + { + pattern: "example.com", + mType: Full, + input: "xexample.com", + output: false, + }, + } + for _, test := range cases1 { + ac := NewACAutomatonMatcherGroup() + matcher, err := test.mType.New(test.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(ac, matcher, 0)) + ac.Build() + if m := ac.MatchAny(test.input); m != test.output { + t.Error("unexpected output: ", m, " for test case ", test) + } + } + { + cases2Input := []struct { + pattern string + mType Type + }{ + { + pattern: "163.com", + mType: Domain, + }, + { + pattern: "m.126.com", + mType: Full, + }, + { + pattern: "3.com", + mType: Full, + }, + { + pattern: "google.com", + mType: Substr, + }, + { + pattern: "vgoogle.com", + mType: Substr, + }, + } + ac := NewACAutomatonMatcherGroup() + for _, test := range cases2Input { + matcher, err := test.mType.New(test.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(ac, matcher, 0)) + } + ac.Build() + cases2Output := []struct { + pattern string + res bool + }{ + { + pattern: "126.com", + res: false, + }, + { + pattern: "m.163.com", + res: true, + }, + { + pattern: "mm163.com", + res: false, + }, + { + pattern: "m.126.com", + res: true, + }, + { + pattern: "163.com", + res: true, + }, + { + pattern: "63.com", + res: false, + }, + { + pattern: "oogle.com", + res: false, + }, + { + pattern: "vvgoogle.com", + res: true, + }, + } + for _, test := range cases2Output { + if m := ac.MatchAny(test.pattern); m != test.res { + t.Error("unexpected output: ", m, " for test case ", test) + } + } + } + + { + cases3Input := []struct { + pattern string + mType Type + }{ + { + pattern: "video.google.com", + mType: Domain, + }, + { + pattern: "gle.com", + mType: Domain, + }, + } + ac := NewACAutomatonMatcherGroup() + for _, test := range cases3Input { + matcher, err := test.mType.New(test.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(ac, matcher, 0)) + } + ac.Build() + cases3Output := []struct { + pattern string + res bool + }{ + { + pattern: "google.com", + res: false, + }, + } + for _, test := range cases3Output { + if m := ac.MatchAny(test.pattern); m != test.res { + t.Error("unexpected output: ", m, " for test case ", test) + } + } + } + + { + cases4Input := []struct { + pattern string + mType Type + }{ + { + pattern: "apis", + mType: Substr, + }, + { + pattern: "googleapis.com", + mType: Domain, + }, + } + ac := NewACAutomatonMatcherGroup() + for _, test := range cases4Input { + matcher, err := test.mType.New(test.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(ac, matcher, 0)) + } + ac.Build() + cases4Output := []struct { + pattern string + res bool + }{ + { + pattern: "gapis.com", + res: true, + }, + } + for _, test := range cases4Output { + if m := ac.MatchAny(test.pattern); m != test.res { + t.Error("unexpected output: ", m, " for test case ", test) + } + } + } +} + +func TestACAutomatonMatcherGroupSubstr(t *testing.T) { + patterns := []struct { + pattern string + mType Type + }{ + { + pattern: "apis", + mType: Substr, + }, + { + pattern: "google", + mType: Substr, + }, + { + pattern: "apis", + mType: Substr, + }, + } + cases := []struct { + input string + output []uint32 + }{ + { + input: "google.com", + output: []uint32{1}, + }, + { + input: "apis.com", + output: []uint32{0, 2}, + }, + { + input: "googleapis.com", + output: []uint32{1, 0, 2}, + }, + { + input: "fonts.googleapis.com", + output: []uint32{1, 0, 2}, + }, + { + input: "apis.googleapis.com", + output: []uint32{0, 2, 1, 0, 2}, + }, + } + matcherGroup := NewACAutomatonMatcherGroup() + for id, entry := range patterns { + matcher, err := entry.mType.New(entry.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(id))) + } + matcherGroup.Build() + for _, test := range cases { + if r := matcherGroup.Match(test.input); !reflect.DeepEqual(r, test.output) { + t.Error("unexpected output: ", r, " for test case ", test) + } + } +} + +// See https://github.com/v2fly/v2ray-core/issues/92#issuecomment-673238489 +func TestACAutomatonMatcherGroupAsIndexMatcher(t *testing.T) { + rules := []struct { + Type Type + Domain string + }{ + // Regex not supported by ACAutomationMatcherGroup + // { + // Type: Regex, + // Domain: "apis\\.us$", + // }, + { + Type: Substr, + Domain: "apis", + }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Domain, + Domain: "com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + { + Type: Substr, + Domain: "apis", + }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Full, + Domain: "fonts.googleapis.com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + { + Type: Domain, + Domain: "example.com", + }, + } + cases := []struct { + Input string + Output []uint32 + }{ + { + Input: "www.baidu.com", + Output: []uint32{5, 9, 4}, + }, + { + Input: "fonts.googleapis.com", + Output: []uint32{8, 3, 7, 4, 2, 6}, + }, + { + Input: "example.googleapis.com", + Output: []uint32{3, 7, 4, 2, 6}, + }, + { + Input: "testapis.us", + Output: []uint32{2, 6 /*, 1*/}, + }, + { + Input: "example.com", + Output: []uint32{10, 4}, + }, + } + matcherGroup := NewACAutomatonMatcherGroup() + for i, rule := range rules { + matcher, err := rule.Type.New(rule.Domain) + common.Must(err) + common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(i+2))) + } + matcherGroup.Build() + for _, test := range cases { + if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) { + t.Error("unexpected output: ", m, " for test case ", test) + } + } +} diff --git a/common/geodata/strmatcher/matchergroup_domain.go b/common/geodata/strmatcher/matchergroup_domain.go new file mode 100644 index 00000000..29334214 --- /dev/null +++ b/common/geodata/strmatcher/matchergroup_domain.go @@ -0,0 +1,109 @@ +package strmatcher + +type trieNode struct { + values []uint32 + children map[string]*trieNode +} + +// DomainMatcherGroup is an implementation of MatcherGroup. +// It uses trie to optimize both memory consumption and lookup speed. Trie node is domain label based. +type DomainMatcherGroup struct { + root *trieNode +} + +func NewDomainMatcherGroup() *DomainMatcherGroup { + return &DomainMatcherGroup{ + root: new(trieNode), + } +} + +// AddDomainMatcher implements MatcherGroupForDomain.AddDomainMatcher. +func (g *DomainMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) { + node := g.root + pattern := matcher.Pattern() + for i := len(pattern); i > 0; { + var part string + for j := i - 1; ; j-- { + if pattern[j] == '.' { + part = pattern[j+1 : i] + i = j + break + } + if j == 0 { + part = pattern[j:i] + i = j + break + } + } + if node.children == nil { + node.children = make(map[string]*trieNode) + } + next := node.children[part] + if next == nil { + next = new(trieNode) + node.children[part] = next + } + node = next + } + + node.values = append(node.values, value) +} + +// Match implements MatcherGroup.Match. +func (g *DomainMatcherGroup) Match(input string) []uint32 { + matches := make([][]uint32, 0, 5) + node := g.root + for i := len(input); i > 0; { + for j := i - 1; ; j-- { + if input[j] == '.' { // Domain label found + node = node.children[input[j+1:i]] + i = j + break + } + if j == 0 { // The last part of domain label + node = node.children[input[j:i]] + i = j + break + } + } + if node == nil { // No more match if no trie edge transition + break + } + if len(node.values) > 0 { // Found matched matchers + matches = append(matches, node.values) + } + if node.children == nil { // No more match if leaf node reached + break + } + } + return CompositeMatchesReverse(matches) +} + +// MatchAny implements MatcherGroup.MatchAny. +func (g *DomainMatcherGroup) MatchAny(input string) bool { + node := g.root + for i := len(input); i > 0; { + for j := i - 1; ; j-- { + if input[j] == '.' { + node = node.children[input[j+1:i]] + i = j + break + } + if j == 0 { + node = node.children[input[j:i]] + i = j + break + } + } + if node == nil { + return false + } + if len(node.values) > 0 { + return true + } + if node.children == nil { + return false + } + } + return false +} diff --git a/common/strmatcher/domain_matcher_test.go b/common/geodata/strmatcher/matchergroup_domain_test.go similarity index 61% rename from common/strmatcher/domain_matcher_test.go rename to common/geodata/strmatcher/matchergroup_domain_test.go index 5a8ca35b..f9f6266d 100644 --- a/common/strmatcher/domain_matcher_test.go +++ b/common/geodata/strmatcher/matchergroup_domain_test.go @@ -4,19 +4,43 @@ import ( "reflect" "testing" - . "github.com/xtls/xray-core/common/strmatcher" + . "github.com/xtls/xray-core/common/geodata/strmatcher" ) func TestDomainMatcherGroup(t *testing.T) { - g := new(DomainMatcherGroup) - g.Add("example.com", 1) - g.Add("google.com", 2) - g.Add("x.a.com", 3) - g.Add("a.b.com", 4) - g.Add("c.a.b.com", 5) - g.Add("x.y.com", 4) - g.Add("x.y.com", 6) - + patterns := []struct { + Pattern string + Value uint32 + }{ + { + Pattern: "example.com", + Value: 1, + }, + { + Pattern: "google.com", + Value: 2, + }, + { + Pattern: "x.a.com", + Value: 3, + }, + { + Pattern: "a.b.com", + Value: 4, + }, + { + Pattern: "c.a.b.com", + Value: 5, + }, + { + Pattern: "x.y.com", + Value: 4, + }, + { + Pattern: "x.y.com", + Value: 6, + }, + } testCases := []struct { Domain string Result []uint32 @@ -58,7 +82,10 @@ func TestDomainMatcherGroup(t *testing.T) { Result: []uint32{4, 6}, }, } - + g := NewDomainMatcherGroup() + for _, pattern := range patterns { + AddMatcherToGroup(g, DomainMatcher(pattern.Pattern), pattern.Value) + } for _, testCase := range testCases { r := g.Match(testCase.Domain) if !reflect.DeepEqual(r, testCase.Result) { @@ -68,7 +95,7 @@ func TestDomainMatcherGroup(t *testing.T) { } func TestEmptyDomainMatcherGroup(t *testing.T) { - g := new(DomainMatcherGroup) + g := NewDomainMatcherGroup() r := g.Match("example.com") if len(r) != 0 { t.Error("Expect [], but ", r) diff --git a/common/geodata/strmatcher/matchergroup_full.go b/common/geodata/strmatcher/matchergroup_full.go new file mode 100644 index 00000000..85057b36 --- /dev/null +++ b/common/geodata/strmatcher/matchergroup_full.go @@ -0,0 +1,30 @@ +package strmatcher + +// FullMatcherGroup is an implementation of MatcherGroup. +// It uses a hash table to facilitate exact match lookup. +type FullMatcherGroup struct { + matchers map[string][]uint32 +} + +func NewFullMatcherGroup() *FullMatcherGroup { + return &FullMatcherGroup{ + matchers: make(map[string][]uint32), + } +} + +// AddFullMatcher implements MatcherGroupForFull.AddFullMatcher. +func (g *FullMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) { + domain := matcher.Pattern() + g.matchers[domain] = append(g.matchers[domain], value) +} + +// Match implements MatcherGroup.Match. +func (g *FullMatcherGroup) Match(input string) []uint32 { + return g.matchers[input] +} + +// MatchAny implements MatcherGroup.Any. +func (g *FullMatcherGroup) MatchAny(input string) bool { + _, found := g.matchers[input] + return found +} diff --git a/common/strmatcher/full_matcher_test.go b/common/geodata/strmatcher/matchergroup_full_test.go similarity index 56% rename from common/strmatcher/full_matcher_test.go rename to common/geodata/strmatcher/matchergroup_full_test.go index 73d60d51..8e754921 100644 --- a/common/strmatcher/full_matcher_test.go +++ b/common/geodata/strmatcher/matchergroup_full_test.go @@ -4,17 +4,35 @@ import ( "reflect" "testing" - . "github.com/xtls/xray-core/common/strmatcher" + . "github.com/xtls/xray-core/common/geodata/strmatcher" ) func TestFullMatcherGroup(t *testing.T) { - g := new(FullMatcherGroup) - g.Add("example.com", 1) - g.Add("google.com", 2) - g.Add("x.a.com", 3) - g.Add("x.y.com", 4) - g.Add("x.y.com", 6) - + patterns := []struct { + Pattern string + Value uint32 + }{ + { + Pattern: "example.com", + Value: 1, + }, + { + Pattern: "google.com", + Value: 2, + }, + { + Pattern: "x.a.com", + Value: 3, + }, + { + Pattern: "x.y.com", + Value: 4, + }, + { + Pattern: "x.y.com", + Value: 6, + }, + } testCases := []struct { Domain string Result []uint32 @@ -32,7 +50,10 @@ func TestFullMatcherGroup(t *testing.T) { Result: []uint32{4, 6}, }, } - + g := NewFullMatcherGroup() + for _, pattern := range patterns { + AddMatcherToGroup(g, FullMatcher(pattern.Pattern), pattern.Value) + } for _, testCase := range testCases { r := g.Match(testCase.Domain) if !reflect.DeepEqual(r, testCase.Result) { @@ -42,7 +63,7 @@ func TestFullMatcherGroup(t *testing.T) { } func TestEmptyFullMatcherGroup(t *testing.T) { - g := new(FullMatcherGroup) + g := NewFullMatcherGroup() r := g.Match("example.com") if len(r) != 0 { t.Error("Expect [], but ", r) diff --git a/common/geodata/strmatcher/matchergroup_mph.go b/common/geodata/strmatcher/matchergroup_mph.go new file mode 100644 index 00000000..ebf5ff7f --- /dev/null +++ b/common/geodata/strmatcher/matchergroup_mph.go @@ -0,0 +1,198 @@ +package strmatcher + +import ( + "math/bits" + "runtime" + "sort" + "strings" + "unsafe" +) + +// PrimeRK is the prime base used in Rabin-Karp algorithm. +const PrimeRK = 16777619 + +// RollingHash calculates the rolling murmurHash of given string based on a provided suffix hash. +func RollingHash(hash uint32, input string) uint32 { + for i := len(input) - 1; i >= 0; i-- { + hash = hash*PrimeRK + uint32(input[i]) + } + return hash +} + +// MemHash is the hash function used by go map, it utilizes available hardware instructions(behaves +// as aeshash if aes instruction is available). +// With different seed, each MemHash performs as distinct hash functions. +func MemHash(seed uint32, input string) uint32 { + return uint32(strhash(unsafe.Pointer(&input), uintptr(seed))) // nosemgrep +} + +const ( + mphMatchTypeCount = 2 // Full and Domain +) + +type mphRuleInfo struct { + rollingHash uint32 + matchers [mphMatchTypeCount][]uint32 +} + +// MphMatcherGroup is an implementation of MatcherGroup. +// It implements Rabin-Karp algorithm and minimal perfect hash table for Full and Domain matcher. +type MphMatcherGroup struct { + rules []string // RuleIdx -> pattern string, index 0 reserved for failed lookup + values [][]uint32 // RuleIdx -> registered matcher values for the pattern (Full Matcher takes precedence) + level0 []uint32 // RollingHash & Mask -> seed for Memhash + level0Mask uint32 // Mask restricting RollingHash to 0 ~ len(level0) + level1 []uint32 // Memhash & Mask -> stored index for rules + level1Mask uint32 // Mask for restricting Memhash to 0 ~ len(level1) + ruleInfos *map[string]mphRuleInfo +} + +func NewMphMatcherGroup() *MphMatcherGroup { + return &MphMatcherGroup{ + rules: []string{""}, + values: [][]uint32{nil}, + level0: nil, + level0Mask: 0, + level1: nil, + level1Mask: 0, + ruleInfos: &map[string]mphRuleInfo{}, // Only used for building, destroyed after build complete + } +} + +// AddFullMatcher implements MatcherGroupForFull. +func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) { + pattern := strings.ToLower(matcher.Pattern()) + g.addPattern(0, "", pattern, matcher.Type(), value) +} + +// AddDomainMatcher implements MatcherGroupForDomain. +func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) { + pattern := strings.ToLower(matcher.Pattern()) + hash := g.addPattern(0, "", pattern, matcher.Type(), value) // For full domain match + g.addPattern(hash, pattern, ".", matcher.Type(), value) // For partial domain match +} + +func (g *MphMatcherGroup) addPattern(suffixHash uint32, suffixPattern string, pattern string, matcherType Type, value uint32) uint32 { + fullPattern := pattern + suffixPattern + info, found := (*g.ruleInfos)[fullPattern] + if !found { + info = mphRuleInfo{rollingHash: RollingHash(suffixHash, pattern)} + g.rules = append(g.rules, fullPattern) + g.values = append(g.values, nil) + } + info.matchers[matcherType] = append(info.matchers[matcherType], value) + (*g.ruleInfos)[fullPattern] = info + return info.rollingHash +} + +// Build builds a minimal perfect hash table for insert rules. +// Algorithm used: Hash, displace, and compress. See http://cmph.sourceforge.net/papers/esa09.pdf +func (g *MphMatcherGroup) Build() error { + ruleCount := len(*g.ruleInfos) + g.level0 = make([]uint32, nextPow2(ruleCount/4)) + g.level0Mask = uint32(len(g.level0) - 1) + g.level1 = make([]uint32, nextPow2(ruleCount)) + g.level1Mask = uint32(len(g.level1) - 1) + + // Create buckets based on all rule's rolling hash + buckets := make([][]uint32, len(g.level0)) + for ruleIdx := 1; ruleIdx < len(g.rules); ruleIdx++ { // Traverse rules starting from index 1 (0 reserved for failed lookup) + ruleInfo := (*g.ruleInfos)[g.rules[ruleIdx]] + bucketIdx := ruleInfo.rollingHash & g.level0Mask + buckets[bucketIdx] = append(buckets[bucketIdx], uint32(ruleIdx)) + g.values[ruleIdx] = append(ruleInfo.matchers[Full], ruleInfo.matchers[Domain]...) // nolint:gocritic + } + g.ruleInfos = nil // Set ruleInfos nil to release memory + runtime.GC() // peak mem + + // Sort buckets in descending order with respect to each bucket's size + bucketIdxs := make([]int, len(buckets)) + for bucketIdx := range buckets { + bucketIdxs[bucketIdx] = bucketIdx + } + sort.Slice(bucketIdxs, func(i, j int) bool { return len(buckets[bucketIdxs[i]]) > len(buckets[bucketIdxs[j]]) }) + + // Exercise Hash, Displace, and Compress algorithm to construct minimal perfect hash table + occupied := make([]bool, len(g.level1)) // Whether a second-level hash has been already used + hashedBucket := make([]uint32, 0, 4) // Second-level hashes for each rule in a specific bucket + for _, bucketIdx := range bucketIdxs { + bucket := buckets[bucketIdx] + hashedBucket = hashedBucket[:0] + seed := uint32(0) + for len(hashedBucket) != len(bucket) { + for _, ruleIdx := range bucket { + memHash := MemHash(seed, g.rules[ruleIdx]) & g.level1Mask + if occupied[memHash] { // Collision occurred with this seed + for _, hash := range hashedBucket { // Revert all values in this hashed bucket + occupied[hash] = false + g.level1[hash] = 0 + } + hashedBucket = hashedBucket[:0] + seed++ // Try next seed + break + } + occupied[memHash] = true + g.level1[memHash] = ruleIdx // The final value in the hash table + hashedBucket = append(hashedBucket, memHash) + } + } + g.level0[bucketIdx] = seed // Displacement value for this bucket + } + return nil +} + +// Lookup searches for input in minimal perfect hash table and returns its index. 0 indicates not found. +func (g *MphMatcherGroup) Lookup(rollingHash uint32, input string) uint32 { + i0 := rollingHash & g.level0Mask + seed := g.level0[i0] + i1 := MemHash(seed, input) & g.level1Mask + if n := g.level1[i1]; g.rules[n] == input { + return n + } + return 0 +} + +// Match implements MatcherGroup.Match. +func (g *MphMatcherGroup) Match(input string) []uint32 { + matches := make([][]uint32, 0, 5) + hash := uint32(0) + for i := len(input) - 1; i >= 0; i-- { + hash = hash*PrimeRK + uint32(input[i]) + if input[i] == '.' { + if mphIdx := g.Lookup(hash, input[i:]); mphIdx != 0 { + matches = append(matches, g.values[mphIdx]) + } + } + } + if mphIdx := g.Lookup(hash, input); mphIdx != 0 { + matches = append(matches, g.values[mphIdx]) + } + return CompositeMatchesReverse(matches) +} + +// MatchAny implements MatcherGroup.MatchAny. +func (g *MphMatcherGroup) MatchAny(input string) bool { + hash := uint32(0) + for i := len(input) - 1; i >= 0; i-- { + hash = hash*PrimeRK + uint32(input[i]) + if input[i] == '.' { + if g.Lookup(hash, input[i:]) != 0 { + return true + } + } + } + return g.Lookup(hash, input) != 0 +} + +func nextPow2(v int) int { + if v <= 1 { + return 1 + } + const MaxUInt = ^uint(0) + n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1 + return int(n) +} + +//go:noescape +//go:linkname strhash runtime.strhash +func strhash(p unsafe.Pointer, h uintptr) uintptr diff --git a/common/strmatcher/strmatcher_test.go b/common/geodata/strmatcher/matchergroup_mph_test.go similarity index 56% rename from common/strmatcher/strmatcher_test.go rename to common/geodata/strmatcher/matchergroup_mph_test.go index 408ae628..f710c5b9 100644 --- a/common/strmatcher/strmatcher_test.go +++ b/common/geodata/strmatcher/matchergroup_mph_test.go @@ -5,94 +5,10 @@ import ( "testing" "github.com/xtls/xray-core/common" - . "github.com/xtls/xray-core/common/strmatcher" + . "github.com/xtls/xray-core/common/geodata/strmatcher" ) -func TestMatcherGroup(t *testing.T) { - rules := []struct { - Type Type - Domain string - }{ - { - Type: Regex, - Domain: "apis\\.us$", - }, - { - Type: Substr, - Domain: "apis", - }, - { - Type: Domain, - Domain: "googleapis.com", - }, - { - Type: Domain, - Domain: "com", - }, - { - Type: Full, - Domain: "www.baidu.com", - }, - { - Type: Substr, - Domain: "apis", - }, - { - Type: Domain, - Domain: "googleapis.com", - }, - { - Type: Full, - Domain: "fonts.googleapis.com", - }, - { - Type: Full, - Domain: "www.baidu.com", - }, - { - Type: Domain, - Domain: "example.com", - }, - } - cases := []struct { - Input string - Output []uint32 - }{ - { - Input: "www.baidu.com", - Output: []uint32{5, 9, 4}, - }, - { - Input: "fonts.googleapis.com", - Output: []uint32{8, 3, 7, 4, 2, 6}, - }, - { - Input: "example.googleapis.com", - Output: []uint32{3, 7, 4, 2, 6}, - }, - { - Input: "testapis.us", - Output: []uint32{1, 2, 6}, - }, - { - Input: "example.com", - Output: []uint32{10, 4}, - }, - } - matcherGroup := &MatcherGroup{} - for _, rule := range rules { - matcher, err := rule.Type.New(rule.Domain) - common.Must(err) - matcherGroup.Add(matcher) - } - for _, test := range cases { - if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) { - t.Error("unexpected output: ", m, " for test case ", test) - } - } -} - -func TestACAutomaton(t *testing.T) { +func TestMphMatcherGroup(t *testing.T) { cases1 := []struct { pattern string mType Type @@ -100,53 +16,55 @@ func TestACAutomaton(t *testing.T) { output bool }{ { - pattern: "xtls.github.io", + pattern: "example.com", mType: Domain, - input: "www.xtls.github.io", + input: "www.example.com", output: true, }, { - pattern: "xtls.github.io", + pattern: "example.com", mType: Domain, - input: "xtls.github.io", + input: "example.com", output: true, }, { - pattern: "xtls.github.io", + pattern: "example.com", mType: Domain, - input: "www.xtis.github.io", + input: "www.e3ample.com", output: false, }, { - pattern: "xtls.github.io", + pattern: "example.com", mType: Domain, - input: "tls.github.io", + input: "xample.com", output: false, }, { - pattern: "xtls.github.io", + pattern: "example.com", mType: Domain, - input: "xxtls.github.io", + input: "xexample.com", output: false, }, { - pattern: "xtls.github.io", + pattern: "example.com", mType: Full, - input: "xtls.github.io", + input: "example.com", output: true, }, { - pattern: "xtls.github.io", + pattern: "example.com", mType: Full, - input: "xxtls.github.io", + input: "xexample.com", output: false, }, } for _, test := range cases1 { - ac := NewACAutomaton() - ac.Add(test.pattern, test.mType) - ac.Build() - if m := ac.Match(test.input); m != test.output { + mph := NewMphMatcherGroup() + matcher, err := test.mType.New(test.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(mph, matcher, 0)) + mph.Build() + if m := mph.MatchAny(test.input); m != test.output { t.Error("unexpected output: ", m, " for test case ", test) } } @@ -167,20 +85,14 @@ func TestACAutomaton(t *testing.T) { pattern: "3.com", mType: Full, }, - { - pattern: "google.com", - mType: Substr, - }, - { - pattern: "vgoogle.com", - mType: Substr, - }, } - ac := NewACAutomaton() + mph := NewMphMatcherGroup() for _, test := range cases2Input { - ac.Add(test.pattern, test.mType) + matcher, err := test.mType.New(test.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(mph, matcher, 0)) } - ac.Build() + mph.Build() cases2Output := []struct { pattern string res bool @@ -215,15 +127,11 @@ func TestACAutomaton(t *testing.T) { }, { pattern: "vvgoogle.com", - res: true, - }, - { - pattern: "½", res: false, }, } for _, test := range cases2Output { - if m := ac.Match(test.pattern); m != test.res { + if m := mph.MatchAny(test.pattern); m != test.res { t.Error("unexpected output: ", m, " for test case ", test) } } @@ -242,11 +150,13 @@ func TestACAutomaton(t *testing.T) { mType: Domain, }, } - ac := NewACAutomaton() + mph := NewMphMatcherGroup() for _, test := range cases3Input { - ac.Add(test.pattern, test.mType) + matcher, err := test.mType.New(test.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(mph, matcher, 0)) } - ac.Build() + mph.Build() cases3Output := []struct { pattern string res bool @@ -257,9 +167,112 @@ func TestACAutomaton(t *testing.T) { }, } for _, test := range cases3Output { - if m := ac.Match(test.pattern); m != test.res { + if m := mph.MatchAny(test.pattern); m != test.res { t.Error("unexpected output: ", m, " for test case ", test) } } } } + +// See https://github.com/v2fly/v2ray-core/issues/92#issuecomment-673238489 +func TestMphMatcherGroupAsIndexMatcher(t *testing.T) { + rules := []struct { + Type Type + Domain string + }{ + // Regex not supported by MphMatcherGroup + // { + // Type: Regex, + // Domain: "apis\\.us$", + // }, + // Substr not supported by MphMatcherGroup + // { + // Type: Substr, + // Domain: "apis", + // }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Domain, + Domain: "com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + // Substr not supported by MphMatcherGroup, We add another matcher to preserve index + { + Type: Domain, // Substr, + Domain: "example.com", // "apis", + }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Full, + Domain: "fonts.googleapis.com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + { // This matcher (index 10) is swapped with matcher (index 6) to test that full matcher takes high priority. + Type: Full, + Domain: "example.com", + }, + { + Type: Domain, + Domain: "example.com", + }, + } + cases := []struct { + Input string + Output []uint32 + }{ + { + Input: "www.baidu.com", + Output: []uint32{5, 9, 4}, + }, + { + Input: "fonts.googleapis.com", + Output: []uint32{8, 3, 7, 4 /*2, 6*/}, + }, + { + Input: "example.googleapis.com", + Output: []uint32{3, 7, 4 /*2, 6*/}, + }, + { + Input: "testapis.us", + // Output: []uint32{ /*2, 6*/ /*1,*/ }, + Output: nil, + }, + { + Input: "example.com", + Output: []uint32{10, 6, 11, 4}, + }, + } + matcherGroup := NewMphMatcherGroup() + for i, rule := range rules { + matcher, err := rule.Type.New(rule.Domain) + common.Must(err) + common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(i+3))) + } + matcherGroup.Build() + for _, test := range cases { + if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) { + t.Error("unexpected output: ", m, " for test case ", test) + } + } +} + +func TestEmptyMphMatcherGroup(t *testing.T) { + g := NewMphMatcherGroup() + g.Build() + r := g.Match("example.com") + if len(r) != 0 { + t.Error("Expect [], but ", r) + } +} diff --git a/common/geodata/strmatcher/matchergroup_simple.go b/common/geodata/strmatcher/matchergroup_simple.go new file mode 100644 index 00000000..fa6f0eb2 --- /dev/null +++ b/common/geodata/strmatcher/matchergroup_simple.go @@ -0,0 +1,41 @@ +package strmatcher + +type matcherEntry struct { + matcher Matcher + value uint32 +} + +// SimpleMatcherGroup is an implementation of MatcherGroup. +// It simply stores all matchers in an array and sequentially matches them. +type SimpleMatcherGroup struct { + matchers []matcherEntry +} + +// AddMatcher implements MatcherGroupForAll.AddMatcher. +func (g *SimpleMatcherGroup) AddMatcher(matcher Matcher, value uint32) { + g.matchers = append(g.matchers, matcherEntry{ + matcher: matcher, + value: value, + }) +} + +// Match implements MatcherGroup.Match. +func (g *SimpleMatcherGroup) Match(input string) []uint32 { + result := []uint32{} + for _, e := range g.matchers { + if e.matcher.Match(input) { + result = append(result, e.value) + } + } + return result +} + +// MatchAny implements MatcherGroup.MatchAny. +func (g *SimpleMatcherGroup) MatchAny(input string) bool { + for _, e := range g.matchers { + if e.matcher.Match(input) { + return true + } + } + return false +} diff --git a/common/geodata/strmatcher/matchergroup_simple_test.go b/common/geodata/strmatcher/matchergroup_simple_test.go new file mode 100644 index 00000000..dfeea99a --- /dev/null +++ b/common/geodata/strmatcher/matchergroup_simple_test.go @@ -0,0 +1,69 @@ +package strmatcher_test + +import ( + "reflect" + "testing" + + "github.com/xtls/xray-core/common" + . "github.com/xtls/xray-core/common/geodata/strmatcher" +) + +func TestSimpleMatcherGroup(t *testing.T) { + patterns := []struct { + pattern string + mType Type + }{ + { + pattern: "example.com", + mType: Domain, + }, + { + pattern: "example.com", + mType: Full, + }, + { + pattern: "example.com", + mType: Regex, + }, + } + cases := []struct { + input string + output []uint32 + }{ + { + input: "www.example.com", + output: []uint32{0, 2}, + }, + { + input: "example.com", + output: []uint32{0, 1, 2}, + }, + { + input: "www.e3ample.com", + output: []uint32{}, + }, + { + input: "xample.com", + output: []uint32{}, + }, + { + input: "xexample.com", + output: []uint32{2}, + }, + { + input: "examplexcom", + output: []uint32{2}, + }, + } + matcherGroup := &SimpleMatcherGroup{} + for id, entry := range patterns { + matcher, err := entry.mType.New(entry.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(id))) + } + for _, test := range cases { + if r := matcherGroup.Match(test.input); !reflect.DeepEqual(r, test.output) { + t.Error("unexpected output: ", r, " for test case ", test) + } + } +} diff --git a/common/geodata/strmatcher/matchergroup_substr.go b/common/geodata/strmatcher/matchergroup_substr.go new file mode 100644 index 00000000..ccaa0c9f --- /dev/null +++ b/common/geodata/strmatcher/matchergroup_substr.go @@ -0,0 +1,61 @@ +package strmatcher + +import ( + "sort" + "strings" +) + +// SubstrMatcherGroup is implementation of MatcherGroup, +// It is simply implmeneted to comply with the priority specification of Substr matchers. +type SubstrMatcherGroup struct { + patterns []string + values []uint32 +} + +// AddSubstrMatcher implements MatcherGroupForSubstr.AddSubstrMatcher. +func (g *SubstrMatcherGroup) AddSubstrMatcher(matcher SubstrMatcher, value uint32) { + g.patterns = append(g.patterns, matcher.Pattern()) + g.values = append(g.values, value) +} + +// Match implements MatcherGroup.Match. +func (g *SubstrMatcherGroup) Match(input string) []uint32 { + var result []uint32 + for i, pattern := range g.patterns { + for j := strings.LastIndex(input, pattern); j != -1; j = strings.LastIndex(input[:j], pattern) { + result = append(result, uint32(j)<<16|uint32(i)&0xffff) // uint32: position (higher 16 bit) | patternIdx (lower 16 bit) + } + } + // sort.Slice will trigger allocation no matter what input is. See https://github.com/golang/go/issues/17332 + // We optimize the sorting by length to prevent memory allocation as possible. + switch len(result) { + case 0: + return nil + case 1: + // No need to sort + case 2: + // Do a simple swap if unsorted + if result[0] > result[1] { + result[0], result[1] = result[1], result[0] + } + default: + // Sort the match results in dictionary order, so that: + // 1. Pattern matched at smaller position (meaning matched further) takes precedence. + // 2. When patterns matched at same position, pattern with smaller index (meaning inserted early) takes precedence. + sort.Slice(result, func(i, j int) bool { return result[i] < result[j] }) + } + for i, entry := range result { + result[i] = g.values[entry&0xffff] // Get pattern value from its index (the lower 16 bit) + } + return result +} + +// MatchAny implements MatcherGroup.MatchAny. +func (g *SubstrMatcherGroup) MatchAny(input string) bool { + for _, pattern := range g.patterns { + if strings.Contains(input, pattern) { + return true + } + } + return false +} diff --git a/common/geodata/strmatcher/matchergroup_substr_test.go b/common/geodata/strmatcher/matchergroup_substr_test.go new file mode 100644 index 00000000..bfcce682 --- /dev/null +++ b/common/geodata/strmatcher/matchergroup_substr_test.go @@ -0,0 +1,65 @@ +package strmatcher_test + +import ( + "reflect" + "testing" + + "github.com/xtls/xray-core/common" + . "github.com/xtls/xray-core/common/geodata/strmatcher" +) + +func TestSubstrMatcherGroup(t *testing.T) { + patterns := []struct { + pattern string + mType Type + }{ + { + pattern: "apis", + mType: Substr, + }, + { + pattern: "google", + mType: Substr, + }, + { + pattern: "apis", + mType: Substr, + }, + } + cases := []struct { + input string + output []uint32 + }{ + { + input: "google.com", + output: []uint32{1}, + }, + { + input: "apis.com", + output: []uint32{0, 2}, + }, + { + input: "googleapis.com", + output: []uint32{1, 0, 2}, + }, + { + input: "fonts.googleapis.com", + output: []uint32{1, 0, 2}, + }, + { + input: "apis.googleapis.com", + output: []uint32{0, 2, 1, 0, 2}, + }, + } + matcherGroup := &SubstrMatcherGroup{} + for id, entry := range patterns { + matcher, err := entry.mType.New(entry.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(id))) + } + for _, test := range cases { + if r := matcherGroup.Match(test.input); !reflect.DeepEqual(r, test.output) { + t.Error("unexpected output: ", r, " for test case ", test) + } + } +} diff --git a/common/geodata/strmatcher/matchers.go b/common/geodata/strmatcher/matchers.go new file mode 100644 index 00000000..7e073764 --- /dev/null +++ b/common/geodata/strmatcher/matchers.go @@ -0,0 +1,290 @@ +package strmatcher + +import ( + "errors" + "regexp" + "strings" + "unicode/utf8" + + "golang.org/x/net/idna" +) + +// FullMatcher is an implementation of Matcher. +type FullMatcher string + +func (FullMatcher) Type() Type { + return Full +} + +func (m FullMatcher) Pattern() string { + return string(m) +} + +func (m FullMatcher) String() string { + return "full:" + m.Pattern() +} + +func (m FullMatcher) Match(s string) bool { + return string(m) == s +} + +// DomainMatcher is an implementation of Matcher. +type DomainMatcher string + +func (DomainMatcher) Type() Type { + return Domain +} + +func (m DomainMatcher) Pattern() string { + return string(m) +} + +func (m DomainMatcher) String() string { + return "domain:" + m.Pattern() +} + +func (m DomainMatcher) Match(s string) bool { + pattern := m.Pattern() + if !strings.HasSuffix(s, pattern) { + return false + } + return len(s) == len(pattern) || s[len(s)-len(pattern)-1] == '.' +} + +// SubstrMatcher is an implementation of Matcher. +type SubstrMatcher string + +func (SubstrMatcher) Type() Type { + return Substr +} + +func (m SubstrMatcher) Pattern() string { + return string(m) +} + +func (m SubstrMatcher) String() string { + return "keyword:" + m.Pattern() +} + +func (m SubstrMatcher) Match(s string) bool { + return strings.Contains(s, m.Pattern()) +} + +// RegexMatcher is an implementation of Matcher. +type RegexMatcher struct { + pattern *regexp.Regexp +} + +func (*RegexMatcher) Type() Type { + return Regex +} + +func (m *RegexMatcher) Pattern() string { + return m.pattern.String() +} + +func (m *RegexMatcher) String() string { + return "regexp:" + m.Pattern() +} + +func (m *RegexMatcher) Match(s string) bool { + return m.pattern.MatchString(s) +} + +// New creates a new Matcher based on the given pattern. +func (t Type) New(pattern string) (Matcher, error) { + switch t { + case Full: + return FullMatcher(pattern), nil + case Substr: + return SubstrMatcher(pattern), nil + case Domain: + pattern, err := ToDomain(pattern) + if err != nil { + return nil, err + } + return DomainMatcher(pattern), nil + case Regex: // 1. regex matching is case-sensitive + regex, err := regexp.Compile(pattern) + if err != nil { + return nil, err + } + return &RegexMatcher{pattern: regex}, nil + default: + return nil, errors.New("unknown matcher type") + } +} + +// NewDomainPattern creates a new Matcher based on the given domain pattern. +// It works like `Type.New`, but will do validation and conversion to ensure it's a valid domain pattern. +func (t Type) NewDomainPattern(pattern string) (Matcher, error) { + switch t { + case Full: + pattern, err := ToDomain(pattern) + if err != nil { + return nil, err + } + return FullMatcher(pattern), nil + case Substr: + pattern, err := ToDomain(pattern) + if err != nil { + return nil, err + } + return SubstrMatcher(pattern), nil + case Domain: + pattern, err := ToDomain(pattern) + if err != nil { + return nil, err + } + return DomainMatcher(pattern), nil + case Regex: // Regex's charset not in LDH subset + regex, err := regexp.Compile(pattern) + if err != nil { + return nil, err + } + return &RegexMatcher{pattern: regex}, nil + default: + return nil, errors.New("unknown matcher type") + } +} + +// ToDomain converts input pattern to a domain string, and return error if such a conversion cannot be made. +// 1. Conforms to Letter-Digit-Hyphen (LDH) subset (https://tools.ietf.org/html/rfc952): +// * Letters A to Z (no distinction between uppercase and lowercase, we convert to lowers) +// * Digits 0 to 9 +// * Hyphens(-) and Periods(.) +// 2. If any non-ASCII characters, domain are converted from Internationalized domain name to Punycode. +func ToDomain(pattern string) (string, error) { + for { + isASCII, hasUpper := true, false + for i := 0; i < len(pattern); i++ { + c := pattern[i] + if c >= utf8.RuneSelf { + isASCII = false + break + } + switch { + case 'A' <= c && c <= 'Z': + hasUpper = true + case 'a' <= c && c <= 'z': + case '0' <= c && c <= '9': + case c == '-': + case c == '.': + default: + return "", errors.New("pattern string does not conform to Letter-Digit-Hyphen (LDH) subset") + } + } + if !isASCII { + var err error + pattern, err = idna.Punycode.ToASCII(pattern) + if err != nil { + return "", err + } + continue + } + if hasUpper { + pattern = strings.ToLower(pattern) + } + break + } + return pattern, nil +} + +// MatcherGroupForAll is an interface indicating a MatcherGroup could accept all types of matchers. +type MatcherGroupForAll interface { + AddMatcher(matcher Matcher, value uint32) +} + +// MatcherGroupForFull is an interface indicating a MatcherGroup could accept FullMatchers. +type MatcherGroupForFull interface { + AddFullMatcher(matcher FullMatcher, value uint32) +} + +// MatcherGroupForDomain is an interface indicating a MatcherGroup could accept DomainMatchers. +type MatcherGroupForDomain interface { + AddDomainMatcher(matcher DomainMatcher, value uint32) +} + +// MatcherGroupForSubstr is an interface indicating a MatcherGroup could accept SubstrMatchers. +type MatcherGroupForSubstr interface { + AddSubstrMatcher(matcher SubstrMatcher, value uint32) +} + +// MatcherGroupForRegex is an interface indicating a MatcherGroup could accept RegexMatchers. +type MatcherGroupForRegex interface { + AddRegexMatcher(matcher *RegexMatcher, value uint32) +} + +// AddMatcherToGroup is a helper function to try to add a Matcher to any kind of MatcherGroup. +// It returns error if the MatcherGroup does not accept the provided Matcher's type. +// This function is provided to help writing code to test a MatcherGroup. +func AddMatcherToGroup(g MatcherGroup, matcher Matcher, value uint32) error { + if g, ok := g.(IndexMatcher); ok { + g.Add(matcher) + return nil + } + if g, ok := g.(MatcherGroupForAll); ok { + g.AddMatcher(matcher, value) + return nil + } + switch matcher := matcher.(type) { + case FullMatcher: + if g, ok := g.(MatcherGroupForFull); ok { + g.AddFullMatcher(matcher, value) + return nil + } + case DomainMatcher: + if g, ok := g.(MatcherGroupForDomain); ok { + g.AddDomainMatcher(matcher, value) + return nil + } + case SubstrMatcher: + if g, ok := g.(MatcherGroupForSubstr); ok { + g.AddSubstrMatcher(matcher, value) + return nil + } + case *RegexMatcher: + if g, ok := g.(MatcherGroupForRegex); ok { + g.AddRegexMatcher(matcher, value) + return nil + } + } + return errors.New("cannot add matcher to matcher group") +} + +// CompositeMatches flattens the matches slice to produce a single matched indices slice. +// It is designed to avoid new memory allocation as possible. +func CompositeMatches(matches [][]uint32) []uint32 { + switch len(matches) { + case 0: + return nil + case 1: + return matches[0] + default: + result := make([]uint32, 0, 5) + for i := 0; i < len(matches); i++ { + result = append(result, matches[i]...) + } + return result + } +} + +// CompositeMatches flattens the matches slice to produce a single matched indices slice. +// It is designed that: +// 1. All matchers are concatenated in reverse order, so the matcher that matches further ranks higher. +// 2. Indices in the same matcher keeps their original order. +// 3. Avoid new memory allocation as possible. +func CompositeMatchesReverse(matches [][]uint32) []uint32 { + switch len(matches) { + case 0: + return nil + case 1: + return matches[0] + default: + result := make([]uint32, 0, 5) + for i := len(matches) - 1; i >= 0; i-- { + result = append(result, matches[i]...) + } + return result + } +} diff --git a/common/geodata/strmatcher/matchers_test.go b/common/geodata/strmatcher/matchers_test.go new file mode 100644 index 00000000..5f3f05b3 --- /dev/null +++ b/common/geodata/strmatcher/matchers_test.go @@ -0,0 +1,149 @@ +package strmatcher_test + +import ( + "reflect" + "testing" + "unsafe" + + "github.com/xtls/xray-core/common" + . "github.com/xtls/xray-core/common/geodata/strmatcher" +) + +func TestMatcher(t *testing.T) { + cases := []struct { + pattern string + mType Type + input string + output bool + }{ + { + pattern: "example.com", + mType: Domain, + input: "www.example.com", + output: true, + }, + { + pattern: "example.com", + mType: Domain, + input: "example.com", + output: true, + }, + { + pattern: "example.com", + mType: Domain, + input: "www.e3ample.com", + output: false, + }, + { + pattern: "example.com", + mType: Domain, + input: "xample.com", + output: false, + }, + { + pattern: "example.com", + mType: Domain, + input: "xexample.com", + output: false, + }, + { + pattern: "example.com", + mType: Full, + input: "example.com", + output: true, + }, + { + pattern: "example.com", + mType: Full, + input: "xexample.com", + output: false, + }, + { + pattern: "example.com", + mType: Regex, + input: "examplexcom", + output: true, + }, + } + for _, test := range cases { + matcher, err := test.mType.New(test.pattern) + common.Must(err) + if m := matcher.Match(test.input); m != test.output { + t.Error("unexpected output: ", m, " for test case ", test) + } + } +} + +func TestToDomain(t *testing.T) { + { // Test normal ASCII domain, which should not trigger new string data allocation + input := "example.com" + domain, err := ToDomain(input) + if err != nil { + t.Error("unexpected error: ", err) + } + if domain != input { + t.Error("unexpected output: ", domain, " for test case ", input) + } + if (*reflect.StringHeader)(unsafe.Pointer(&input)).Data != (*reflect.StringHeader)(unsafe.Pointer(&domain)).Data { + t.Error("different string data of output: ", domain, " and test case ", input) + } + } + { // Test ASCII domain containing upper case letter, which should be converted to lower case + input := "eXAMPLE.cOm" + domain, err := ToDomain(input) + if err != nil { + t.Error("unexpected error: ", err) + } + if domain != "example.com" { + t.Error("unexpected output: ", domain, " for test case ", input) + } + } + { // Test internationalized domain, which should be translated to ASCII punycode + input := "example.公益" + domain, err := ToDomain(input) + if err != nil { + t.Error("unexpected error: ", err) + } + if domain != "example.xn--55qw42g" { + t.Error("unexpected output: ", domain, " for test case ", input) + } + } + { // Test internationalized domain containing upper case letter + input := "eXAMPLE.公益" + domain, err := ToDomain(input) + if err != nil { + t.Error("unexpected error: ", err) + } + if domain != "example.xn--55qw42g" { + t.Error("unexpected output: ", domain, " for test case ", input) + } + } + { // Test domain name of invalid character, which should return with error + input := "{" + _, err := ToDomain(input) + if err == nil { + t.Error("unexpected non error for test case ", input) + } + } + { // Test domain name containing a space, which should return with error + input := "Mijia Cloud" + _, err := ToDomain(input) + if err == nil { + t.Error("unexpected non error for test case ", input) + } + } + { // Test domain name containing an underscore, which should return with error + input := "Mijia_Cloud.com" + _, err := ToDomain(input) + if err == nil { + t.Error("unexpected non error for test case ", input) + } + } + { // Test internationalized domain containing invalid character + input := "Mijia Cloud.公司" + _, err := ToDomain(input) + if err == nil { + t.Error("unexpected non error for test case ", input) + } + } +} diff --git a/common/geodata/strmatcher/strmatcher.go b/common/geodata/strmatcher/strmatcher.go new file mode 100644 index 00000000..e4187f63 --- /dev/null +++ b/common/geodata/strmatcher/strmatcher.go @@ -0,0 +1,101 @@ +package strmatcher + +// Type is the type of the matcher. +type Type byte + +const ( + // Full is the type of matcher that the input string must exactly equal to the pattern. + Full Type = 0 + // Domain is the type of matcher that the input string must be a sub-domain or itself of the pattern. + Domain Type = 1 + // Substr is the type of matcher that the input string must contain the pattern as a sub-string. + Substr Type = 2 + // Regex is the type of matcher that the input string must matches the regular-expression pattern. + Regex Type = 3 +) + +// Matcher is the interface to determine a string matches a pattern. +// - This is a basic matcher to represent a certain kind of match semantic(full, substr, domain or regex). +type Matcher interface { + // Type returns the matcher's type. + Type() Type + + // Pattern returns the matcher's raw string representation. + Pattern() string + + // String returns a string representation of the matcher containing its type and pattern. + String() string + + // Match returns true if the given string matches a predefined pattern. + // * This method is seldom used for performance reason + // and is generally taken over by their corresponding MatcherGroup. + Match(input string) bool +} + +// MatcherGroup is an advanced type of matcher to accept a bunch of basic Matchers (of certain type, not all matcher types). +// For example: +// - FullMatcherGroup accepts FullMatcher and uses a hash table to facilitate lookup. +// - DomainMatcherGroup accepts DomainMatcher and uses a trie to optimize both memory consumption and lookup speed. +type MatcherGroup interface { + // Match returns all matched matchers with their corresponding values. + Match(input string) []uint32 + + // MatchAny returns true as soon as one matching matcher is found. + MatchAny(input string) bool +} + +// IndexMatcher is a general type of matcher thats accepts all kinds of basic matchers. +// It should: +// - Accept all Matcher types with no exception. +// - Optimize string matching with a combination of MatcherGroups. +// - Obey certain priority order specification when returning matched Matchers. +type IndexMatcher interface { + // Size returns number of matchers added to IndexMatcher. + Size() uint32 + + // Add adds a new Matcher to IndexMatcher, and returns its index. The index will never be 0. + Add(matcher Matcher) uint32 + + // Build builds the IndexMatcher to be ready for matching. + Build() error + + // Match returns the indices of all matchers that matches the input. + // * Empty array is returned if no such matcher exists. + // * The order of returned matchers should follow priority specification. + // Priority specification: + // 1. Priority between matcher types: full > domain > substr > regex. + // 2. Priority of same-priority matchers matching at same position: the early added takes precedence. + // 3. Priority of domain matchers matching at different levels: the further matched domain takes precedence. + // 4. Priority of substr matchers matching at different positions: the further matched substr takes precedence. + Match(input string) []uint32 + + // MatchAny returns true as soon as one matching matcher is found. + MatchAny(input string) bool +} + +// ValueMatcher is a general type of matcher that accepts all kinds of basic matchers. +// It should: +// - Accept all Matcher types with no exception. +// - Optimize string matching with a combination of MatcherGroups. +// - Obey certain priority order specification when returning matched values. +type ValueMatcher interface { + // Add adds a new Matcher to ValueMatcher, binding it to the provided value. + Add(matcher Matcher, value uint32) + + // Build builds the ValueMatcher to be ready for matching. + Build() error + + // Match returns the values of all matchers that matches the input. + // * Empty array is returned if no such matcher exists. + // * The order of returned values should follow priority specification. + // * Same value may appear multiple times if multiple matched matchers were added with that value. + // Priority specification: + // 1. Priority between matcher types: full > domain > substr > regex. + // 2. Priority of same-priority matchers matching at same position: the early added takes precedence. + // 3. Priority of domain matchers matching at different levels: the further matched domain takes precedence. + // 4. Priority of substr matchers matching at different positions: the further matched substr takes precedence. + Match(input string) []uint32 + + // MatchAny returns true as soon as one matching matcher is found. + MatchAny(input string) bool +} diff --git a/common/geodata/strmatcher/valuematcher_linear.go b/common/geodata/strmatcher/valuematcher_linear.go new file mode 100644 index 00000000..e4d9e974 --- /dev/null +++ b/common/geodata/strmatcher/valuematcher_linear.go @@ -0,0 +1,85 @@ +package strmatcher + +// LinearValueMatcher is an implementation of ValueMatcher. +type LinearValueMatcher struct { + full *FullMatcherGroup + domain *DomainMatcherGroup + substr *SubstrMatcherGroup + regex *SimpleMatcherGroup +} + +func NewLinearValueMatcher() *LinearValueMatcher { + return new(LinearValueMatcher) +} + +// Add implements ValueMatcher.Add. +func (g *LinearValueMatcher) Add(matcher Matcher, value uint32) { + switch matcher := matcher.(type) { + case FullMatcher: + if g.full == nil { + g.full = NewFullMatcherGroup() + } + g.full.AddFullMatcher(matcher, value) + case DomainMatcher: + if g.domain == nil { + g.domain = NewDomainMatcherGroup() + } + g.domain.AddDomainMatcher(matcher, value) + case SubstrMatcher: + if g.substr == nil { + g.substr = new(SubstrMatcherGroup) + } + g.substr.AddSubstrMatcher(matcher, value) + default: + if g.regex == nil { + g.regex = new(SimpleMatcherGroup) + } + g.regex.AddMatcher(matcher, value) + } +} + +// Build implements ValueMatcher.Build. +func (*LinearValueMatcher) Build() error { + return nil +} + +// Match implements ValueMatcher.Match. +func (g *LinearValueMatcher) Match(input string) []uint32 { + // Allocate capacity to prevent matches escaping to heap + result := make([][]uint32, 0, 5) + if g.full != nil { + if matches := g.full.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + if g.domain != nil { + if matches := g.domain.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + if g.substr != nil { + if matches := g.substr.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + if g.regex != nil { + if matches := g.regex.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + return CompositeMatches(result) +} + +// MatchAny implements ValueMatcher.MatchAny. +func (g *LinearValueMatcher) MatchAny(input string) bool { + if g.full != nil && g.full.MatchAny(input) { + return true + } + if g.domain != nil && g.domain.MatchAny(input) { + return true + } + if g.substr != nil && g.substr.MatchAny(input) { + return true + } + return g.regex != nil && g.regex.MatchAny(input) +} diff --git a/common/geodata/strmatcher/valuematcher_mph.go b/common/geodata/strmatcher/valuematcher_mph.go new file mode 100644 index 00000000..ced6f442 --- /dev/null +++ b/common/geodata/strmatcher/valuematcher_mph.go @@ -0,0 +1,89 @@ +package strmatcher + +import "runtime" + +// A MphValueMatcher is divided into three parts: +// 1. `full` and `domain` patterns are matched by Rabin-Karp algorithm and minimal perfect hash table; +// 2. `substr` patterns are matched by ac automaton; +// 3. `regex` patterns are matched with the regex library. +type MphValueMatcher struct { + mph *MphMatcherGroup + ac *ACAutomatonMatcherGroup + regex *SimpleMatcherGroup +} + +func NewMphValueMatcher() *MphValueMatcher { + return new(MphValueMatcher) +} + +// Add implements ValueMatcher.Add. +func (g *MphValueMatcher) Add(matcher Matcher, value uint32) { + switch matcher := matcher.(type) { + case FullMatcher: + if g.mph == nil { + g.mph = NewMphMatcherGroup() + } + g.mph.AddFullMatcher(matcher, value) + case DomainMatcher: + if g.mph == nil { + g.mph = NewMphMatcherGroup() + } + g.mph.AddDomainMatcher(matcher, value) + case SubstrMatcher: + if g.ac == nil { + g.ac = NewACAutomatonMatcherGroup() + } + g.ac.AddSubstrMatcher(matcher, value) + case *RegexMatcher: + if g.regex == nil { + g.regex = &SimpleMatcherGroup{} + } + g.regex.AddMatcher(matcher, value) + } +} + +// Build implements ValueMatcher.Build. +func (g *MphValueMatcher) Build() error { + if g.mph != nil { + runtime.GC() // peak mem + g.mph.Build() + } + runtime.GC() // peak mem + if g.ac != nil { + g.ac.Build() + runtime.GC() // peak mem + } + return nil +} + +// Match implements ValueMatcher.Match. +func (g *MphValueMatcher) Match(input string) []uint32 { + result := make([][]uint32, 0, 5) + if g.mph != nil { + if matches := g.mph.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + if g.ac != nil { + if matches := g.ac.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + if g.regex != nil { + if matches := g.regex.Match(input); len(matches) > 0 { + result = append(result, matches) + } + } + return CompositeMatches(result) +} + +// MatchAny implements ValueMatcher.MatchAny. +func (g *MphValueMatcher) MatchAny(input string) bool { + if g.mph != nil && g.mph.MatchAny(input) { + return true + } + if g.ac != nil && g.ac.MatchAny(input) { + return true + } + return g.regex != nil && g.regex.MatchAny(input) +} diff --git a/common/platform/platform.go b/common/platform/platform.go index 6446873b..80e62874 100644 --- a/common/platform/platform.go +++ b/common/platform/platform.go @@ -24,8 +24,6 @@ const ( XUDPBaseKey = "xray.xudp.basekey" TunFdKey = "xray.tun.fd" - - MphCachePath = "xray.mph.cache" ) type EnvFlag struct { diff --git a/common/strmatcher/ac_automaton_matcher.go b/common/strmatcher/ac_automaton_matcher.go deleted file mode 100644 index 7844333d..00000000 --- a/common/strmatcher/ac_automaton_matcher.go +++ /dev/null @@ -1,247 +0,0 @@ -package strmatcher - -import ( - "container/list" -) - -const validCharCount = 53 - -type MatchType struct { - Type Type - Exist bool -} - -const ( - TrieEdge bool = true - FailEdge bool = false -) - -type Edge struct { - Type bool - NextNode int -} - -type ACAutomaton struct { - Trie [][validCharCount]Edge - Fail []int - Exists []MatchType - Count int -} - -func newNode() [validCharCount]Edge { - var s [validCharCount]Edge - for i := range s { - s[i] = Edge{ - Type: FailEdge, - NextNode: 0, - } - } - return s -} - -var char2Index = []int{ - 'A': 0, - 'a': 0, - 'B': 1, - 'b': 1, - 'C': 2, - 'c': 2, - 'D': 3, - 'd': 3, - 'E': 4, - 'e': 4, - 'F': 5, - 'f': 5, - 'G': 6, - 'g': 6, - 'H': 7, - 'h': 7, - 'I': 8, - 'i': 8, - 'J': 9, - 'j': 9, - 'K': 10, - 'k': 10, - 'L': 11, - 'l': 11, - 'M': 12, - 'm': 12, - 'N': 13, - 'n': 13, - 'O': 14, - 'o': 14, - 'P': 15, - 'p': 15, - 'Q': 16, - 'q': 16, - 'R': 17, - 'r': 17, - 'S': 18, - 's': 18, - 'T': 19, - 't': 19, - 'U': 20, - 'u': 20, - 'V': 21, - 'v': 21, - 'W': 22, - 'w': 22, - 'X': 23, - 'x': 23, - 'Y': 24, - 'y': 24, - 'Z': 25, - 'z': 25, - '!': 26, - '$': 27, - '&': 28, - '\'': 29, - '(': 30, - ')': 31, - '*': 32, - '+': 33, - ',': 34, - ';': 35, - '=': 36, - ':': 37, - '%': 38, - '-': 39, - '.': 40, - '_': 41, - '~': 42, - '0': 43, - '1': 44, - '2': 45, - '3': 46, - '4': 47, - '5': 48, - '6': 49, - '7': 50, - '8': 51, - '9': 52, -} - -func NewACAutomaton() *ACAutomaton { - ac := new(ACAutomaton) - ac.Trie = append(ac.Trie, newNode()) - ac.Fail = append(ac.Fail, 0) - ac.Exists = append(ac.Exists, MatchType{ - Type: Full, - Exist: false, - }) - return ac -} - -func (ac *ACAutomaton) Add(domain string, t Type) { - node := 0 - for i := len(domain) - 1; i >= 0; i-- { - idx := char2Index[domain[i]] - if ac.Trie[node][idx].NextNode == 0 { - ac.Count++ - if len(ac.Trie) < ac.Count+1 { - ac.Trie = append(ac.Trie, newNode()) - ac.Fail = append(ac.Fail, 0) - ac.Exists = append(ac.Exists, MatchType{ - Type: Full, - Exist: false, - }) - } - ac.Trie[node][idx] = Edge{ - Type: TrieEdge, - NextNode: ac.Count, - } - } - node = ac.Trie[node][idx].NextNode - } - ac.Exists[node] = MatchType{ - Type: t, - Exist: true, - } - switch t { - case Domain: - ac.Exists[node] = MatchType{ - Type: Full, - Exist: true, - } - idx := char2Index['.'] - if ac.Trie[node][idx].NextNode == 0 { - ac.Count++ - if len(ac.Trie) < ac.Count+1 { - ac.Trie = append(ac.Trie, newNode()) - ac.Fail = append(ac.Fail, 0) - ac.Exists = append(ac.Exists, MatchType{ - Type: Full, - Exist: false, - }) - } - ac.Trie[node][idx] = Edge{ - Type: TrieEdge, - NextNode: ac.Count, - } - } - node = ac.Trie[node][idx].NextNode - ac.Exists[node] = MatchType{ - Type: t, - Exist: true, - } - default: - break - } -} - -func (ac *ACAutomaton) Build() { - queue := list.New() - for i := 0; i < validCharCount; i++ { - if ac.Trie[0][i].NextNode != 0 { - queue.PushBack(ac.Trie[0][i]) - } - } - for { - front := queue.Front() - if front == nil { - break - } else { - node := front.Value.(Edge).NextNode - queue.Remove(front) - for i := 0; i < validCharCount; i++ { - if ac.Trie[node][i].NextNode != 0 { - ac.Fail[ac.Trie[node][i].NextNode] = ac.Trie[ac.Fail[node]][i].NextNode - queue.PushBack(ac.Trie[node][i]) - } else { - ac.Trie[node][i] = Edge{ - Type: FailEdge, - NextNode: ac.Trie[ac.Fail[node]][i].NextNode, - } - } - } - } - } -} - -func (ac *ACAutomaton) Match(s string) bool { - node := 0 - fullMatch := true - // 1. the match string is all through trie edge. FULL MATCH or DOMAIN - // 2. the match string is through a fail edge. NOT FULL MATCH - // 2.1 Through a fail edge, but there exists a valid node. SUBSTR - for i := len(s) - 1; i >= 0; i-- { - chr := int(s[i]) - if chr >= len(char2Index) { - return false - } - idx := char2Index[chr] - fullMatch = fullMatch && ac.Trie[node][idx].Type - node = ac.Trie[node][idx].NextNode - switch ac.Exists[node].Type { - case Substr: - return true - case Domain: - if fullMatch { - return true - } - default: - break - } - } - return fullMatch && ac.Exists[node].Exist -} diff --git a/common/strmatcher/benchmark_test.go b/common/strmatcher/benchmark_test.go deleted file mode 100644 index 972570ce..00000000 --- a/common/strmatcher/benchmark_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package strmatcher_test - -import ( - "strconv" - "testing" - - "github.com/xtls/xray-core/common" - . "github.com/xtls/xray-core/common/strmatcher" -) - -func BenchmarkACAutomaton(b *testing.B) { - ac := NewACAutomaton() - for i := 1; i <= 1024; i++ { - ac.Add(strconv.Itoa(i)+".xray.com", Domain) - } - ac.Build() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ac.Match("0.xray.com") - } -} - -func BenchmarkDomainMatcherGroup(b *testing.B) { - g := new(DomainMatcherGroup) - - for i := 1; i <= 1024; i++ { - g.Add(strconv.Itoa(i)+".example.com", uint32(i)) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = g.Match("0.example.com") - } -} - -func BenchmarkFullMatcherGroup(b *testing.B) { - g := new(FullMatcherGroup) - - for i := 1; i <= 1024; i++ { - g.Add(strconv.Itoa(i)+".example.com", uint32(i)) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = g.Match("0.example.com") - } -} - -func BenchmarkMarchGroup(b *testing.B) { - g := new(MatcherGroup) - for i := 1; i <= 1024; i++ { - m, err := Domain.New(strconv.Itoa(i) + ".example.com") - common.Must(err) - g.Add(m) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = g.Match("0.example.com") - } -} diff --git a/common/strmatcher/domain_matcher.go b/common/strmatcher/domain_matcher.go deleted file mode 100644 index ae8e65bc..00000000 --- a/common/strmatcher/domain_matcher.go +++ /dev/null @@ -1,98 +0,0 @@ -package strmatcher - -import "strings" - -func breakDomain(domain string) []string { - return strings.Split(domain, ".") -} - -type node struct { - values []uint32 - sub map[string]*node -} - -// DomainMatcherGroup is a IndexMatcher for a large set of Domain matchers. -// Visible for testing only. -type DomainMatcherGroup struct { - root *node -} - -func (g *DomainMatcherGroup) Add(domain string, value uint32) { - if g.root == nil { - g.root = new(node) - } - - current := g.root - parts := breakDomain(domain) - for i := len(parts) - 1; i >= 0; i-- { - part := parts[i] - if current.sub == nil { - current.sub = make(map[string]*node) - } - next := current.sub[part] - if next == nil { - next = new(node) - current.sub[part] = next - } - current = next - } - - current.values = append(current.values, value) -} - -func (g *DomainMatcherGroup) addMatcher(m domainMatcher, value uint32) { - g.Add(string(m), value) -} - -func (g *DomainMatcherGroup) Match(domain string) []uint32 { - if domain == "" { - return nil - } - - current := g.root - if current == nil { - return nil - } - - nextPart := func(idx int) int { - for i := idx - 1; i >= 0; i-- { - if domain[i] == '.' { - return i - } - } - return -1 - } - - matches := [][]uint32{} - idx := len(domain) - for { - if idx == -1 || current.sub == nil { - break - } - - nidx := nextPart(idx) - part := domain[nidx+1 : idx] - next := current.sub[part] - if next == nil { - break - } - current = next - idx = nidx - if len(current.values) > 0 { - matches = append(matches, current.values) - } - } - switch len(matches) { - case 0: - return nil - case 1: - return matches[0] - default: - result := []uint32{} - for idx := range matches { - // Insert reversely, the subdomain that matches further ranks higher - result = append(result, matches[len(matches)-1-idx]...) - } - return result - } -} diff --git a/common/strmatcher/full_matcher.go b/common/strmatcher/full_matcher.go deleted file mode 100644 index e00d02aa..00000000 --- a/common/strmatcher/full_matcher.go +++ /dev/null @@ -1,25 +0,0 @@ -package strmatcher - -type FullMatcherGroup struct { - matchers map[string][]uint32 -} - -func (g *FullMatcherGroup) Add(domain string, value uint32) { - if g.matchers == nil { - g.matchers = make(map[string][]uint32) - } - - g.matchers[domain] = append(g.matchers[domain], value) -} - -func (g *FullMatcherGroup) addMatcher(m fullMatcher, value uint32) { - g.Add(string(m), value) -} - -func (g *FullMatcherGroup) Match(str string) []uint32 { - if g.matchers == nil { - return nil - } - - return g.matchers[str] -} diff --git a/common/strmatcher/matchers.go b/common/strmatcher/matchers.go deleted file mode 100644 index 915927db..00000000 --- a/common/strmatcher/matchers.go +++ /dev/null @@ -1,56 +0,0 @@ -package strmatcher - -import ( - "regexp" - "strings" -) - -type fullMatcher string - -func (m fullMatcher) Match(s string) bool { - return string(m) == s -} - -func (m fullMatcher) String() string { - return "full:" + string(m) -} - -type substrMatcher string - -func (m substrMatcher) Match(s string) bool { - return strings.Contains(s, string(m)) -} - -func (m substrMatcher) String() string { - return "keyword:" + string(m) -} - -type domainMatcher string - -func (m domainMatcher) Match(s string) bool { - pattern := string(m) - if !strings.HasSuffix(s, pattern) { - return false - } - return len(s) == len(pattern) || s[len(s)-len(pattern)-1] == '.' -} - -func (m domainMatcher) String() string { - return "domain:" + string(m) -} - -type RegexMatcher struct { - Pattern string - reg *regexp.Regexp -} - -func (m *RegexMatcher) Match(s string) bool { - if m.reg == nil { - m.reg = regexp.MustCompile(m.Pattern) - } - return m.reg.MatchString(s) -} - -func (m *RegexMatcher) String() string { - return "regexp:" + m.Pattern -} diff --git a/common/strmatcher/matchers_test.go b/common/strmatcher/matchers_test.go deleted file mode 100644 index d39c522c..00000000 --- a/common/strmatcher/matchers_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package strmatcher_test - -import ( - "testing" - - "github.com/xtls/xray-core/common" - . "github.com/xtls/xray-core/common/strmatcher" -) - -func TestMatcher(t *testing.T) { - cases := []struct { - pattern string - mType Type - input string - output bool - }{ - { - pattern: "example.com", - mType: Domain, - input: "www.example.com", - output: true, - }, - { - pattern: "example.com", - mType: Domain, - input: "example.com", - output: true, - }, - { - pattern: "example.com", - mType: Domain, - input: "www.fxample.com", - output: false, - }, - { - pattern: "example.com", - mType: Domain, - input: "xample.com", - output: false, - }, - { - pattern: "example.com", - mType: Domain, - input: "xexample.com", - output: false, - }, - { - pattern: "example.com", - mType: Full, - input: "example.com", - output: true, - }, - { - pattern: "example.com", - mType: Full, - input: "xexample.com", - output: false, - }, - { - pattern: "example.com", - mType: Regex, - input: "examplexcom", - output: true, - }, - } - for _, test := range cases { - matcher, err := test.mType.New(test.pattern) - common.Must(err) - if m := matcher.Match(test.input); m != test.output { - t.Error("unexpected output: ", m, " for test case ", test) - } - } -} diff --git a/common/strmatcher/mph_matcher.go b/common/strmatcher/mph_matcher.go deleted file mode 100644 index ff3dea65..00000000 --- a/common/strmatcher/mph_matcher.go +++ /dev/null @@ -1,308 +0,0 @@ -package strmatcher - -import ( - "math/bits" - "regexp" - "sort" - "strings" - "unsafe" -) - -// PrimeRK is the prime base used in Rabin-Karp algorithm. -const PrimeRK = 16777619 - -// calculate the rolling murmurHash of given string -func RollingHash(s string) uint32 { - h := uint32(0) - for i := len(s) - 1; i >= 0; i-- { - h = h*PrimeRK + uint32(s[i]) - } - return h -} - -// A MphMatcherGroup is divided into three parts: -// 1. `full` and `domain` patterns are matched by Rabin-Karp algorithm and minimal perfect hash table; -// 2. `substr` patterns are matched by ac automaton; -// 3. `regex` patterns are matched with the regex library. -type MphMatcherGroup struct { - Ac *ACAutomaton - OtherMatchers []MatcherEntry - Rules []string - Level0 []uint32 - Level0Mask int - Level1 []uint32 - Level1Mask int - Count uint32 - RuleMap *map[string]uint32 -} - -func (g *MphMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) { - h := RollingHash(pattern) - switch t { - case Domain: - (*g.RuleMap)["."+pattern] = h*PrimeRK + uint32('.') - fallthrough - case Full: - (*g.RuleMap)[pattern] = h - default: - } -} - -func NewMphMatcherGroup() *MphMatcherGroup { - return &MphMatcherGroup{ - Ac: nil, - OtherMatchers: nil, - Rules: nil, - Level0: nil, - Level0Mask: 0, - Level1: nil, - Level1Mask: 0, - Count: 1, - RuleMap: &map[string]uint32{}, - } -} - -// AddPattern adds a pattern to MphMatcherGroup -func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) { - switch t { - case Substr: - if g.Ac == nil { - g.Ac = NewACAutomaton() - } - g.Ac.Add(pattern, t) - case Full, Domain: - pattern = strings.ToLower(pattern) - g.AddFullOrDomainPattern(pattern, t) - case Regex: - r, err := regexp.Compile(pattern) - if err != nil { - return 0, err - } - g.OtherMatchers = append(g.OtherMatchers, MatcherEntry{ - M: &RegexMatcher{Pattern: pattern, reg: r}, - Id: g.Count, - }) - default: - panic("Unknown type") - } - return g.Count, nil -} - -// Build builds a minimal perfect hash table and ac automaton from insert rules -func (g *MphMatcherGroup) Build() { - if g.Ac != nil { - g.Ac.Build() - } - keyLen := len(*g.RuleMap) - if keyLen == 0 { - keyLen = 1 - (*g.RuleMap)["empty___"] = RollingHash("empty___") - } - g.Level0 = make([]uint32, nextPow2(keyLen/4)) - g.Level0Mask = len(g.Level0) - 1 - g.Level1 = make([]uint32, nextPow2(keyLen)) - g.Level1Mask = len(g.Level1) - 1 - sparseBuckets := make([][]int, len(g.Level0)) - var ruleIdx int - for rule, hash := range *g.RuleMap { - n := int(hash) & g.Level0Mask - g.Rules = append(g.Rules, rule) - sparseBuckets[n] = append(sparseBuckets[n], ruleIdx) - ruleIdx++ - } - g.RuleMap = nil - var buckets []indexBucket - for n, vals := range sparseBuckets { - if len(vals) > 0 { - buckets = append(buckets, indexBucket{n, vals}) - } - } - sort.Sort(bySize(buckets)) - - occ := make([]bool, len(g.Level1)) - var tmpOcc []int - for _, bucket := range buckets { - seed := uint32(0) - for { - findSeed := true - tmpOcc = tmpOcc[:0] - for _, i := range bucket.vals { - n := int(strhashFallback(unsafe.Pointer(&g.Rules[i]), uintptr(seed))) & g.Level1Mask - if occ[n] { - for _, n := range tmpOcc { - occ[n] = false - } - seed++ - findSeed = false - break - } - occ[n] = true - tmpOcc = append(tmpOcc, n) - g.Level1[n] = uint32(i) - } - if findSeed { - g.Level0[bucket.n] = seed - break - } - } - } -} - -func nextPow2(v int) int { - if v <= 1 { - return 1 - } - const MaxUInt = ^uint(0) - n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1 - return int(n) -} - -// Lookup searches for s in t and returns its index and whether it was found. -func (g *MphMatcherGroup) Lookup(h uint32, s string) bool { - i0 := int(h) & g.Level0Mask - seed := g.Level0[i0] - i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.Level1Mask - n := g.Level1[i1] - return s == g.Rules[int(n)] -} - -// Match implements IndexMatcher.Match. -func (g *MphMatcherGroup) Match(pattern string) []uint32 { - result := []uint32{} - hash := uint32(0) - for i := len(pattern) - 1; i >= 0; i-- { - hash = hash*PrimeRK + uint32(pattern[i]) - if pattern[i] == '.' { - if g.Lookup(hash, pattern[i:]) { - result = append(result, 1) - return result - } - } - } - if g.Lookup(hash, pattern) { - result = append(result, 1) - return result - } - if g.Ac != nil && g.Ac.Match(pattern) { - result = append(result, 1) - return result - } - for _, e := range g.OtherMatchers { - if e.M.Match(pattern) { - result = append(result, e.Id) - return result - } - } - return nil -} - -type indexBucket struct { - n int - vals []int -} - -type bySize []indexBucket - -func (s bySize) Len() int { return len(s) } -func (s bySize) Less(i, j int) bool { return len(s[i].vals) > len(s[j].vals) } -func (s bySize) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -type stringStruct struct { - str unsafe.Pointer - len int -} - -func strhashFallback(a unsafe.Pointer, h uintptr) uintptr { - x := (*stringStruct)(a) - return memhashFallback(x.str, h, uintptr(x.len)) -} - -const ( - // Constants for multiplication: four random odd 64-bit numbers. - m1 = 16877499708836156737 - m2 = 2820277070424839065 - m3 = 9497967016996688599 - m4 = 15839092249703872147 -) - -var hashkey = [4]uintptr{1, 1, 1, 1} - -func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr { - h := uint64(seed + s*hashkey[0]) -tail: - switch { - case s == 0: - case s < 4: - h ^= uint64(*(*byte)(p)) - h ^= uint64(*(*byte)(add(p, s>>1))) << 8 - h ^= uint64(*(*byte)(add(p, s-1))) << 16 - h = rotl31(h*m1) * m2 - case s <= 8: - h ^= uint64(readUnaligned32(p)) - h ^= uint64(readUnaligned32(add(p, s-4))) << 32 - h = rotl31(h*m1) * m2 - case s <= 16: - h ^= readUnaligned64(p) - h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, s-8)) - h = rotl31(h*m1) * m2 - case s <= 32: - h ^= readUnaligned64(p) - h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, 8)) - h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, s-16)) - h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, s-8)) - h = rotl31(h*m1) * m2 - default: - v1 := h - v2 := uint64(seed * hashkey[1]) - v3 := uint64(seed * hashkey[2]) - v4 := uint64(seed * hashkey[3]) - for s >= 32 { - v1 ^= readUnaligned64(p) - v1 = rotl31(v1*m1) * m2 - p = add(p, 8) - v2 ^= readUnaligned64(p) - v2 = rotl31(v2*m2) * m3 - p = add(p, 8) - v3 ^= readUnaligned64(p) - v3 = rotl31(v3*m3) * m4 - p = add(p, 8) - v4 ^= readUnaligned64(p) - v4 = rotl31(v4*m4) * m1 - p = add(p, 8) - s -= 32 - } - h = v1 ^ v2 ^ v3 ^ v4 - goto tail - } - - h ^= h >> 29 - h *= m3 - h ^= h >> 32 - return uintptr(h) -} - -func add(p unsafe.Pointer, x uintptr) unsafe.Pointer { - return unsafe.Pointer(uintptr(p) + x) -} - -func readUnaligned32(p unsafe.Pointer) uint32 { - q := (*[4]byte)(p) - return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24 -} - -func rotl31(x uint64) uint64 { - return (x << 31) | (x >> (64 - 31)) -} - -func readUnaligned64(p unsafe.Pointer) uint64 { - q := (*[8]byte)(p) - return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56 -} - -func (g *MphMatcherGroup) Size() uint32 { - return g.Count -} diff --git a/common/strmatcher/mph_matcher_compact.go b/common/strmatcher/mph_matcher_compact.go deleted file mode 100644 index a40b9f56..00000000 --- a/common/strmatcher/mph_matcher_compact.go +++ /dev/null @@ -1,47 +0,0 @@ -package strmatcher - -import ( - "bytes" - "encoding/gob" - "io" -) - -func init() { - gob.Register(&RegexMatcher{}) - gob.Register(fullMatcher("")) - gob.Register(substrMatcher("")) - gob.Register(domainMatcher("")) -} - -func (g *MphMatcherGroup) Serialize(w io.Writer) error { - data := MphMatcherGroup{ - Ac: g.Ac, - OtherMatchers: g.OtherMatchers, - Rules: g.Rules, - Level0: g.Level0, - Level0Mask: g.Level0Mask, - Level1: g.Level1, - Level1Mask: g.Level1Mask, - Count: g.Count, - } - return gob.NewEncoder(w).Encode(data) -} - -func NewMphMatcherGroupFromBuffer(data []byte) (*MphMatcherGroup, error) { - var gData MphMatcherGroup - if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&gData); err != nil { - return nil, err - } - - g := NewMphMatcherGroup() - g.Ac = gData.Ac - g.OtherMatchers = gData.OtherMatchers - g.Rules = gData.Rules - g.Level0 = gData.Level0 - g.Level0Mask = gData.Level0Mask - g.Level1 = gData.Level1 - g.Level1Mask = gData.Level1Mask - g.Count = gData.Count - - return g, nil -} diff --git a/common/strmatcher/strmatcher.go b/common/strmatcher/strmatcher.go deleted file mode 100644 index 89e7dae6..00000000 --- a/common/strmatcher/strmatcher.go +++ /dev/null @@ -1,141 +0,0 @@ -package strmatcher - -import ( - "errors" - "regexp" -) - -// Matcher is the interface to determine a string matches a pattern. -type Matcher interface { - // Match returns true if the given string matches a predefined pattern. - Match(string) bool - String() string -} - -// Type is the type of the matcher. -type Type byte - -const ( - // Full is the type of matcher that the input string must exactly equal to the pattern. - Full Type = iota - // Substr is the type of matcher that the input string must contain the pattern as a sub-string. - Substr - // Domain is the type of matcher that the input string must be a sub-domain or itself of the pattern. - Domain - // Regex is the type of matcher that the input string must matches the regular-expression pattern. - Regex -) - -// New creates a new Matcher based on the given pattern. -func (t Type) New(pattern string) (Matcher, error) { - // 1. regex matching is case-sensitive - switch t { - case Full: - return fullMatcher(pattern), nil - case Substr: - return substrMatcher(pattern), nil - case Domain: - return domainMatcher(pattern), nil - case Regex: - r, err := regexp.Compile(pattern) - if err != nil { - return nil, err - } - return &RegexMatcher{ - Pattern: pattern, - reg: r, - }, nil - default: - return nil, errors.New("unk type") - } -} - -// IndexMatcher is the interface for matching with a group of matchers. -type IndexMatcher interface { - // Match returns the index of a matcher that matches the input. It returns empty array if no such matcher exists. - Match(input string) []uint32 - // Size returns the number of matchers in the group. - Size() uint32 -} - -type MatcherEntry struct { - M Matcher - Id uint32 -} - -// MatcherGroup is an implementation of IndexMatcher. -// Empty initialization works. -type MatcherGroup struct { - count uint32 - fullMatcher FullMatcherGroup - domainMatcher DomainMatcherGroup - otherMatchers []MatcherEntry -} - -// Add adds a new Matcher into the MatcherGroup, and returns its index. The index will never be 0. -func (g *MatcherGroup) Add(m Matcher) uint32 { - g.count++ - c := g.count - - switch tm := m.(type) { - case fullMatcher: - g.fullMatcher.addMatcher(tm, c) - case domainMatcher: - g.domainMatcher.addMatcher(tm, c) - default: - g.otherMatchers = append(g.otherMatchers, MatcherEntry{ - M: m, - Id: c, - }) - } - - return c -} - -// Match implements IndexMatcher.Match. -func (g *MatcherGroup) Match(pattern string) []uint32 { - result := []uint32{} - result = append(result, g.fullMatcher.Match(pattern)...) - result = append(result, g.domainMatcher.Match(pattern)...) - for _, e := range g.otherMatchers { - if e.M.Match(pattern) { - result = append(result, e.Id) - } - } - return result -} - -// Size returns the number of matchers in the MatcherGroup. -func (g *MatcherGroup) Size() uint32 { - return g.count -} - -type IndexMatcherGroup struct { - Matchers []IndexMatcher -} - -func (g *IndexMatcherGroup) Match(input string) []uint32 { - var offset uint32 - for _, m := range g.Matchers { - if res := m.Match(input); len(res) > 0 { - if offset == 0 { - return res - } - shifted := make([]uint32, len(res)) - for i, id := range res { - shifted[i] = id + offset - } - return shifted - } - offset += m.Size() - } - return nil -} - -func (g *IndexMatcherGroup) Size() uint32 { - var count uint32 - for _, m := range g.Matchers { - count += m.Size() - } - return count -} diff --git a/infra/conf/dns.go b/infra/conf/dns.go index a65f0ee8..d55dada6 100644 --- a/infra/conf/dns.go +++ b/infra/conf/dns.go @@ -3,6 +3,7 @@ package conf import ( "bufio" "encoding/json" + "io" "os" "path/filepath" "runtime" @@ -10,8 +11,8 @@ import ( "strings" "github.com/xtls/xray-core/app/dns" - "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" ) @@ -20,7 +21,7 @@ type NameServerConfig struct { ClientIP *Address `json:"clientIp"` Port uint16 `json:"port"` SkipFallback bool `json:"skipFallback"` - Domains []string `json:"domains"` + Domains StringList `json:"domains"` ExpectedIPs StringList `json:"expectedIPs"` ExpectIPs StringList `json:"expectIPs"` QueryStrategy string `json:"queryStrategy"` @@ -46,7 +47,7 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error { ClientIP *Address `json:"clientIp"` Port uint16 `json:"port"` SkipFallback bool `json:"skipFallback"` - Domains []string `json:"domains"` + Domains StringList `json:"domains"` ExpectedIPs StringList `json:"expectedIPs"` ExpectIPs StringList `json:"expectIPs"` QueryStrategy string `json:"queryStrategy"` @@ -80,45 +81,14 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error { return errors.New("failed to parse name server: ", string(data)) } -func toDomainMatchingType(t router.Domain_Type) dns.DomainMatchingType { - switch t { - case router.Domain_Domain: - return dns.DomainMatchingType_Subdomain - case router.Domain_Full: - return dns.DomainMatchingType_Full - case router.Domain_Plain: - return dns.DomainMatchingType_Keyword - case router.Domain_Regex: - return dns.DomainMatchingType_Regex - default: - panic("unknown domain type") - } -} - func (c *NameServerConfig) Build() (*dns.NameServer, error) { if c.Address == nil { - return nil, errors.New("NameServer address is not specified.") + return nil, errors.New("nameserver address is not specified") } - var domains []*dns.NameServer_PriorityDomain - var originalRules []*dns.NameServer_OriginalRule - - for _, rule := range c.Domains { - parsedDomain, err := parseDomainRule(rule) - if err != nil { - return nil, errors.New("invalid domain rule: ", rule).Base(err) - } - - for _, pd := range parsedDomain { - domains = append(domains, &dns.NameServer_PriorityDomain{ - Type: toDomainMatchingType(pd.Type), - Domain: pd.Value, - }) - } - originalRules = append(originalRules, &dns.NameServer_OriginalRule{ - Rule: rule, - Size: uint32(len(parsedDomain)), - }) + domainRules, err := geodata.ParseDomainRules(c.Domains, geodata.Domain_Substr) + if err != nil { + return nil, err } if len(c.ExpectedIPs) == 0 { @@ -145,14 +115,14 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) { } } - expectedGeoipList, err := ToCidrList(newExpectedIPs) + expectedIPRules, err := geodata.ParseIPRules(newExpectedIPs) if err != nil { - return nil, errors.New("invalid expected IP rule: ", c.ExpectedIPs).Base(err) + return nil, err } - unexpectedGeoipList, err := ToCidrList(newUnexpectedIPs) + unexpectedIPRules, err := geodata.ParseIPRules(newUnexpectedIPs) if err != nil { - return nil, errors.New("invalid unexpected IP rule: ", c.UnexpectedIPs).Base(err) + return nil, err } var myClientIP []byte @@ -169,32 +139,24 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) { Address: c.Address.Build(), Port: uint32(c.Port), }, - ClientIp: myClientIP, - SkipFallback: c.SkipFallback, - PrioritizedDomain: domains, - ExpectedGeoip: expectedGeoipList, - OriginalRules: originalRules, - QueryStrategy: resolveQueryStrategy(c.QueryStrategy), - ActPrior: actPrior, - Tag: c.Tag, - TimeoutMs: c.TimeoutMs, - DisableCache: c.DisableCache, - ServeStale: c.ServeStale, - ServeExpiredTTL: c.ServeExpiredTTL, - FinalQuery: c.FinalQuery, - UnexpectedGeoip: unexpectedGeoipList, - ActUnprior: actUnprior, + ClientIp: myClientIP, + SkipFallback: c.SkipFallback, + Domain: domainRules, + ExpectedIp: expectedIPRules, + QueryStrategy: resolveQueryStrategy(c.QueryStrategy), + ActPrior: actPrior, + Tag: c.Tag, + TimeoutMs: c.TimeoutMs, + DisableCache: c.DisableCache, + ServeStale: c.ServeStale, + ServeExpiredTTL: c.ServeExpiredTTL, + FinalQuery: c.FinalQuery, + UnexpectedIp: unexpectedIPRules, + ActUnprior: actUnprior, }, nil } -var typeMap = map[router.Domain_Type]dns.DomainMatchingType{ - router.Domain_Full: dns.DomainMatchingType_Full, - router.Domain_Domain: dns.DomainMatchingType_Subdomain, - router.Domain_Plain: dns.DomainMatchingType_Keyword, - router.Domain_Regex: dns.DomainMatchingType_Regex, -} - -// DNSConfig is a JSON serializable object for dns.Config. +// DNSConfig is a JSON serializable object for dns.Config type DNSConfig struct { Servers []*NameServerConfig `json:"servers"` Hosts *HostsWrapper `json:"hosts"` @@ -246,7 +208,7 @@ type HostsWrapper struct { Hosts map[string]*HostAddress } -func getHostMapping(ha *HostAddress) *dns.Config_HostMapping { +func newHostMapping(ha *HostAddress) *dns.Config_HostMapping { if ha.addr != nil { if ha.addr.Family().IsDomain() { return &dns.Config_HostMapping{ @@ -290,109 +252,15 @@ func (m *HostsWrapper) UnmarshalJSON(data []byte) error { // Build implements Buildable func (m *HostsWrapper) Build() ([]*dns.Config_HostMapping, error) { - mappings := make([]*dns.Config_HostMapping, 0, 20) - - domains := make([]string, 0, len(m.Hosts)) - for domain := range m.Hosts { - domains = append(domains, domain) - } - sort.Strings(domains) - - for _, domain := range domains { - switch { - case strings.HasPrefix(domain, "domain:"): - domainName := domain[7:] - if len(domainName) == 0 { - return nil, errors.New("empty domain type of rule: ", domain) - } - mapping := getHostMapping(m.Hosts[domain]) - mapping.Type = dns.DomainMatchingType_Subdomain - mapping.Domain = domainName - mappings = append(mappings, mapping) - - case strings.HasPrefix(domain, "geosite:"): - listName := domain[8:] - if len(listName) == 0 { - return nil, errors.New("empty geosite rule: ", domain) - } - geositeList, err := loadGeositeWithAttr("geosite.dat", listName) - if err != nil { - return nil, errors.New("failed to load geosite: ", listName).Base(err) - } - for _, d := range geositeList { - mapping := getHostMapping(m.Hosts[domain]) - mapping.Type = typeMap[d.Type] - mapping.Domain = d.Value - mappings = append(mappings, mapping) - } - - case strings.HasPrefix(domain, "regexp:"): - regexpVal := domain[7:] - if len(regexpVal) == 0 { - return nil, errors.New("empty regexp type of rule: ", domain) - } - mapping := getHostMapping(m.Hosts[domain]) - mapping.Type = dns.DomainMatchingType_Regex - mapping.Domain = regexpVal - mappings = append(mappings, mapping) - - case strings.HasPrefix(domain, "keyword:"): - keywordVal := domain[8:] - if len(keywordVal) == 0 { - return nil, errors.New("empty keyword type of rule: ", domain) - } - mapping := getHostMapping(m.Hosts[domain]) - mapping.Type = dns.DomainMatchingType_Keyword - mapping.Domain = keywordVal - mappings = append(mappings, mapping) - - case strings.HasPrefix(domain, "full:"): - fullVal := domain[5:] - if len(fullVal) == 0 { - return nil, errors.New("empty full domain type of rule: ", domain) - } - mapping := getHostMapping(m.Hosts[domain]) - mapping.Type = dns.DomainMatchingType_Full - mapping.Domain = fullVal - mappings = append(mappings, mapping) - - case strings.HasPrefix(domain, "dotless:"): - mapping := getHostMapping(m.Hosts[domain]) - mapping.Type = dns.DomainMatchingType_Regex - switch substr := domain[8:]; { - case substr == "": - mapping.Domain = "^[^.]*$" - case !strings.Contains(substr, "."): - mapping.Domain = "^[^.]*" + substr + "[^.]*$" - default: - return nil, errors.New("substr in dotless rule should not contain a dot: ", substr) - } - mappings = append(mappings, mapping) - - case strings.HasPrefix(domain, "ext:"): - kv := strings.Split(domain[4:], ":") - if len(kv) != 2 { - return nil, errors.New("invalid external resource: ", domain) - } - filename := kv[0] - list := kv[1] - geositeList, err := loadGeositeWithAttr(filename, list) - if err != nil { - return nil, errors.New("failed to load domain list: ", list, " from ", filename).Base(err) - } - for _, d := range geositeList { - mapping := getHostMapping(m.Hosts[domain]) - mapping.Type = typeMap[d.Type] - mapping.Domain = d.Value - mappings = append(mappings, mapping) - } - - default: - mapping := getHostMapping(m.Hosts[domain]) - mapping.Type = dns.DomainMatchingType_Full - mapping.Domain = domain - mappings = append(mappings, mapping) + mappings := make([]*dns.Config_HostMapping, 0, len(m.Hosts)) + for rule, addrs := range m.Hosts { + mapping := newHostMapping(addrs) + rule, err := geodata.ParseDomainRule(rule, geodata.Domain_Full) + if err != nil { + return nil, err } + mapping.Domain = rule + mappings = append(mappings, mapping) } return mappings, nil } @@ -504,9 +372,7 @@ func (c *DNSConfig) Build() (*dns.Config, error) { if err != nil { return nil, errors.New("failed to read system hosts").Base(err) } - for domain, ips := range systemHosts { - config.StaticHosts = append(config.StaticHosts, &dns.Config_HostMapping{Ip: ips, Domain: domain, Type: dns.DomainMatchingType_Full}) - } + config.StaticHosts = append(config.StaticHosts, systemHosts...) } return config, nil @@ -527,7 +393,7 @@ func resolveQueryStrategy(queryStrategy string) dns.QueryStrategy { } } -func readSystemHosts() (map[string][][]byte, error) { +func readSystemHosts() ([]*dns.Config_HostMapping, error) { var hostsPath string switch runtime.GOOS { case "windows": @@ -542,12 +408,16 @@ func readSystemHosts() (map[string][][]byte, error) { } defer file.Close() - hostsMap := make(map[string][][]byte) - scanner := bufio.NewScanner(file) + return readSystemHostsFrom(file) +} + +func readSystemHostsFrom(r io.Reader) ([]*dns.Config_HostMapping, error) { + hosts := make(map[string][][]byte, 16) + scanner := bufio.NewScanner(r) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if i := strings.IndexByte(line, '#'); i >= 0 { - // Discard comments. + // Strip inline comments before splitting the line into fields. line = line[0:i] } f := strings.Fields(line) @@ -558,19 +428,28 @@ func readSystemHosts() (map[string][][]byte, error) { if addr.Family().IsDomain() { continue } - ip := addr.IP() for i := 1; i < len(f); i++ { domain := strings.TrimSuffix(f[i], ".") domain = strings.ToLower(domain) - if v, ok := hostsMap[domain]; ok { - hostsMap[domain] = append(v, ip) - } else { - hostsMap[domain] = [][]byte{ip} - } + hosts[domain] = append(hosts[domain], addr.IP()) } } if err := scanner.Err(); err != nil { return nil, err } + + hostsMap := make([]*dns.Config_HostMapping, 0, len(hosts)) + for domain, ips := range hosts { + // ParseDomainRule accepts rule syntax too, not just plain domains. + rule, err := geodata.ParseDomainRule(domain, geodata.Domain_Full) + if err != nil { + return nil, err + } + hostsMap = append(hostsMap, &dns.Config_HostMapping{ + Domain: rule, + Ip: ips, + }) + } + return hostsMap, nil } diff --git a/infra/conf/dns_test.go b/infra/conf/dns_test.go index a9739668..278f34c3 100644 --- a/infra/conf/dns_test.go +++ b/infra/conf/dns_test.go @@ -4,10 +4,13 @@ import ( "encoding/json" "testing" + "github.com/google/go-cmp/cmp" "github.com/xtls/xray-core/app/dns" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" . "github.com/xtls/xray-core/infra/conf" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" ) func TestDNSConfigParsing(t *testing.T) { @@ -22,7 +25,7 @@ func TestDNSConfigParsing(t *testing.T) { } expectedServeStale := true expectedServeExpiredTTL := uint32(172800) - runMultiTestCase(t, []TestCase{ + testCases := []TestCase{ { Input: `{ "servers": [{ @@ -61,16 +64,9 @@ func TestDNSConfigParsing(t *testing.T) { Port: 5353, }, SkipFallback: true, - PrioritizedDomain: []*dns.NameServer_PriorityDomain{ + Domain: []*geodata.DomainRule{ { - Type: dns.DomainMatchingType_Subdomain, - Domain: "example.com", - }, - }, - OriginalRules: []*dns.NameServer_OriginalRule{ - { - Rule: "domain:example.com", - Size: 1, + Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "example.com"}}, }, }, ServeStale: &expectedServeStale, @@ -80,28 +76,23 @@ func TestDNSConfigParsing(t *testing.T) { }, StaticHosts: []*dns.Config_HostMapping{ { - Type: dns.DomainMatchingType_Subdomain, - Domain: "example.com", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Domain, Value: "example.com"}}}, ProxiedDomain: "google.com", }, { - Type: dns.DomainMatchingType_Full, - Domain: "example.com", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "example.com"}}}, Ip: [][]byte{{127, 0, 0, 1}}, }, { - Type: dns.DomainMatchingType_Keyword, - Domain: "google", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Substr, Value: "google"}}}, Ip: [][]byte{{8, 8, 8, 8}, {8, 8, 4, 4}}, }, { - Type: dns.DomainMatchingType_Regex, - Domain: ".*\\.com", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Regex, Value: ".*\\.com"}}}, Ip: [][]byte{{8, 8, 4, 4}}, }, { - Type: dns.DomainMatchingType_Full, - Domain: "www.example.org", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "www.example.org"}}}, Ip: [][]byte{{127, 0, 0, 1}, {127, 0, 0, 2}}, }, }, @@ -113,5 +104,21 @@ func TestDNSConfigParsing(t *testing.T) { DisableFallback: true, }, }, - }) + } + + for _, testCase := range testCases { + actual, err := testCase.Parser(testCase.Input) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff( + testCase.Output, + actual, + protocmp.Transform(), + protocmp.SortRepeatedFields(&dns.Config{}, "static_hosts"), + ); diff != "" { + t.Fatalf("Failed in test case:\n%s\nDiff (-want +got):\n%s", testCase.Input, diff) + } + } } diff --git a/infra/conf/router.go b/infra/conf/router.go index a488d397..af83c743 100644 --- a/infra/conf/router.go +++ b/infra/conf/router.go @@ -1,20 +1,14 @@ package conf import ( - "bufio" - "bytes" "encoding/json" - "io" - "runtime" - "strconv" "strings" "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/platform" - "github.com/xtls/xray-core/common/platform/filesystem" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/serial" + "google.golang.org/protobuf/proto" ) @@ -104,15 +98,14 @@ func (c *RouterConfig) Build() (*router.Config, error) { if c != nil { rawRuleList = c.RuleList } - for _, rawRule := range rawRuleList { rule, err := parseRule(rawRule) if err != nil { return nil, err } - config.Rule = append(config.Rule, rule) } + for _, rawBalancer := range c.Balancers { balancer, err := rawBalancer.Build() if err != nil { @@ -120,6 +113,7 @@ func (c *RouterConfig) Build() (*router.Config, error) { } config.BalancingRule = append(config.BalancingRule, balancer) } + return config, nil } @@ -129,399 +123,6 @@ type RouterRule struct { BalancerTag string `json:"balancerTag"` } -func parseIP(s string) (*router.CIDR, error) { - var addr, mask string - i := strings.Index(s, "/") - if i < 0 { - addr = s - } else { - addr = s[:i] - mask = s[i+1:] - } - ip := net.ParseAddress(addr) - switch ip.Family() { - case net.AddressFamilyIPv4: - bits := uint32(32) - if len(mask) > 0 { - bits64, err := strconv.ParseUint(mask, 10, 32) - if err != nil { - return nil, errors.New("invalid network mask for router: ", mask).Base(err) - } - bits = uint32(bits64) - } - if bits > 32 { - return nil, errors.New("invalid network mask for router: ", bits) - } - return &router.CIDR{ - Ip: []byte(ip.IP()), - Prefix: bits, - }, nil - case net.AddressFamilyIPv6: - bits := uint32(128) - if len(mask) > 0 { - bits64, err := strconv.ParseUint(mask, 10, 32) - if err != nil { - return nil, errors.New("invalid network mask for router: ", mask).Base(err) - } - bits = uint32(bits64) - } - if bits > 128 { - return nil, errors.New("invalid network mask for router: ", bits) - } - return &router.CIDR{ - Ip: []byte(ip.IP()), - Prefix: bits, - }, nil - default: - return nil, errors.New("unsupported address for router: ", s) - } -} - -func loadFile(file, code string) ([]byte, error) { - runtime.GC() - r, err := filesystem.OpenAsset(file) - defer r.Close() - if err != nil { - return nil, errors.New("failed to open file: ", file).Base(err) - } - bs := find(r, []byte(code)) - if bs == nil { - return nil, errors.New("code not found in ", file, ": ", code) - } - return bs, nil -} - -func loadIP(file, code string) ([]*router.CIDR, error) { - bs, err := loadFile(file, code) - if err != nil { - return nil, err - } - var geoip router.GeoIP - if err := proto.Unmarshal(bs, &geoip); err != nil { - return nil, errors.New("error unmarshal IP in ", file, ": ", code).Base(err) - } - defer runtime.GC() // or debug.FreeOSMemory() - return geoip.Cidr, nil -} - -func loadSite(file, code string) ([]*router.Domain, error) { - - // Check if domain matcher cache is provided via environment - domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) - if domainMatcherPath != "" { - return []*router.Domain{{}}, nil - } - - bs, err := loadFile(file, code) - if err != nil { - return nil, err - } - var geosite router.GeoSite - if err := proto.Unmarshal(bs, &geosite); err != nil { - return nil, errors.New("error unmarshal Site in ", file, ": ", code).Base(err) - } - defer runtime.GC() // or debug.FreeOSMemory() - return geosite.Domain, nil -} - -func decodeVarint(r *bufio.Reader) (uint64, error) { - var x uint64 - for shift := uint(0); shift < 64; shift += 7 { - b, err := r.ReadByte() - if err != nil { - return 0, err - } - x |= (uint64(b) & 0x7F) << shift - if (b & 0x80) == 0 { - return x, nil - } - } - // The number is too large to represent in a 64-bit value. - return 0, errors.New("varint overflow") -} - -func find(r io.Reader, code []byte) []byte { - codeL := len(code) - if codeL == 0 { - return nil - } - - br := bufio.NewReaderSize(r, 64*1024) - need := 2 + codeL - prefixBuf := make([]byte, need) - - for { - if _, err := br.ReadByte(); err != nil { - return nil - } - - x, err := decodeVarint(br) - if err != nil { - return nil - } - bodyL := int(x) - if bodyL <= 0 { - return nil - } - - prefixL := bodyL - if prefixL > need { - prefixL = need - } - prefix := prefixBuf[:prefixL] - if _, err := io.ReadFull(br, prefix); err != nil { - return nil - } - - match := false - if bodyL >= need { - if int(prefix[1]) == codeL && bytes.Equal(prefix[2:need], code) { - match = true - } - } - - remain := bodyL - prefixL - if match { - out := make([]byte, bodyL) - copy(out, prefix) - if remain > 0 { - if _, err := io.ReadFull(br, out[prefixL:]); err != nil { - return nil - } - } - return out - } - - if remain > 0 { - if _, err := br.Discard(remain); err != nil { - return nil - } - } - } -} - -type AttributeMatcher interface { - Match(*router.Domain) bool -} - -type BooleanMatcher string - -func (m BooleanMatcher) Match(domain *router.Domain) bool { - for _, attr := range domain.Attribute { - if attr.Key == string(m) { - return true - } - } - return false -} - -type AttributeList struct { - matcher []AttributeMatcher -} - -func (al *AttributeList) Match(domain *router.Domain) bool { - for _, matcher := range al.matcher { - if !matcher.Match(domain) { - return false - } - } - return true -} - -func (al *AttributeList) IsEmpty() bool { - return len(al.matcher) == 0 -} - -func parseAttrs(attrs []string) *AttributeList { - al := new(AttributeList) - for _, attr := range attrs { - lc := strings.ToLower(attr) - al.matcher = append(al.matcher, BooleanMatcher(lc)) - } - return al -} - -func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, error) { - parts := strings.Split(siteWithAttr, "@") - if len(parts) == 0 { - return nil, errors.New("empty site") - } - country := strings.ToUpper(parts[0]) - attrs := parseAttrs(parts[1:]) - domains, err := loadSite(file, country) - if err != nil { - return nil, err - } - - if attrs.IsEmpty() { - return domains, nil - } - - filteredDomains := make([]*router.Domain, 0, len(domains)) - for _, domain := range domains { - if attrs.Match(domain) { - filteredDomains = append(filteredDomains, domain) - } - } - - return filteredDomains, nil -} - -func parseDomainRule(domain string) ([]*router.Domain, error) { - if strings.HasPrefix(domain, "geosite:") { - country := strings.ToUpper(domain[8:]) - domains, err := loadGeositeWithAttr("geosite.dat", country) - if err != nil { - return nil, errors.New("failed to load geosite: ", country).Base(err) - } - return domains, nil - } - isExtDatFile := 0 - { - const prefix = "ext:" - if strings.HasPrefix(domain, prefix) { - isExtDatFile = len(prefix) - } - const prefixQualified = "ext-domain:" - if strings.HasPrefix(domain, prefixQualified) { - isExtDatFile = len(prefixQualified) - } - } - if isExtDatFile != 0 { - kv := strings.Split(domain[isExtDatFile:], ":") - if len(kv) != 2 { - return nil, errors.New("invalid external resource: ", domain) - } - filename := kv[0] - country := kv[1] - domains, err := loadGeositeWithAttr(filename, country) - if err != nil { - return nil, errors.New("failed to load external sites: ", country, " from ", filename).Base(err) - } - return domains, nil - } - - domainRule := new(router.Domain) - switch { - case strings.HasPrefix(domain, "regexp:"): - domainRule.Type = router.Domain_Regex - domainRule.Value = domain[7:] - - case strings.HasPrefix(domain, "domain:"): - domainRule.Type = router.Domain_Domain - domainRule.Value = domain[7:] - - case strings.HasPrefix(domain, "full:"): - domainRule.Type = router.Domain_Full - domainRule.Value = domain[5:] - - case strings.HasPrefix(domain, "keyword:"): - domainRule.Type = router.Domain_Plain - domainRule.Value = domain[8:] - - case strings.HasPrefix(domain, "dotless:"): - domainRule.Type = router.Domain_Regex - switch substr := domain[8:]; { - case substr == "": - domainRule.Value = "^[^.]*$" - case !strings.Contains(substr, "."): - domainRule.Value = "^[^.]*" + substr + "[^.]*$" - default: - return nil, errors.New("substr in dotless rule should not contain a dot: ", substr) - } - - default: - domainRule.Type = router.Domain_Plain - domainRule.Value = domain - } - return []*router.Domain{domainRule}, nil -} - -func ToCidrList(ips StringList) ([]*router.GeoIP, error) { - var geoipList []*router.GeoIP - var customCidrs []*router.CIDR - - for _, ip := range ips { - if strings.HasPrefix(ip, "geoip:") { - country := ip[6:] - isReverseMatch := false - if strings.HasPrefix(ip, "geoip:!") { - country = ip[7:] - isReverseMatch = true - } - if len(country) == 0 { - return nil, errors.New("empty country name in rule") - } - geoip, err := loadIP("geoip.dat", strings.ToUpper(country)) - if err != nil { - return nil, errors.New("failed to load GeoIP: ", country).Base(err) - } - - geoipList = append(geoipList, &router.GeoIP{ - CountryCode: strings.ToUpper(country), - Cidr: geoip, - ReverseMatch: isReverseMatch, - }) - continue - } - isExtDatFile := 0 - { - const prefix = "ext:" - if strings.HasPrefix(ip, prefix) { - isExtDatFile = len(prefix) - } - const prefixQualified = "ext-ip:" - if strings.HasPrefix(ip, prefixQualified) { - isExtDatFile = len(prefixQualified) - } - } - if isExtDatFile != 0 { - kv := strings.Split(ip[isExtDatFile:], ":") - if len(kv) != 2 { - return nil, errors.New("invalid external resource: ", ip) - } - - filename := kv[0] - country := kv[1] - if len(filename) == 0 || len(country) == 0 { - return nil, errors.New("empty filename or empty country in rule") - } - - isReverseMatch := false - if strings.HasPrefix(country, "!") { - country = country[1:] - isReverseMatch = true - } - geoip, err := loadIP(filename, strings.ToUpper(country)) - if err != nil { - return nil, errors.New("failed to load IPs: ", country, " from ", filename).Base(err) - } - - geoipList = append(geoipList, &router.GeoIP{ - CountryCode: strings.ToUpper(filename + "_" + country), - Cidr: geoip, - ReverseMatch: isReverseMatch, - }) - - continue - } - - ipRule, err := parseIP(ip) - if err != nil { - return nil, errors.New("invalid IP: ", ip).Base(err) - } - customCidrs = append(customCidrs, ipRule) - } - - if len(customCidrs) > 0 { - geoipList = append(geoipList, &router.GeoIP{ - Cidr: customCidrs, - }) - } - - return geoipList, nil -} - type WebhookRuleConfig struct { URL string `json:"url"` Deduplication uint32 `json:"deduplication"` @@ -571,31 +172,27 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { } if rawFieldRule.Domain != nil { - for _, domain := range *rawFieldRule.Domain { - rules, err := parseDomainRule(domain) - if err != nil { - return nil, errors.New("failed to parse domain rule: ", domain).Base(err) - } - rule.Domain = append(rule.Domain, rules...) - } - } - - if rawFieldRule.Domains != nil { - for _, domain := range *rawFieldRule.Domains { - rules, err := parseDomainRule(domain) - if err != nil { - return nil, errors.New("failed to parse domain rule: ", domain).Base(err) - } - rule.Domain = append(rule.Domain, rules...) - } - } - - if rawFieldRule.IP != nil { - geoipList, err := ToCidrList(*rawFieldRule.IP) + rules, err := geodata.ParseDomainRules(*rawFieldRule.Domain, geodata.Domain_Substr) if err != nil { return nil, err } - rule.Geoip = geoipList + rule.Domain = rules + } + + if rawFieldRule.Domains != nil { + rules, err := geodata.ParseDomainRules(*rawFieldRule.Domains, geodata.Domain_Substr) + if err != nil { + return nil, err + } + rule.Domain = rules + } + + if rawFieldRule.IP != nil { + rules, err := geodata.ParseIPRules(*rawFieldRule.IP) + if err != nil { + return nil, err + } + rule.Ip = rules } if rawFieldRule.Port != nil { @@ -611,11 +208,11 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { } if rawFieldRule.SourceIP != nil { - geoipList, err := ToCidrList(*rawFieldRule.SourceIP) + rules, err := geodata.ParseIPRules(*rawFieldRule.SourceIP) if err != nil { return nil, err } - rule.SourceGeoip = geoipList + rule.SourceIp = rules } if rawFieldRule.SourcePort != nil { @@ -623,11 +220,11 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { } if rawFieldRule.LocalIP != nil { - geoipList, err := ToCidrList(*rawFieldRule.LocalIP) + rules, err := geodata.ParseIPRules(*rawFieldRule.LocalIP) if err != nil { return nil, err } - rule.LocalGeoip = geoipList + rule.LocalIp = rules } if rawFieldRule.LocalPort != nil { diff --git a/infra/conf/router_test.go b/infra/conf/router_test.go index 2533046c..26ebff1c 100644 --- a/infra/conf/router_test.go +++ b/infra/conf/router_test.go @@ -2,78 +2,19 @@ package conf_test import ( "encoding/json" - "fmt" - "os" - "path/filepath" "testing" "time" _ "unsafe" "github.com/xtls/xray-core/app/router" - "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/platform" - "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/common/serial" . "github.com/xtls/xray-core/infra/conf" + "google.golang.org/protobuf/proto" ) -func getAssetPath(file string) (string, error) { - path := platform.GetAssetLocation(file) - _, err := os.Stat(path) - if os.IsNotExist(err) { - path := filepath.Join("..", "..", "resources", file) - _, err := os.Stat(path) - if os.IsNotExist(err) { - return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file) - } - if err != nil { - return "", fmt.Errorf("can't stat %s: %v", path, err) - } - return path, nil - } - if err != nil { - return "", fmt.Errorf("can't stat %s: %v", path, err) - } - - return path, nil -} - -func TestToCidrList(t *testing.T) { - tempDir, err := os.MkdirTemp("", "test-") - if err != nil { - t.Fatalf("can't create temp dir: %v", err) - } - defer os.RemoveAll(tempDir) - - geoipPath, err := getAssetPath("geoip.dat") - if err != nil { - t.Fatal(err) - } - - common.Must(filesystem.CopyFile(filepath.Join(tempDir, "geoip.dat"), geoipPath)) - common.Must(filesystem.CopyFile(filepath.Join(tempDir, "geoiptestrouter.dat"), geoipPath)) - - os.Setenv("xray.location.asset", tempDir) - defer os.Unsetenv("xray.location.asset") - - ips := StringList([]string{ - "geoip:us", - "geoip:cn", - "geoip:!cn", - "ext:geoiptestrouter.dat:!cn", - "ext:geoiptestrouter.dat:ca", - "ext-ip:geoiptestrouter.dat:!cn", - "ext-ip:geoiptestrouter.dat:!ca", - }) - - _, err = ToCidrList(ips) - if err != nil { - t.Fatalf("Failed to parse geoip list, got %s", err) - } -} - func TestRouterConfig(t *testing.T) { createParser := func() func(string) (proto.Message, error) { return func(s string) (proto.Message, error) { @@ -182,29 +123,27 @@ func TestRouterConfig(t *testing.T) { }, Rule: []*router.RoutingRule{ { - Domain: []*router.Domain{ - { - Type: router.Domain_Plain, - Value: "baidu.com", - }, - { - Type: router.Domain_Plain, - Value: "qq.com", - }, + Domain: []*geodata.DomainRule{ + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Substr, Value: "baidu.com"}}}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Substr, Value: "qq.com"}}}, }, TargetTag: &router.RoutingRule_Tag{ Tag: "direct", }, }, { - Geoip: []*router.GeoIP{ + Ip: []*geodata.IPRule{ { - Cidr: []*router.CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{10, 0, 0, 0}, Prefix: 8, }, - { + }, + }, + { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Prefix: 128, }, @@ -265,29 +204,27 @@ func TestRouterConfig(t *testing.T) { DomainStrategy: router.Config_IpIfNonMatch, Rule: []*router.RoutingRule{ { - Domain: []*router.Domain{ - { - Type: router.Domain_Plain, - Value: "baidu.com", - }, - { - Type: router.Domain_Plain, - Value: "qq.com", - }, + Domain: []*geodata.DomainRule{ + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Substr, Value: "baidu.com"}}}, + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Substr, Value: "qq.com"}}}, }, TargetTag: &router.RoutingRule_Tag{ Tag: "direct", }, }, { - Geoip: []*router.GeoIP{ + Ip: []*geodata.IPRule{ { - Cidr: []*router.CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{10, 0, 0, 0}, Prefix: 8, }, - { + }, + }, + { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Prefix: 128, }, diff --git a/infra/conf/xray.go b/infra/conf/xray.go index a2410244..fce642f0 100644 --- a/infra/conf/xray.go +++ b/infra/conf/xray.go @@ -1,21 +1,16 @@ package conf import ( - "bytes" "context" "encoding/json" - "os" "path/filepath" - "sort" "strings" "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/app/proxyman" - "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/app/stats" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/platform" "github.com/xtls/xray-core/common/serial" core "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/transport/internet" @@ -617,187 +612,6 @@ func (c *Config) Build() (*core.Config, error) { return config, nil } -func (c *Config) BuildMPHCache(customMatcherFilePath *string) error { - var geosite []*router.GeoSite - deps := make(map[string][]string) - uniqueGeosites := make(map[string]bool) - uniqueTags := make(map[string]bool) - matcherFilePath := platform.GetAssetLocation("matcher.cache") - - if customMatcherFilePath != nil { - matcherFilePath = *customMatcherFilePath - } - - processGeosite := func(dStr string) bool { - prefix := "" - if strings.HasPrefix(dStr, "geosite:") { - prefix = "geosite:" - } else if strings.HasPrefix(dStr, "ext-domain:") { - prefix = "ext-domain:" - } - if prefix == "" { - return false - } - key := strings.ToLower(dStr) - country := strings.ToUpper(dStr[len(prefix):]) - if !uniqueGeosites[country] { - ds, err := loadGeositeWithAttr("geosite.dat", country) - if err == nil { - uniqueGeosites[country] = true - geosite = append(geosite, &router.GeoSite{CountryCode: key, Domain: ds}) - } - } - return true - } - - processDomains := func(tag string, rawDomains []string) { - var manualDomains []*router.Domain - var dDeps []string - for _, dStr := range rawDomains { - if processGeosite(dStr) { - dDeps = append(dDeps, strings.ToLower(dStr)) - } else { - ds, err := parseDomainRule(dStr) - if err == nil { - manualDomains = append(manualDomains, ds...) - } - } - } - if len(manualDomains) > 0 { - if !uniqueTags[tag] { - uniqueTags[tag] = true - geosite = append(geosite, &router.GeoSite{CountryCode: tag, Domain: manualDomains}) - } - } - if len(dDeps) > 0 { - deps[tag] = append(deps[tag], dDeps...) - } - } - - // proccess rules - if c.RouterConfig != nil { - for _, rawRule := range c.RouterConfig.RuleList { - type SimpleRule struct { - RuleTag string `json:"ruleTag"` - Domain *StringList `json:"domain"` - Domains *StringList `json:"domains"` - } - var sr SimpleRule - json.Unmarshal(rawRule, &sr) - if sr.RuleTag == "" { - continue - } - var allDomains []string - if sr.Domain != nil { - allDomains = append(allDomains, *sr.Domain...) - } - if sr.Domains != nil { - allDomains = append(allDomains, *sr.Domains...) - } - processDomains(sr.RuleTag, allDomains) - } - } - - // proccess dns servers - if c.DNSConfig != nil { - for _, ns := range c.DNSConfig.Servers { - if ns.Tag == "" { - continue - } - processDomains(ns.Tag, ns.Domains) - } - } - - var hostIPs map[string][]string - if c.DNSConfig != nil && c.DNSConfig.Hosts != nil { - hostIPs = make(map[string][]string) - var hostDeps []string - var hostPatterns []string - - // use raw map to avoid expanding geosites - var domains []string - for domain := range c.DNSConfig.Hosts.Hosts { - domains = append(domains, domain) - } - sort.Strings(domains) - - manualHostGroups := make(map[string][]*router.Domain) - manualHostIPs := make(map[string][]string) - manualHostNames := make(map[string]string) - - for _, domain := range domains { - ha := c.DNSConfig.Hosts.Hosts[domain] - m := getHostMapping(ha) - - var ips []string - if m.ProxiedDomain != "" { - ips = append(ips, m.ProxiedDomain) - } else { - for _, ip := range m.Ip { - ips = append(ips, net.IPAddress(ip).String()) - } - } - - if processGeosite(domain) { - tag := strings.ToLower(domain) - hostDeps = append(hostDeps, tag) - hostIPs[tag] = ips - hostPatterns = append(hostPatterns, domain) - } else { - // build manual domains by their destination IPs - sort.Strings(ips) - ipKey := strings.Join(ips, ",") - ds, err := parseDomainRule(domain) - if err == nil { - manualHostGroups[ipKey] = append(manualHostGroups[ipKey], ds...) - manualHostIPs[ipKey] = ips - if _, ok := manualHostNames[ipKey]; !ok { - manualHostNames[ipKey] = domain - } - } - } - } - - // create manual host groups - var ipKeys []string - for k := range manualHostGroups { - ipKeys = append(ipKeys, k) - } - sort.Strings(ipKeys) - - for _, k := range ipKeys { - tag := manualHostNames[k] - geosite = append(geosite, &router.GeoSite{CountryCode: tag, Domain: manualHostGroups[k]}) - hostDeps = append(hostDeps, tag) - hostIPs[tag] = manualHostIPs[k] - - // record tag _ORDER links the matcher to IP addresses - hostPatterns = append(hostPatterns, tag) - } - - deps["HOSTS"] = hostDeps - hostIPs["_ORDER"] = hostPatterns - } - - f, err := os.Create(matcherFilePath) - if err != nil { - return err - } - defer f.Close() - - var buf bytes.Buffer - - if err := router.SerializeGeoSiteList(geosite, deps, hostIPs, &buf); err != nil { - return err - } - - if _, err := f.Write(buf.Bytes()); err != nil { - return err - } - - return nil -} - // Convert string to Address. func ParseSendThough(Addr *string) *Address { var addr Address diff --git a/infra/conf/xray_test.go b/infra/conf/xray_test.go index c1349fd2..a3a767f2 100644 --- a/infra/conf/xray_test.go +++ b/infra/conf/xray_test.go @@ -11,6 +11,7 @@ import ( "github.com/xtls/xray-core/app/proxyman" "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/geodata" clog "github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" @@ -95,10 +96,10 @@ func TestXrayConfig(t *testing.T) { DomainStrategy: router.Config_AsIs, Rule: []*router.RoutingRule{ { - Geoip: []*router.GeoIP{ + Ip: []*geodata.IPRule{ { - Cidr: []*router.CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{10, 0, 0, 0}, Prefix: 8, }, diff --git a/main/commands/all/buildmphcache.go b/main/commands/all/buildmphcache.go deleted file mode 100644 index 6c45205e..00000000 --- a/main/commands/all/buildmphcache.go +++ /dev/null @@ -1,52 +0,0 @@ -package all - -import ( - "os" - - "github.com/xtls/xray-core/common/platform" - "github.com/xtls/xray-core/infra/conf/serial" - "github.com/xtls/xray-core/main/commands/base" -) - -var cmdBuildMphCache = &base.Command{ - UsageLine: `{{.Exec}} buildMphCache [-c config.json] [-o domain.cache]`, - Short: `Build domain matcher cache`, - Long: ` -Build domain matcher cache from a configuration file. - -Example: {{.Exec}} buildMphCache -c config.json -o domain.cache -`, -} - -func init() { - cmdBuildMphCache.Run = executeBuildMphCache -} - -var ( - configPath = cmdBuildMphCache.Flag.String("c", "config.json", "Config file path") - outputPath = cmdBuildMphCache.Flag.String("o", "domain.cache", "Output cache file path") -) - -func executeBuildMphCache(cmd *base.Command, args []string) { - cf, err := os.Open(*configPath) - if err != nil { - base.Fatalf("failed to open config file: %v", err) - } - defer cf.Close() - - // prevent using existing cache - domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) - if domainMatcherPath != "" { - os.Setenv("XRAY_MPH_CACHE", "") - defer os.Setenv("XRAY_MPH_CACHE", domainMatcherPath) - } - - config, err := serial.DecodeJSONConfig(cf) - if err != nil { - base.Fatalf("failed to decode config file: %v", err) - } - - if err := config.BuildMPHCache(outputPath); err != nil { - base.Fatalf("failed to build MPH cache: %v", err) - } -} diff --git a/main/commands/all/commands.go b/main/commands/all/commands.go index 20b92bb0..fba3a4b8 100644 --- a/main/commands/all/commands.go +++ b/main/commands/all/commands.go @@ -19,6 +19,5 @@ func init() { cmdMLDSA65, cmdMLKEM768, cmdVLESSEnc, - cmdBuildMphCache, ) } diff --git a/main/run.go b/main/run.go index 935dcade..2655f783 100644 --- a/main/run.go +++ b/main/run.go @@ -28,15 +28,15 @@ var cmdRun = &base.Command{ Long: ` Run Xray with config, the default command. -The -config=file, -c=file flags set the config files for +The -config=file, -c=file flags set the config files for Xray. Multiple assign is accepted. The -confdir=dir flag sets a dir with multiple json config -The -format=json flag sets the format of config files. +The -format=json flag sets the format of config files. Default "auto". -The -test flag tells Xray to test config files only, +The -test flag tells Xray to test config files only, without launching the server. The -dump flag tells Xray to print the merged config. @@ -93,12 +93,6 @@ func executeRun(cmd *base.Command, args []string) { } defer server.Close() - /* - conf.FileCache = nil - conf.IPCache = nil - conf.SiteCache = nil - */ - // Explicitly triggering GC to remove garbage from config loading. runtime.GC() debug.FreeOSMemory() @@ -218,8 +212,6 @@ func getConfigFormat() string { func startXray() (core.Server, error) { configFiles := getConfigFilePath(true) - // config, err := core.LoadConfig(getConfigFormat(), configFiles[0], configFiles) - c, err := core.LoadConfig(getConfigFormat(), configFiles) if err != nil { return nil, errors.New("failed to load config files: [", configFiles.String(), "]").Base(err) diff --git a/testing/scenarios/dns_test.go b/testing/scenarios/dns_test.go index 4de2fe7e..4f7fad5d 100644 --- a/testing/scenarios/dns_test.go +++ b/testing/scenarios/dns_test.go @@ -9,6 +9,7 @@ import ( "github.com/xtls/xray-core/app/proxyman" "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/geodata" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/serial" "github.com/xtls/xray-core/core" @@ -34,8 +35,7 @@ func TestResolveIP(t *testing.T) { serial.ToTypedMessage(&dns.Config{ StaticHosts: []*dns.Config_HostMapping{ { - Type: dns.DomainMatchingType_Full, - Domain: "google.com", + Domain: &geodata.DomainRule{Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "google.com"}}}, Ip: [][]byte{dest.Address.IP()}, }, }, @@ -44,10 +44,10 @@ func TestResolveIP(t *testing.T) { DomainStrategy: router.Config_IpIfNonMatch, Rule: []*router.RoutingRule{ { - Geoip: []*router.GeoIP{ + Ip: []*geodata.IPRule{ { - Cidr: []*router.CIDR{ - { + Value: &geodata.IPRule_Custom{ + Custom: &geodata.CIDR{ Ip: []byte{127, 0, 0, 0}, Prefix: 8, }, diff --git a/testing/scenarios/reverse_test.go b/testing/scenarios/reverse_test.go index 1e41d2a7..430856dc 100644 --- a/testing/scenarios/reverse_test.go +++ b/testing/scenarios/reverse_test.go @@ -10,6 +10,7 @@ import ( "github.com/xtls/xray-core/app/reverse" "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/geodata" clog "github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" @@ -52,8 +53,8 @@ func TestReverseProxy(t *testing.T) { serial.ToTypedMessage(&router.Config{ Rule: []*router.RoutingRule{ { - Domain: []*router.Domain{ - {Type: router.Domain_Full, Value: "test.example.com"}, + Domain: []*geodata.DomainRule{ + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "test.example.com"}}}, }, TargetTag: &router.RoutingRule_Tag{ Tag: "portal", @@ -118,8 +119,8 @@ func TestReverseProxy(t *testing.T) { serial.ToTypedMessage(&router.Config{ Rule: []*router.RoutingRule{ { - Domain: []*router.Domain{ - {Type: router.Domain_Full, Value: "test.example.com"}, + Domain: []*geodata.DomainRule{ + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "test.example.com"}}}, }, TargetTag: &router.RoutingRule_Tag{ Tag: "reverse", @@ -158,7 +159,7 @@ func TestReverseProxy(t *testing.T) { Receiver: &protocol.ServerEndpoint{ Address: net.NewIPOrDomain(net.LocalHostIP), Port: uint32(reversePort), - User: &protocol.User{ + User: &protocol.User{ Account: serial.ToTypedMessage(&vmess.Account{ Id: userID.String(), SecuritySettings: &protocol.SecurityConfig{ @@ -227,8 +228,8 @@ func TestReverseProxyLongRunning(t *testing.T) { serial.ToTypedMessage(&router.Config{ Rule: []*router.RoutingRule{ { - Domain: []*router.Domain{ - {Type: router.Domain_Full, Value: "test.example.com"}, + Domain: []*geodata.DomainRule{ + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "test.example.com"}}}, }, TargetTag: &router.RoutingRule_Tag{ Tag: "portal", @@ -307,8 +308,8 @@ func TestReverseProxyLongRunning(t *testing.T) { serial.ToTypedMessage(&router.Config{ Rule: []*router.RoutingRule{ { - Domain: []*router.Domain{ - {Type: router.Domain_Full, Value: "test.example.com"}, + Domain: []*geodata.DomainRule{ + {Value: &geodata.DomainRule_Custom{Custom: &geodata.Domain{Type: geodata.Domain_Full, Value: "test.example.com"}}}, }, TargetTag: &router.RoutingRule_Tag{ Tag: "reverse", @@ -347,7 +348,7 @@ func TestReverseProxyLongRunning(t *testing.T) { Receiver: &protocol.ServerEndpoint{ Address: net.NewIPOrDomain(net.LocalHostIP), Port: uint32(reversePort), - User: &protocol.User{ + User: &protocol.User{ Account: serial.ToTypedMessage(&vmess.Account{ Id: userID.String(), SecuritySettings: &protocol.SecurityConfig{