diff --git a/common/geodata/domain_matcher.go b/common/geodata/domain_matcher.go index c9e67702..731e2c16 100644 --- a/common/geodata/domain_matcher.go +++ b/common/geodata/domain_matcher.go @@ -95,48 +95,43 @@ func (f *CompactDomainMatcherFactory) BuildMatcher(rules []*DomainRule) (DomainM matchers: make([]strmatcher.MatcherGroup, 0, len(rules)), values: make([]uint32, 0, len(rules)), } - custom := strmatcher.NewLinearValueMatcher() - var idx uint32 - for _, r := range rules { + for i, r := range rules { switch v := r.Value.(type) { case *DomainRule_Custom: m, err := parseDomain(v.Custom) if err != nil { return nil, err } - custom.Add(m, 0) + if compact.custom == nil { + compact.custom = strmatcher.NewLinearValueMatcher() + } + compact.custom.Add(m, uint32(i)) case *DomainRule_Geosite: m, err := f.getOrCreateFrom(v.Geosite) if err != nil { return nil, err } compact.matchers = append(compact.matchers, m) - compact.values = append(compact.values, idx) - idx++ + compact.values = append(compact.values, uint32(i)) default: panic("unknown domain rule type") } } - if len(compact.matchers) != len(rules) { - compact.matchers = append(compact.matchers, custom) - compact.values = append(compact.values, idx+1) - } return compact, nil } type CompactDomainMatcher struct { + custom strmatcher.ValueMatcher matchers []strmatcher.MatcherGroup values []uint32 } -func (c *CompactDomainMatcher) Add(matcher strmatcher.MatcherGroup, value uint32) { - c.matchers = append(c.matchers, matcher) - c.values = append(c.values, value) -} - // Match implements DomainMatcher. func (c *CompactDomainMatcher) Match(input string) []uint32 { - result := make([]uint32, 0) + var result []uint32 + if c.custom != nil { + result = append(result, c.custom.Match(input)...) + } for i, m := range c.matchers { if m.MatchAny(input) { result = append(result, c.values[i]) @@ -147,6 +142,9 @@ func (c *CompactDomainMatcher) Match(input string) []uint32 { // MatchAny implements DomainMatcher. func (c *CompactDomainMatcher) MatchAny(input string) bool { + if c.custom != nil && c.custom.MatchAny(input) { + return true + } for _, m := range c.matchers { if m.MatchAny(input) { return true diff --git a/common/geodata/domain_matcher_test.go b/common/geodata/domain_matcher_test.go new file mode 100644 index 00000000..775b3371 --- /dev/null +++ b/common/geodata/domain_matcher_test.go @@ -0,0 +1,50 @@ +package geodata + +import ( + "path/filepath" + "reflect" + "slices" + "testing" + + "github.com/xtls/xray-core/common/geodata/strmatcher" +) + +func TestCompactDomainMatcher_PreservesCustomRuleIndices(t *testing.T) { + factory := &CompactDomainMatcherFactory{shared: make(map[string]strmatcher.MatcherGroup)} + matcher, err := factory.BuildMatcher([]*DomainRule{ + {Value: &DomainRule_Custom{Custom: &Domain{Type: Domain_Full, Value: "example.com"}}}, + {Value: &DomainRule_Custom{Custom: &Domain{Type: Domain_Domain, Value: "example.com"}}}, + }) + if err != nil { + t.Fatalf("BuildMatcher() failed: %v", err) + } + + got := matcher.Match("example.com") + slices.Sort(got) + + want := []uint32{0, 1} + if !reflect.DeepEqual(got, want) { + t.Fatalf("Match() = %v, want %v", got, want) + } +} + +func TestCompactDomainMatcher_PreservesMixedRuleIndices(t *testing.T) { + t.Setenv("xray.location.asset", filepath.Join("..", "..", "resources")) + + factory := &CompactDomainMatcherFactory{shared: make(map[string]strmatcher.MatcherGroup)} + matcher, err := factory.BuildMatcher([]*DomainRule{ + {Value: &DomainRule_Geosite{Geosite: &GeoSiteRule{File: DefaultGeoSiteDat, Code: "CN"}}}, + {Value: &DomainRule_Custom{Custom: &Domain{Type: Domain_Full, Value: "163.com"}}}, + }) + if err != nil { + t.Fatalf("BuildMatcher() failed: %v", err) + } + + got := matcher.Match("163.com") + slices.Sort(got) + + want := []uint32{0, 1} + if !reflect.DeepEqual(got, want) { + t.Fatalf("Match() = %v, want %v", got, want) + } +}