Meow
2026-04-14 00:42:29 +08:00
committed by GitHub
parent e9f7d61c2e
commit 82624bcaf0
73 changed files with 5432 additions and 4455 deletions

View File

@@ -0,0 +1,66 @@
package geodata
import (
"context"
"strings"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/geodata/strmatcher"
)
type DomainMatcher interface {
Match(input string) []uint32
MatchAny(input string) bool
}
func buildDomainMatcher(rules []*DomainRule) (DomainMatcher, error) {
g := strmatcher.NewMphValueMatcher()
for i, r := range rules {
switch v := r.Value.(type) {
case *DomainRule_Custom:
m, err := parseDomain(v.Custom)
if err != nil {
return nil, err
}
g.Add(m, uint32(i))
case *DomainRule_Geosite:
domains, err := loadSiteWithAttrs(v.Geosite.File, v.Geosite.Code, v.Geosite.Attrs)
if err != nil {
return nil, err
}
for j, d := range domains {
domains[j] = nil // peak mem
m, err := parseDomain(d)
if err != nil {
errors.LogError(context.Background(), "ignore invalid geosite entry in ", v.Geosite.File, ":", v.Geosite.Code, " at index ", j, ", ", err)
continue
}
g.Add(m, uint32(i))
}
default:
panic("unknown domain rule type")
}
}
if err := g.Build(); err != nil {
return nil, err
}
return g, nil
}
func parseDomain(d *Domain) (strmatcher.Matcher, error) {
if d == nil {
return nil, errors.New("domain must not be nil")
}
switch d.Type {
case Domain_Substr:
return strmatcher.Substr.New(strings.ToLower(d.Value))
case Domain_Regex:
return strmatcher.Regex.New(d.Value)
case Domain_Domain:
return strmatcher.Domain.New(d.Value)
case Domain_Full:
return strmatcher.Full.New(strings.ToLower(d.Value))
default:
return nil, errors.New("unknown domain type: ", d.Type)
}
}

View File

@@ -0,0 +1,13 @@
package geodata
type DomainRegistry struct{}
func (r *DomainRegistry) BuildDomainMatcher(rules []*DomainRule) (DomainMatcher, error) {
return buildDomainMatcher(rules)
}
func newDomainRegistry() *DomainRegistry {
return &DomainRegistry{}
}
var DomainReg = newDomainRegistry()

908
common/geodata/geodat.pb.go Normal file
View File

@@ -0,0 +1,908 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v6.33.5
// source: common/geodata/geodat.proto
package geodata
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// Type of domain value.
type Domain_Type int32
const (
// The value is used as a sub string.
Domain_Substr Domain_Type = 0
// The value is used as a regular expression.
Domain_Regex Domain_Type = 1
// The value is a domain.
Domain_Domain Domain_Type = 2
// The value is a full domain.
Domain_Full Domain_Type = 3
)
// Enum value maps for Domain_Type.
var (
Domain_Type_name = map[int32]string{
0: "Substr",
1: "Regex",
2: "Domain",
3: "Full",
}
Domain_Type_value = map[string]int32{
"Substr": 0,
"Regex": 1,
"Domain": 2,
"Full": 3,
}
)
func (x Domain_Type) Enum() *Domain_Type {
p := new(Domain_Type)
*p = x
return p
}
func (x Domain_Type) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (Domain_Type) Descriptor() protoreflect.EnumDescriptor {
return file_common_geodata_geodat_proto_enumTypes[0].Descriptor()
}
func (Domain_Type) Type() protoreflect.EnumType {
return &file_common_geodata_geodat_proto_enumTypes[0]
}
func (x Domain_Type) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use Domain_Type.Descriptor instead.
func (Domain_Type) EnumDescriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{0, 0}
}
type Domain struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Domain matching type.
Type Domain_Type `protobuf:"varint,1,opt,name=type,proto3,enum=xray.common.geodata.Domain_Type" json:"type,omitempty"`
// Domain value.
Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
// Attributes of this domain. May be used for filtering.
Attribute []*Domain_Attribute `protobuf:"bytes,3,rep,name=attribute,proto3" json:"attribute,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Domain) Reset() {
*x = Domain{}
mi := &file_common_geodata_geodat_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Domain) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Domain) ProtoMessage() {}
func (x *Domain) ProtoReflect() protoreflect.Message {
mi := &file_common_geodata_geodat_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Domain.ProtoReflect.Descriptor instead.
func (*Domain) Descriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{0}
}
func (x *Domain) GetType() Domain_Type {
if x != nil {
return x.Type
}
return Domain_Substr
}
func (x *Domain) GetValue() string {
if x != nil {
return x.Value
}
return ""
}
func (x *Domain) GetAttribute() []*Domain_Attribute {
if x != nil {
return x.Attribute
}
return nil
}
type GeoSite struct {
state protoimpl.MessageState `protogen:"open.v1"`
Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"`
Domain []*Domain `protobuf:"bytes,2,rep,name=domain,proto3" json:"domain,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GeoSite) Reset() {
*x = GeoSite{}
mi := &file_common_geodata_geodat_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GeoSite) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GeoSite) ProtoMessage() {}
func (x *GeoSite) ProtoReflect() protoreflect.Message {
mi := &file_common_geodata_geodat_proto_msgTypes[1]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GeoSite.ProtoReflect.Descriptor instead.
func (*GeoSite) Descriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{1}
}
func (x *GeoSite) GetCode() string {
if x != nil {
return x.Code
}
return ""
}
func (x *GeoSite) GetDomain() []*Domain {
if x != nil {
return x.Domain
}
return nil
}
type GeoSiteList struct {
state protoimpl.MessageState `protogen:"open.v1"`
Entry []*GeoSite `protobuf:"bytes,1,rep,name=entry,proto3" json:"entry,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GeoSiteList) Reset() {
*x = GeoSiteList{}
mi := &file_common_geodata_geodat_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GeoSiteList) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GeoSiteList) ProtoMessage() {}
func (x *GeoSiteList) ProtoReflect() protoreflect.Message {
mi := &file_common_geodata_geodat_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GeoSiteList.ProtoReflect.Descriptor instead.
func (*GeoSiteList) Descriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{2}
}
func (x *GeoSiteList) GetEntry() []*GeoSite {
if x != nil {
return x.Entry
}
return nil
}
type GeoSiteRule struct {
state protoimpl.MessageState `protogen:"open.v1"`
File string `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"`
Attrs string `protobuf:"bytes,3,opt,name=attrs,proto3" json:"attrs,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GeoSiteRule) Reset() {
*x = GeoSiteRule{}
mi := &file_common_geodata_geodat_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GeoSiteRule) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GeoSiteRule) ProtoMessage() {}
func (x *GeoSiteRule) ProtoReflect() protoreflect.Message {
mi := &file_common_geodata_geodat_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GeoSiteRule.ProtoReflect.Descriptor instead.
func (*GeoSiteRule) Descriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{3}
}
func (x *GeoSiteRule) GetFile() string {
if x != nil {
return x.File
}
return ""
}
func (x *GeoSiteRule) GetCode() string {
if x != nil {
return x.Code
}
return ""
}
func (x *GeoSiteRule) GetAttrs() string {
if x != nil {
return x.Attrs
}
return ""
}
type DomainRule struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to Value:
//
// *DomainRule_Geosite
// *DomainRule_Custom
Value isDomainRule_Value `protobuf_oneof:"value"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *DomainRule) Reset() {
*x = DomainRule{}
mi := &file_common_geodata_geodat_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *DomainRule) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DomainRule) ProtoMessage() {}
func (x *DomainRule) ProtoReflect() protoreflect.Message {
mi := &file_common_geodata_geodat_proto_msgTypes[4]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DomainRule.ProtoReflect.Descriptor instead.
func (*DomainRule) Descriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{4}
}
func (x *DomainRule) GetValue() isDomainRule_Value {
if x != nil {
return x.Value
}
return nil
}
func (x *DomainRule) GetGeosite() *GeoSiteRule {
if x != nil {
if x, ok := x.Value.(*DomainRule_Geosite); ok {
return x.Geosite
}
}
return nil
}
func (x *DomainRule) GetCustom() *Domain {
if x != nil {
if x, ok := x.Value.(*DomainRule_Custom); ok {
return x.Custom
}
}
return nil
}
type isDomainRule_Value interface {
isDomainRule_Value()
}
type DomainRule_Geosite struct {
Geosite *GeoSiteRule `protobuf:"bytes,1,opt,name=geosite,proto3,oneof"`
}
type DomainRule_Custom struct {
Custom *Domain `protobuf:"bytes,2,opt,name=custom,proto3,oneof"`
}
func (*DomainRule_Geosite) isDomainRule_Value() {}
func (*DomainRule_Custom) isDomainRule_Value() {}
type CIDR struct {
state protoimpl.MessageState `protogen:"open.v1"`
// IP address, should be either 4 or 16 bytes.
Ip []byte `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"`
// Number of leading ones in the network mask.
Prefix uint32 `protobuf:"varint,2,opt,name=prefix,proto3" json:"prefix,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *CIDR) Reset() {
*x = CIDR{}
mi := &file_common_geodata_geodat_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *CIDR) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CIDR) ProtoMessage() {}
func (x *CIDR) ProtoReflect() protoreflect.Message {
mi := &file_common_geodata_geodat_proto_msgTypes[5]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CIDR.ProtoReflect.Descriptor instead.
func (*CIDR) Descriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{5}
}
func (x *CIDR) GetIp() []byte {
if x != nil {
return x.Ip
}
return nil
}
func (x *CIDR) GetPrefix() uint32 {
if x != nil {
return x.Prefix
}
return 0
}
type GeoIP struct {
state protoimpl.MessageState `protogen:"open.v1"`
Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"`
Cidr []*CIDR `protobuf:"bytes,2,rep,name=cidr,proto3" json:"cidr,omitempty"`
ReverseMatch bool `protobuf:"varint,3,opt,name=reverse_match,json=reverseMatch,proto3" json:"reverse_match,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GeoIP) Reset() {
*x = GeoIP{}
mi := &file_common_geodata_geodat_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GeoIP) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GeoIP) ProtoMessage() {}
func (x *GeoIP) ProtoReflect() protoreflect.Message {
mi := &file_common_geodata_geodat_proto_msgTypes[6]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GeoIP.ProtoReflect.Descriptor instead.
func (*GeoIP) Descriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{6}
}
func (x *GeoIP) GetCode() string {
if x != nil {
return x.Code
}
return ""
}
func (x *GeoIP) GetCidr() []*CIDR {
if x != nil {
return x.Cidr
}
return nil
}
func (x *GeoIP) GetReverseMatch() bool {
if x != nil {
return x.ReverseMatch
}
return false
}
type GeoIPList struct {
state protoimpl.MessageState `protogen:"open.v1"`
Entry []*GeoIP `protobuf:"bytes,1,rep,name=entry,proto3" json:"entry,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GeoIPList) Reset() {
*x = GeoIPList{}
mi := &file_common_geodata_geodat_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GeoIPList) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GeoIPList) ProtoMessage() {}
func (x *GeoIPList) ProtoReflect() protoreflect.Message {
mi := &file_common_geodata_geodat_proto_msgTypes[7]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GeoIPList.ProtoReflect.Descriptor instead.
func (*GeoIPList) Descriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{7}
}
func (x *GeoIPList) GetEntry() []*GeoIP {
if x != nil {
return x.Entry
}
return nil
}
type GeoIPRule struct {
state protoimpl.MessageState `protogen:"open.v1"`
File string `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"`
ReverseMatch bool `protobuf:"varint,3,opt,name=reverse_match,json=reverseMatch,proto3" json:"reverse_match,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GeoIPRule) Reset() {
*x = GeoIPRule{}
mi := &file_common_geodata_geodat_proto_msgTypes[8]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GeoIPRule) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GeoIPRule) ProtoMessage() {}
func (x *GeoIPRule) ProtoReflect() protoreflect.Message {
mi := &file_common_geodata_geodat_proto_msgTypes[8]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GeoIPRule.ProtoReflect.Descriptor instead.
func (*GeoIPRule) Descriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{8}
}
func (x *GeoIPRule) GetFile() string {
if x != nil {
return x.File
}
return ""
}
func (x *GeoIPRule) GetCode() string {
if x != nil {
return x.Code
}
return ""
}
func (x *GeoIPRule) GetReverseMatch() bool {
if x != nil {
return x.ReverseMatch
}
return false
}
type IPRule struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to Value:
//
// *IPRule_Geoip
// *IPRule_Custom
Value isIPRule_Value `protobuf_oneof:"value"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *IPRule) Reset() {
*x = IPRule{}
mi := &file_common_geodata_geodat_proto_msgTypes[9]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *IPRule) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*IPRule) ProtoMessage() {}
func (x *IPRule) ProtoReflect() protoreflect.Message {
mi := &file_common_geodata_geodat_proto_msgTypes[9]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use IPRule.ProtoReflect.Descriptor instead.
func (*IPRule) Descriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{9}
}
func (x *IPRule) GetValue() isIPRule_Value {
if x != nil {
return x.Value
}
return nil
}
func (x *IPRule) GetGeoip() *GeoIPRule {
if x != nil {
if x, ok := x.Value.(*IPRule_Geoip); ok {
return x.Geoip
}
}
return nil
}
func (x *IPRule) GetCustom() *CIDR {
if x != nil {
if x, ok := x.Value.(*IPRule_Custom); ok {
return x.Custom
}
}
return nil
}
type isIPRule_Value interface {
isIPRule_Value()
}
type IPRule_Geoip struct {
Geoip *GeoIPRule `protobuf:"bytes,1,opt,name=geoip,proto3,oneof"`
}
type IPRule_Custom struct {
Custom *CIDR `protobuf:"bytes,2,opt,name=custom,proto3,oneof"`
}
func (*IPRule_Geoip) isIPRule_Value() {}
func (*IPRule_Custom) isIPRule_Value() {}
type Domain_Attribute struct {
state protoimpl.MessageState `protogen:"open.v1"`
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
// Types that are valid to be assigned to TypedValue:
//
// *Domain_Attribute_BoolValue
// *Domain_Attribute_IntValue
TypedValue isDomain_Attribute_TypedValue `protobuf_oneof:"typed_value"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Domain_Attribute) Reset() {
*x = Domain_Attribute{}
mi := &file_common_geodata_geodat_proto_msgTypes[10]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Domain_Attribute) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Domain_Attribute) ProtoMessage() {}
func (x *Domain_Attribute) ProtoReflect() protoreflect.Message {
mi := &file_common_geodata_geodat_proto_msgTypes[10]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Domain_Attribute.ProtoReflect.Descriptor instead.
func (*Domain_Attribute) Descriptor() ([]byte, []int) {
return file_common_geodata_geodat_proto_rawDescGZIP(), []int{0, 0}
}
func (x *Domain_Attribute) GetKey() string {
if x != nil {
return x.Key
}
return ""
}
func (x *Domain_Attribute) GetTypedValue() isDomain_Attribute_TypedValue {
if x != nil {
return x.TypedValue
}
return nil
}
func (x *Domain_Attribute) GetBoolValue() bool {
if x != nil {
if x, ok := x.TypedValue.(*Domain_Attribute_BoolValue); ok {
return x.BoolValue
}
}
return false
}
func (x *Domain_Attribute) GetIntValue() int64 {
if x != nil {
if x, ok := x.TypedValue.(*Domain_Attribute_IntValue); ok {
return x.IntValue
}
}
return 0
}
type isDomain_Attribute_TypedValue interface {
isDomain_Attribute_TypedValue()
}
type Domain_Attribute_BoolValue struct {
BoolValue bool `protobuf:"varint,2,opt,name=bool_value,json=boolValue,proto3,oneof"`
}
type Domain_Attribute_IntValue struct {
IntValue int64 `protobuf:"varint,3,opt,name=int_value,json=intValue,proto3,oneof"`
}
func (*Domain_Attribute_BoolValue) isDomain_Attribute_TypedValue() {}
func (*Domain_Attribute_IntValue) isDomain_Attribute_TypedValue() {}
var File_common_geodata_geodat_proto protoreflect.FileDescriptor
const file_common_geodata_geodat_proto_rawDesc = "" +
"\n" +
"\x1bcommon/geodata/geodat.proto\x12\x13xray.common.geodata\"\xbc\x02\n" +
"\x06Domain\x124\n" +
"\x04type\x18\x01 \x01(\x0e2 .xray.common.geodata.Domain.TypeR\x04type\x12\x14\n" +
"\x05value\x18\x02 \x01(\tR\x05value\x12C\n" +
"\tattribute\x18\x03 \x03(\v2%.xray.common.geodata.Domain.AttributeR\tattribute\x1al\n" +
"\tAttribute\x12\x10\n" +
"\x03key\x18\x01 \x01(\tR\x03key\x12\x1f\n" +
"\n" +
"bool_value\x18\x02 \x01(\bH\x00R\tboolValue\x12\x1d\n" +
"\tint_value\x18\x03 \x01(\x03H\x00R\bintValueB\r\n" +
"\vtyped_value\"3\n" +
"\x04Type\x12\n" +
"\n" +
"\x06Substr\x10\x00\x12\t\n" +
"\x05Regex\x10\x01\x12\n" +
"\n" +
"\x06Domain\x10\x02\x12\b\n" +
"\x04Full\x10\x03\"R\n" +
"\aGeoSite\x12\x12\n" +
"\x04code\x18\x01 \x01(\tR\x04code\x123\n" +
"\x06domain\x18\x02 \x03(\v2\x1b.xray.common.geodata.DomainR\x06domain\"A\n" +
"\vGeoSiteList\x122\n" +
"\x05entry\x18\x01 \x03(\v2\x1c.xray.common.geodata.GeoSiteR\x05entry\"K\n" +
"\vGeoSiteRule\x12\x12\n" +
"\x04file\x18\x01 \x01(\tR\x04file\x12\x12\n" +
"\x04code\x18\x02 \x01(\tR\x04code\x12\x14\n" +
"\x05attrs\x18\x03 \x01(\tR\x05attrs\"\x8a\x01\n" +
"\n" +
"DomainRule\x12<\n" +
"\ageosite\x18\x01 \x01(\v2 .xray.common.geodata.GeoSiteRuleH\x00R\ageosite\x125\n" +
"\x06custom\x18\x02 \x01(\v2\x1b.xray.common.geodata.DomainH\x00R\x06customB\a\n" +
"\x05value\".\n" +
"\x04CIDR\x12\x0e\n" +
"\x02ip\x18\x01 \x01(\fR\x02ip\x12\x16\n" +
"\x06prefix\x18\x02 \x01(\rR\x06prefix\"o\n" +
"\x05GeoIP\x12\x12\n" +
"\x04code\x18\x01 \x01(\tR\x04code\x12-\n" +
"\x04cidr\x18\x02 \x03(\v2\x19.xray.common.geodata.CIDRR\x04cidr\x12#\n" +
"\rreverse_match\x18\x03 \x01(\bR\freverseMatch\"=\n" +
"\tGeoIPList\x120\n" +
"\x05entry\x18\x01 \x03(\v2\x1a.xray.common.geodata.GeoIPR\x05entry\"X\n" +
"\tGeoIPRule\x12\x12\n" +
"\x04file\x18\x01 \x01(\tR\x04file\x12\x12\n" +
"\x04code\x18\x02 \x01(\tR\x04code\x12#\n" +
"\rreverse_match\x18\x03 \x01(\bR\freverseMatch\"~\n" +
"\x06IPRule\x126\n" +
"\x05geoip\x18\x01 \x01(\v2\x1e.xray.common.geodata.GeoIPRuleH\x00R\x05geoip\x123\n" +
"\x06custom\x18\x02 \x01(\v2\x19.xray.common.geodata.CIDRH\x00R\x06customB\a\n" +
"\x05valueB[\n" +
"\x17com.xray.common.geodataP\x01Z(github.com/xtls/xray-core/common/geodata\xaa\x02\x13Xray.Common.Geodatab\x06proto3"
var (
file_common_geodata_geodat_proto_rawDescOnce sync.Once
file_common_geodata_geodat_proto_rawDescData []byte
)
func file_common_geodata_geodat_proto_rawDescGZIP() []byte {
file_common_geodata_geodat_proto_rawDescOnce.Do(func() {
file_common_geodata_geodat_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_common_geodata_geodat_proto_rawDesc), len(file_common_geodata_geodat_proto_rawDesc)))
})
return file_common_geodata_geodat_proto_rawDescData
}
var file_common_geodata_geodat_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_common_geodata_geodat_proto_msgTypes = make([]protoimpl.MessageInfo, 11)
var file_common_geodata_geodat_proto_goTypes = []any{
(Domain_Type)(0), // 0: xray.common.geodata.Domain.Type
(*Domain)(nil), // 1: xray.common.geodata.Domain
(*GeoSite)(nil), // 2: xray.common.geodata.GeoSite
(*GeoSiteList)(nil), // 3: xray.common.geodata.GeoSiteList
(*GeoSiteRule)(nil), // 4: xray.common.geodata.GeoSiteRule
(*DomainRule)(nil), // 5: xray.common.geodata.DomainRule
(*CIDR)(nil), // 6: xray.common.geodata.CIDR
(*GeoIP)(nil), // 7: xray.common.geodata.GeoIP
(*GeoIPList)(nil), // 8: xray.common.geodata.GeoIPList
(*GeoIPRule)(nil), // 9: xray.common.geodata.GeoIPRule
(*IPRule)(nil), // 10: xray.common.geodata.IPRule
(*Domain_Attribute)(nil), // 11: xray.common.geodata.Domain.Attribute
}
var file_common_geodata_geodat_proto_depIdxs = []int32{
0, // 0: xray.common.geodata.Domain.type:type_name -> xray.common.geodata.Domain.Type
11, // 1: xray.common.geodata.Domain.attribute:type_name -> xray.common.geodata.Domain.Attribute
1, // 2: xray.common.geodata.GeoSite.domain:type_name -> xray.common.geodata.Domain
2, // 3: xray.common.geodata.GeoSiteList.entry:type_name -> xray.common.geodata.GeoSite
4, // 4: xray.common.geodata.DomainRule.geosite:type_name -> xray.common.geodata.GeoSiteRule
1, // 5: xray.common.geodata.DomainRule.custom:type_name -> xray.common.geodata.Domain
6, // 6: xray.common.geodata.GeoIP.cidr:type_name -> xray.common.geodata.CIDR
7, // 7: xray.common.geodata.GeoIPList.entry:type_name -> xray.common.geodata.GeoIP
9, // 8: xray.common.geodata.IPRule.geoip:type_name -> xray.common.geodata.GeoIPRule
6, // 9: xray.common.geodata.IPRule.custom:type_name -> xray.common.geodata.CIDR
10, // [10:10] is the sub-list for method output_type
10, // [10:10] is the sub-list for method input_type
10, // [10:10] is the sub-list for extension type_name
10, // [10:10] is the sub-list for extension extendee
0, // [0:10] is the sub-list for field type_name
}
func init() { file_common_geodata_geodat_proto_init() }
func file_common_geodata_geodat_proto_init() {
if File_common_geodata_geodat_proto != nil {
return
}
file_common_geodata_geodat_proto_msgTypes[4].OneofWrappers = []any{
(*DomainRule_Geosite)(nil),
(*DomainRule_Custom)(nil),
}
file_common_geodata_geodat_proto_msgTypes[9].OneofWrappers = []any{
(*IPRule_Geoip)(nil),
(*IPRule_Custom)(nil),
}
file_common_geodata_geodat_proto_msgTypes[10].OneofWrappers = []any{
(*Domain_Attribute_BoolValue)(nil),
(*Domain_Attribute_IntValue)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_common_geodata_geodat_proto_rawDesc), len(file_common_geodata_geodat_proto_rawDesc)),
NumEnums: 1,
NumMessages: 11,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_common_geodata_geodat_proto_goTypes,
DependencyIndexes: file_common_geodata_geodat_proto_depIdxs,
EnumInfos: file_common_geodata_geodat_proto_enumTypes,
MessageInfos: file_common_geodata_geodat_proto_msgTypes,
}.Build()
File_common_geodata_geodat_proto = out.File
file_common_geodata_geodat_proto_goTypes = nil
file_common_geodata_geodat_proto_depIdxs = nil
}

View File

@@ -0,0 +1,90 @@
syntax = "proto3";
package xray.common.geodata;
option csharp_namespace = "Xray.Common.Geodata";
option go_package = "github.com/xtls/xray-core/common/geodata";
option java_package = "com.xray.common.geodata";
option java_multiple_files = true;
message Domain {
// Type of domain value.
enum Type {
// The value is used as a sub string.
Substr = 0;
// The value is used as a regular expression.
Regex = 1;
// The value is a domain.
Domain = 2;
// The value is a full domain.
Full = 3;
}
// Domain matching type.
Type type = 1;
// Domain value.
string value = 2;
message Attribute {
string key = 1;
oneof typed_value {
bool bool_value = 2;
int64 int_value = 3;
}
}
// Attributes of this domain. May be used for filtering.
repeated Attribute attribute = 3;
}
message GeoSite {
string code = 1;
repeated Domain domain = 2;
}
message GeoSiteList {
repeated GeoSite entry = 1;
}
message GeoSiteRule {
string file = 1;
string code = 2;
string attrs = 3;
}
message DomainRule {
oneof value {
GeoSiteRule geosite = 1;
Domain custom = 2;
}
}
message CIDR {
// IP address, should be either 4 or 16 bytes.
bytes ip = 1;
// Number of leading ones in the network mask.
uint32 prefix = 2;
}
message GeoIP {
string code = 1;
repeated CIDR cidr = 2;
bool reverse_match = 3;
}
message GeoIPList {
repeated GeoIP entry = 1;
}
message GeoIPRule {
string file = 1;
string code = 2;
bool reverse_match = 3;
}
message IPRule {
oneof value {
GeoIPRule geoip = 1;
CIDR custom = 2;
}
}

View File

@@ -0,0 +1,207 @@
package geodata
import (
"bufio"
"bytes"
"io"
"runtime"
"strings"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/platform/filesystem"
"google.golang.org/protobuf/proto"
)
func checkFile(file, code string) error {
r, err := filesystem.OpenAsset(file)
if err != nil {
return errors.New("failed to open ", file).Base(err)
}
defer r.Close()
if _, err := find(r, []byte(code), false); err != nil {
return errors.New("failed to check code ", code, " from ", file).Base(err)
}
return nil
}
func loadFile(file, code string) ([]byte, error) {
runtime.GC() // peak mem
r, err := filesystem.OpenAsset(file)
if err != nil {
return nil, errors.New("failed to open ", file).Base(err)
}
defer r.Close()
bs, err := find(r, []byte(code), true)
if err != nil {
return nil, errors.New("failed to load code ", code, " from ", file).Base(err)
}
return bs, nil
}
func loadIP(file, code string) ([]*CIDR, error) {
bs, err := loadFile(file, code)
if err != nil {
return nil, err
}
defer runtime.GC() // peak mem
var geoip GeoIP
if err := proto.Unmarshal(bs, &geoip); err != nil {
return nil, errors.New("error unmarshal IP in ", file, ":", code).Base(err)
}
return geoip.Cidr, nil
}
func loadSite(file, code string) ([]*Domain, error) {
bs, err := loadFile(file, code)
if err != nil {
return nil, err
}
defer runtime.GC() // peak mem
var geosite GeoSite
if err := proto.Unmarshal(bs, &geosite); err != nil {
return nil, errors.New("error unmarshal Site in ", file, ":", code).Base(err)
}
return geosite.Domain, nil
}
func decodeVarint(br *bufio.Reader) (uint64, error) {
var x uint64
for shift := uint(0); shift < 64; shift += 7 {
b, err := br.ReadByte()
if err != nil {
return 0, err
}
x |= (uint64(b) & 0x7F) << shift
if (b & 0x80) == 0 {
return x, nil
}
}
// The number is too large to represent in a 64-bit value.
return 0, errors.New("varint overflow")
}
func find(r io.Reader, code []byte, readBody bool) ([]byte, error) {
codeL := len(code)
if codeL == 0 {
return nil, errors.New("empty code")
}
br := bufio.NewReaderSize(r, 64*1024)
need := 2 + codeL // TODO: if code too long
prefixBuf := make([]byte, need)
for {
if _, err := br.ReadByte(); err != nil {
return nil, err
}
x, err := decodeVarint(br)
if err != nil {
return nil, err
}
bodyL := int(x)
if bodyL <= 0 {
return nil, errors.New("invalid body length: ", bodyL)
}
prefixL := bodyL
if prefixL > need {
prefixL = need
}
prefix := prefixBuf[:prefixL]
if _, err := io.ReadFull(br, prefix); err != nil {
return nil, err
}
match := false
if bodyL >= need {
if int(prefix[1]) == codeL && bytes.Equal(prefix[2:need], code) {
if !readBody {
return nil, nil
}
match = true
}
}
remain := bodyL - prefixL
if match {
out := make([]byte, bodyL)
copy(out, prefix)
if remain > 0 {
if _, err := io.ReadFull(br, out[prefixL:]); err != nil {
return nil, err
}
}
return out, nil
}
if remain > 0 {
if _, err := br.Discard(remain); err != nil {
return nil, err
}
}
}
}
type AttributeMatcher interface {
Match(*Domain) bool
}
type HasAttrMatcher string
// Match reports whether this matcher matches any attribute on the domain.
func (m HasAttrMatcher) Match(domain *Domain) bool {
for _, attr := range domain.Attribute {
if attr.Key == string(m) {
return true
}
}
return false
}
type AllAttrsMatcher struct {
matchers []AttributeMatcher
}
// Match reports whether the domain matches every matcher in the list.
func (m *AllAttrsMatcher) Match(domain *Domain) bool {
for _, matcher := range m.matchers {
if !matcher.Match(domain) {
return false
}
}
return true
}
func NewAllAttrsMatcher(attrs string) AttributeMatcher {
if attrs == "" {
return nil
}
m := new(AllAttrsMatcher)
for _, attr := range strings.Split(attrs, "@") {
m.matchers = append(m.matchers, HasAttrMatcher(attr))
}
return m
}
func loadSiteWithAttrs(file, code, attrs string) ([]*Domain, error) {
domains, err := loadSite(file, code)
if err != nil {
return nil, err
}
matcher := NewAllAttrsMatcher(attrs)
if matcher == nil {
return domains, nil
}
filtered := make([]*Domain, 0, len(domains))
for _, d := range domains {
if matcher.Match(d) {
filtered = append(filtered, d)
}
}
return filtered, nil
}

View File

@@ -0,0 +1,996 @@
package geodata
import (
"context"
"net/netip"
"runtime"
"slices"
"sort"
"strings"
"sync"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"go4.org/netipx"
)
type IPMatcher interface {
// TODO: (PERF) all net.IP -> netipx.Addr
// Invalid IP always return false.
Match(ip net.IP) bool
// Returns true if *any* IP is valid and match.
AnyMatch(ips []net.IP) bool
// Returns true only if *all* IPs are valid and match. Any invalid IP, or non-matching valid IP, causes false.
Matches(ips []net.IP) bool
// Filters IPs. Invalid IPs are silently dropped and not included in either result.
FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP)
ToggleReverse()
SetReverse(reverse bool)
}
type IPSet struct {
ipv4, ipv6 *netipx.IPSet
max4, max6 uint8
}
type HeuristicIPMatcher struct {
ipset *IPSet
reverse bool
}
type ipBucket struct {
rep netip.Addr
ips []net.IP
}
// Match implements IPMatcher.
func (m *HeuristicIPMatcher) Match(ip net.IP) bool {
ipx, ok := netipx.FromStdIP(ip)
if !ok {
return false
}
return m.matchAddr(ipx)
}
func (m *HeuristicIPMatcher) matchAddr(ipx netip.Addr) bool {
if ipx.Is4() {
if m.ipset.max4 == 0xff {
return false
}
return m.ipset.ipv4.Contains(ipx) != m.reverse
}
if ipx.Is6() {
if m.ipset.max6 == 0xff {
return false
}
return m.ipset.ipv6.Contains(ipx) != m.reverse
}
return false
}
// AnyMatch implements IPMatcher.
func (m *HeuristicIPMatcher) AnyMatch(ips []net.IP) bool {
n := len(ips)
if n == 0 {
return false
}
if n == 1 {
return m.Match(ips[0])
}
heur4 := m.ipset.max4 <= 24
heur6 := m.ipset.max6 <= 64
if !heur4 && !heur6 {
for _, ip := range ips {
if ipx, ok := netipx.FromStdIP(ip); ok {
if m.matchAddr(ipx) {
return true
}
}
}
return false
}
buckets := make(map[[9]byte]struct{}, n)
for _, ip := range ips {
key, ok := prefixKeyFromIP(ip)
if !ok {
continue
}
heur := (key[0] == 4 && heur4) || (key[0] == 6 && heur6)
if heur {
if _, exists := buckets[key]; exists {
continue
}
}
ipx, ok := netipx.FromStdIP(ip)
if !ok {
continue
}
if m.matchAddr(ipx) {
return true
}
if heur {
buckets[key] = struct{}{}
}
}
return false
}
// Matches implements IPMatcher.
func (m *HeuristicIPMatcher) Matches(ips []net.IP) bool {
n := len(ips)
if n == 0 {
return false
}
if n == 1 {
return m.Match(ips[0])
}
heur4 := m.ipset.max4 <= 24
heur6 := m.ipset.max6 <= 64
if !heur4 && !heur6 {
for _, ip := range ips {
ipx, ok := netipx.FromStdIP(ip)
if !ok {
return false
}
if !m.matchAddr(ipx) {
return false
}
}
return true
}
buckets := make(map[[9]byte]netip.Addr, n)
precise := make([]netip.Addr, 0, n)
for _, ip := range ips {
key, ok := prefixKeyFromIP(ip)
if !ok {
return false
}
if (key[0] == 4 && heur4) || (key[0] == 6 && heur6) {
if _, exists := buckets[key]; !exists {
ipx, ok := netipx.FromStdIP(ip)
if !ok {
return false
}
buckets[key] = ipx
}
} else {
ipx, ok := netipx.FromStdIP(ip)
if !ok {
return false
}
precise = append(precise, ipx)
}
}
for _, ipx := range buckets {
if !m.matchAddr(ipx) {
return false
}
}
for _, ipx := range precise {
if !m.matchAddr(ipx) {
return false
}
}
return true
}
func prefixKeyFromIP(ip net.IP) (key [9]byte, ok bool) {
if ip4 := ip.To4(); ip4 != nil {
key[0] = 4
key[1] = ip4[0]
key[2] = ip4[1]
key[3] = ip4[2] // /24
return key, true
}
if ip16 := ip.To16(); ip16 != nil {
key[0] = 6
key[1] = ip16[0]
key[2] = ip16[1]
key[3] = ip16[2]
key[4] = ip16[3]
key[5] = ip16[4]
key[6] = ip16[5]
key[7] = ip16[6]
key[8] = ip16[7] // /64
return key, true
}
return key, false // illegal
}
// FilterIPs implements IPMatcher.
func (m *HeuristicIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
n := len(ips)
if n == 0 {
return []net.IP{}, []net.IP{}
}
if n == 1 {
ipx, ok := netipx.FromStdIP(ips[0])
if !ok {
return []net.IP{}, []net.IP{}
}
if m.matchAddr(ipx) {
return ips, []net.IP{}
}
return []net.IP{}, ips
}
heur4 := m.ipset.max4 <= 24
heur6 := m.ipset.max6 <= 64
if !heur4 && !heur6 {
matched = make([]net.IP, 0, n)
unmatched = make([]net.IP, 0, n)
for _, ip := range ips {
ipx, ok := netipx.FromStdIP(ip)
if !ok {
continue // illegal ip, ignore
}
if m.matchAddr(ipx) {
matched = append(matched, ip)
} else {
unmatched = append(unmatched, ip)
}
}
return
}
buckets := make(map[[9]byte]*ipBucket, n)
precise := make([]net.IP, 0, n)
for _, ip := range ips {
key, ok := prefixKeyFromIP(ip)
if !ok {
continue // illegal ip, ignore
}
if (key[0] == 4 && !heur4) || (key[0] == 6 && !heur6) {
precise = append(precise, ip)
continue
}
b, exists := buckets[key]
if !exists {
// build bucket
ipx, ok := netipx.FromStdIP(ip)
if !ok {
continue // illegal ip, ignore
}
b = &ipBucket{
rep: ipx,
ips: make([]net.IP, 0, 4), // for dns answer
}
buckets[key] = b
}
b.ips = append(b.ips, ip)
}
matched = make([]net.IP, 0, n)
unmatched = make([]net.IP, 0, n)
for _, b := range buckets {
if m.matchAddr(b.rep) {
matched = append(matched, b.ips...)
} else {
unmatched = append(unmatched, b.ips...)
}
}
for _, ip := range precise {
ipx, ok := netipx.FromStdIP(ip)
if !ok {
continue // illegal ip, ignore
}
if m.matchAddr(ipx) {
matched = append(matched, ip)
} else {
unmatched = append(unmatched, ip)
}
}
return
}
// ToggleReverse implements IPMatcher.
func (m *HeuristicIPMatcher) ToggleReverse() {
m.reverse = !m.reverse
}
// SetReverse implements IPMatcher.
func (m *HeuristicIPMatcher) SetReverse(reverse bool) {
m.reverse = reverse
}
type GeneralMultiIPMatcher struct {
matchers []IPMatcher
}
// Match implements IPMatcher.
func (mm *GeneralMultiIPMatcher) Match(ip net.IP) bool {
for _, m := range mm.matchers {
if m.Match(ip) {
return true
}
}
return false
}
// AnyMatch implements IPMatcher.
func (mm *GeneralMultiIPMatcher) AnyMatch(ips []net.IP) bool {
for _, m := range mm.matchers {
if m.AnyMatch(ips) {
return true
}
}
return false
}
// Matches implements IPMatcher.
func (mm *GeneralMultiIPMatcher) Matches(ips []net.IP) bool {
for _, m := range mm.matchers {
if m.Matches(ips) {
return true
}
}
return false
}
// FilterIPs implements IPMatcher.
func (mm *GeneralMultiIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
matched = make([]net.IP, 0, len(ips))
unmatched = ips
for _, m := range mm.matchers {
if len(unmatched) == 0 {
break
}
var mtch []net.IP
mtch, unmatched = m.FilterIPs(unmatched)
if len(mtch) > 0 {
matched = append(matched, mtch...)
}
}
return
}
// ToggleReverse implements IPMatcher.
func (mm *GeneralMultiIPMatcher) ToggleReverse() {
for _, m := range mm.matchers {
m.ToggleReverse()
}
}
// SetReverse implements IPMatcher.
func (mm *GeneralMultiIPMatcher) SetReverse(reverse bool) {
for _, m := range mm.matchers {
m.SetReverse(reverse)
}
}
type HeuristicMultiIPMatcher struct {
matchers []*HeuristicIPMatcher
}
// Match implements IPMatcher.
func (mm *HeuristicMultiIPMatcher) Match(ip net.IP) bool {
ipx, ok := netipx.FromStdIP(ip)
if !ok {
return false
}
for _, m := range mm.matchers {
if m.matchAddr(ipx) {
return true
}
}
return false
}
// AnyMatch implements IPMatcher.
func (mm *HeuristicMultiIPMatcher) AnyMatch(ips []net.IP) bool {
n := len(ips)
if n == 0 {
return false
}
if n == 1 {
return mm.Match(ips[0])
}
buckets := make(map[[9]byte]struct{}, n)
for _, ip := range ips {
var ipx netip.Addr
state := uint8(0) // 0 = Not initialized, 1 = Initialized, 4 = IPv4 can be skipped, 6 = IPv6 can be skipped
for _, m := range mm.matchers {
heur4 := m.ipset.max4 <= 24
heur6 := m.ipset.max6 <= 64
if state == 0 && (heur4 || heur6) {
key, ok := prefixKeyFromIP(ip)
if !ok {
break
}
if _, exists := buckets[key]; exists {
state = key[0]
} else {
buckets[key] = struct{}{}
state = 1
}
}
if (heur4 && state == 4) || (heur6 && state == 6) {
continue
}
if !ipx.IsValid() {
nipx, ok := netipx.FromStdIP(ip)
if !ok {
break
}
ipx = nipx
}
if m.matchAddr(ipx) {
return true
}
}
}
return false
}
// Matches implements IPMatcher.
func (mm *HeuristicMultiIPMatcher) Matches(ips []net.IP) bool {
n := len(ips)
if n == 0 {
return false
}
if n == 1 {
return mm.Match(ips[0])
}
var views ipViews
for _, m := range mm.matchers {
if !views.ensureForMatcher(m, ips) {
return false
}
matched := true
if m.ipset.max4 <= 24 {
for _, ipx := range views.buckets4 {
if !m.matchAddr(ipx) {
matched = false
break
}
}
} else {
for _, ipx := range views.precise4 {
if !m.matchAddr(ipx) {
matched = false
break
}
}
}
if !matched {
continue
}
if m.ipset.max6 <= 64 {
for _, ipx := range views.buckets6 {
if !m.matchAddr(ipx) {
matched = false
break
}
}
} else {
for _, ipx := range views.precise6 {
if !m.matchAddr(ipx) {
matched = false
break
}
}
}
if matched {
return true
}
}
return false
}
type ipViews struct {
buckets4, buckets6 map[[9]byte]netip.Addr
precise4, precise6 []netip.Addr
}
func (v *ipViews) ensureForMatcher(m *HeuristicIPMatcher, ips []net.IP) bool {
needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
return true
}
if needHeur4 {
v.buckets4 = make(map[[9]byte]netip.Addr, len(ips))
}
if needHeur6 {
v.buckets6 = make(map[[9]byte]netip.Addr, len(ips))
}
if needPrec4 {
v.precise4 = make([]netip.Addr, 0, len(ips))
}
if needPrec6 {
v.precise6 = make([]netip.Addr, 0, len(ips))
}
for _, ip := range ips {
key, ok := prefixKeyFromIP(ip)
if !ok {
return false
}
switch key[0] {
case 4:
var ipx netip.Addr
if needHeur4 {
if _, exists := v.buckets4[key]; !exists {
ipx, ok = netipx.FromStdIP(ip)
if !ok {
return false
}
v.buckets4[key] = ipx
}
}
if needPrec4 {
if !ipx.IsValid() {
ipx, ok = netipx.FromStdIP(ip)
if !ok {
return false
}
}
v.precise4 = append(v.precise4, ipx)
}
case 6:
var ipx netip.Addr
if needHeur6 {
if _, exists := v.buckets6[key]; !exists {
ipx, ok = netipx.FromStdIP(ip)
if !ok {
return false
}
v.buckets6[key] = ipx
}
}
if needPrec6 {
if !ipx.IsValid() {
ipx, ok = netipx.FromStdIP(ip)
if !ok {
return false
}
}
v.precise6 = append(v.precise6, ipx)
}
default:
return false
}
}
return true
}
// FilterIPs implements IPMatcher.
func (mm *HeuristicMultiIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
n := len(ips)
if n == 0 {
return []net.IP{}, []net.IP{}
}
if n == 1 {
ipx, ok := netipx.FromStdIP(ips[0])
if !ok {
return []net.IP{}, []net.IP{}
}
for _, m := range mm.matchers {
if m.matchAddr(ipx) {
return ips, []net.IP{}
}
}
return []net.IP{}, ips
}
var views ipBucketViews
matched = make([]net.IP, 0, n)
for _, m := range mm.matchers {
views.ensureForMatcher(m, ips)
if m.ipset.max4 <= 24 {
for key, b := range views.buckets4 {
if b == nil {
continue
}
if m.matchAddr(b.rep) {
views.buckets4[key] = nil
matched = append(matched, b.ips...)
}
}
} else {
for ipx, ip := range views.precise4 {
if ip == nil {
continue
}
if m.matchAddr(ipx) {
views.precise4[ipx] = nil
matched = append(matched, ip)
}
}
}
if m.ipset.max6 <= 64 {
for key, b := range views.buckets6 {
if b == nil {
continue
}
if m.matchAddr(b.rep) {
views.buckets6[key] = nil
matched = append(matched, b.ips...)
}
}
} else {
for ipx, ip := range views.precise6 {
if ip == nil {
continue
}
if m.matchAddr(ipx) {
views.precise6[ipx] = nil
matched = append(matched, ip)
}
}
}
}
unmatched = make([]net.IP, 0, n-len(matched))
if views.buckets4 != nil {
for _, b := range views.buckets4 {
if b == nil {
continue
}
unmatched = append(unmatched, b.ips...)
}
}
if views.precise4 != nil {
for _, ip := range views.precise4 {
if ip == nil {
continue
}
unmatched = append(unmatched, ip)
}
}
if views.buckets6 != nil {
for _, b := range views.buckets6 {
if b == nil {
continue
}
unmatched = append(unmatched, b.ips...)
}
}
if views.precise6 != nil {
for _, ip := range views.precise6 {
if ip == nil {
continue
}
unmatched = append(unmatched, ip)
}
}
return
}
type ipBucketViews struct {
buckets4, buckets6 map[[9]byte]*ipBucket
precise4, precise6 map[netip.Addr]net.IP
}
func (v *ipBucketViews) ensureForMatcher(m *HeuristicIPMatcher, ips []net.IP) {
needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
return
}
if needHeur4 {
v.buckets4 = make(map[[9]byte]*ipBucket, len(ips))
}
if needHeur6 {
v.buckets6 = make(map[[9]byte]*ipBucket, len(ips))
}
if needPrec4 {
v.precise4 = make(map[netip.Addr]net.IP, len(ips))
}
if needPrec6 {
v.precise6 = make(map[netip.Addr]net.IP, len(ips))
}
for _, ip := range ips {
key, ok := prefixKeyFromIP(ip)
if !ok {
continue // illegal ip, ignore
}
switch key[0] {
case 4:
var ipx netip.Addr
if needHeur4 {
b, exists := v.buckets4[key]
if !exists {
// build bucket
ipx, ok = netipx.FromStdIP(ip)
if !ok {
continue // illegal ip, ignore
}
b = &ipBucket{
rep: ipx,
ips: make([]net.IP, 0, 4), // for dns answer
}
v.buckets4[key] = b
}
b.ips = append(b.ips, ip)
}
if needPrec4 {
if !ipx.IsValid() {
ipx, ok = netipx.FromStdIP(ip)
if !ok {
continue // illegal ip, ignore
}
}
v.precise4[ipx] = ip
}
case 6:
var ipx netip.Addr
if needHeur6 {
b, exists := v.buckets6[key]
if !exists {
// build bucket
ipx, ok = netipx.FromStdIP(ip)
if !ok {
continue // illegal ip, ignore
}
b = &ipBucket{
rep: ipx,
ips: make([]net.IP, 0, 4), // for dns answer
}
v.buckets6[key] = b
}
b.ips = append(b.ips, ip)
}
if needPrec6 {
if !ipx.IsValid() {
ipx, ok = netipx.FromStdIP(ip)
if !ok {
continue // illegal ip, ignore
}
}
v.precise6[ipx] = ip
}
}
}
}
// ToggleReverse implements IPMatcher.
func (mm *HeuristicMultiIPMatcher) ToggleReverse() {
for _, m := range mm.matchers {
m.ToggleReverse()
}
}
// SetReverse implements IPMatcher.
func (mm *HeuristicMultiIPMatcher) SetReverse(reverse bool) {
for _, m := range mm.matchers {
m.SetReverse(reverse)
}
}
type IPSetFactory struct {
sync.Mutex
shared map[string]*IPSet // TODO: cleanup
}
func (f *IPSetFactory) GetOrCreateFromGeoIPRules(rules []*GeoIPRule) (*IPSet, error) {
key := buildGeoIPRulesKey(rules)
f.Lock()
defer f.Unlock()
if ipset := f.shared[key]; ipset != nil {
return ipset, nil
}
ipset, err := f.createFrom(func(add func(*CIDR)) error {
for _, r := range rules {
cidrs, err := loadIP(r.File, r.Code)
if err != nil {
return err
}
for i, c := range cidrs {
add(c)
cidrs[i] = nil // peak mem
}
}
return nil
})
if err == nil {
f.shared[key] = ipset
}
return ipset, err
}
func buildGeoIPRulesKey(rules []*GeoIPRule) string {
rules = slices.Clone(rules)
sort.Slice(rules, func(i, j int) bool {
ri, rj := rules[i], rules[j]
if ri.File != rj.File {
return ri.File < rj.File
}
return ri.Code < rj.Code
})
var sb strings.Builder
sb.Grow(len(rules) * 20) // geoip.dat:xx,
var last *GeoIPRule
for i, r := range rules {
if i == 0 || (r.File != last.File || r.Code != last.Code) {
last = r
sb.WriteString(r.File)
sb.WriteString(":")
sb.WriteString(r.Code)
sb.WriteString(",")
}
}
return sb.String()
}
func (f *IPSetFactory) CreateFromCIDRs(cidrs []*CIDR) (*IPSet, error) {
return f.createFrom(func(add func(*CIDR)) error {
for _, c := range cidrs {
add(c)
}
return nil
})
}
func (f *IPSetFactory) createFrom(yield func(func(*CIDR)) error) (*IPSet, error) {
var ipv4Builder, ipv6Builder netipx.IPSetBuilder
err := yield(func(c *CIDR) {
ipBytes := c.GetIp()
prefixLen := int(c.GetPrefix())
addr, ok := netip.AddrFromSlice(ipBytes)
if !ok {
errors.LogError(context.Background(), "ignore invalid IP byte slice: ", ipBytes)
return
}
prefix := netip.PrefixFrom(addr, prefixLen)
if !prefix.IsValid() {
errors.LogError(context.Background(), "ignore created invalid prefix from addr ", addr, " and length ", prefixLen)
return
}
if addr.Is4() {
ipv4Builder.AddPrefix(prefix)
} else if addr.Is6() {
ipv6Builder.AddPrefix(prefix)
}
})
if err != nil {
return nil, err
}
// peak mem
runtime.GC()
defer runtime.GC()
ipv4, err := ipv4Builder.IPSet()
if err != nil {
return nil, errors.New("failed to build IPv4 set").Base(err)
}
ipv6, err := ipv6Builder.IPSet()
if err != nil {
return nil, errors.New("failed to build IPv6 set").Base(err)
}
var max4, max6 int
for _, p := range ipv4.Prefixes() {
if b := p.Bits(); b > max4 {
max4 = b
}
}
for _, p := range ipv6.Prefixes() {
if b := p.Bits(); b > max6 {
max6 = b
}
}
if max4 == 0 {
max4 = 0xff
}
if max6 == 0 {
max6 = 0xff
}
return &IPSet{ipv4: ipv4, ipv6: ipv6, max4: uint8(max4), max6: uint8(max6)}, nil
}
func buildOptimizedIPMatcher(f *IPSetFactory, rules []*IPRule) (IPMatcher, error) {
n := len(rules)
custom := make([]*CIDR, 0, n)
pos := make([]*GeoIPRule, 0, n)
neg := make([]*GeoIPRule, 0, n)
for _, r := range rules {
switch v := r.Value.(type) {
case *IPRule_Custom:
custom = append(custom, v.Custom)
case *IPRule_Geoip:
if !v.Geoip.ReverseMatch {
pos = append(pos, v.Geoip)
} else {
neg = append(neg, v.Geoip)
}
default:
panic("unknown ip rule type")
}
}
subs := make([]*HeuristicIPMatcher, 0, 3)
if len(custom) > 0 {
ipset, err := f.CreateFromCIDRs(custom)
if err != nil {
return nil, err
}
subs = append(subs, &HeuristicIPMatcher{ipset: ipset, reverse: false})
}
if len(pos) > 0 {
ipset, err := f.GetOrCreateFromGeoIPRules(pos)
if err != nil {
return nil, err
}
subs = append(subs, &HeuristicIPMatcher{ipset: ipset, reverse: false})
}
if len(neg) > 0 {
ipset, err := f.GetOrCreateFromGeoIPRules(neg)
if err != nil {
return nil, err
}
subs = append(subs, &HeuristicIPMatcher{ipset: ipset, reverse: true})
}
switch len(subs) {
case 0:
return nil, errors.New("no valid ip matcher")
case 1:
return subs[0], nil
default:
return &HeuristicMultiIPMatcher{matchers: subs}, nil
}
}

View File

@@ -0,0 +1,325 @@
package geodata
import (
"net"
"path/filepath"
"reflect"
"slices"
"testing"
"github.com/xtls/xray-core/common"
xnet "github.com/xtls/xray-core/common/net"
)
func buildIPMatcher(rawRules ...string) IPMatcher {
rules, err := ParseIPRules(rawRules)
common.Must(err)
matcher, err := newIPRegistry().BuildIPMatcher(rules)
common.Must(err)
return matcher
}
func sortIPStrings(ips []net.IP) []string {
output := make([]string, 0, len(ips))
for _, ip := range ips {
output = append(output, ip.String())
}
slices.Sort(output)
return output
}
func TestIPMatcher(t *testing.T) {
matcher := buildIPMatcher(
"0.0.0.0/8",
"10.0.0.0/8",
"100.64.0.0/10",
"127.0.0.0/8",
"169.254.0.0/16",
"172.16.0.0/12",
"192.0.0.0/24",
"192.0.2.0/24",
"192.168.0.0/16",
"192.18.0.0/15",
"198.51.100.0/24",
"203.0.113.0/24",
"8.8.8.8/32",
"91.108.4.0/16",
)
testCases := []struct {
Input string
Output bool
}{
{
Input: "192.168.1.1",
Output: true,
},
{
Input: "192.0.0.0",
Output: true,
},
{
Input: "192.0.1.0",
Output: false,
},
{
Input: "0.1.0.0",
Output: true,
},
{
Input: "1.0.0.1",
Output: false,
},
{
Input: "8.8.8.7",
Output: false,
},
{
Input: "8.8.8.8",
Output: true,
},
{
Input: "2001:cdba::3257:9652",
Output: false,
},
{
Input: "91.108.255.254",
Output: true,
},
}
for _, test := range testCases {
if v := matcher.Match(xnet.ParseAddress(test.Input).IP()); v != test.Output {
t.Error("unexpected output: ", v, " for test case ", test)
}
}
}
func TestIPMatcherRegression(t *testing.T) {
matcher := buildIPMatcher(
"98.108.20.0/22",
"98.108.20.0/23",
)
testCases := []struct {
Input string
Output bool
}{
{
Input: "98.108.22.11",
Output: true,
},
{
Input: "98.108.25.0",
Output: false,
},
}
for _, test := range testCases {
if v := matcher.Match(xnet.ParseAddress(test.Input).IP()); v != test.Output {
t.Error("unexpected output: ", v, " for test case ", test)
}
}
}
func TestIPReverseMatcher(t *testing.T) {
matcher := buildIPMatcher(
"8.8.8.8/32",
"91.108.4.0/16",
)
matcher.SetReverse(true)
testCases := []struct {
Input string
Output bool
}{
{
Input: "8.8.8.8",
Output: false,
},
{
Input: "2001:cdba::3257:9652",
Output: false,
},
{
Input: "91.108.255.254",
Output: false,
},
}
for _, test := range testCases {
if v := matcher.Match(xnet.ParseAddress(test.Input).IP()); v != test.Output {
t.Error("unexpected output: ", v, " for test case ", test)
}
}
}
func TestIPReverseMatcher2(t *testing.T) {
matcher := buildIPMatcher(
"8.8.8.8/32",
"91.108.4.0/16",
"fe80::", // Keep IPv6 family non-empty so reverse matching can evaluate IPv6 input.
)
matcher.SetReverse(true)
testCases := []struct {
Input string
Output bool
}{
{
Input: "8.8.8.8",
Output: false,
},
{
Input: "2001:cdba::3257:9652",
Output: true,
},
{
Input: "91.108.255.254",
Output: false,
},
}
for _, test := range testCases {
if v := matcher.Match(xnet.ParseAddress(test.Input).IP()); v != test.Output {
t.Error("unexpected output: ", v, " for test case ", test)
}
}
}
func TestIPMatcherAnyMatchAndMatches(t *testing.T) {
matcher := buildIPMatcher(
"8.8.8.8/32",
"2001:4860:4860::8888/128",
)
ip := func(raw string) net.IP {
return xnet.ParseAddress(raw).IP()
}
if matcher.AnyMatch(nil) {
t.Fatal("expect AnyMatch(nil) to be false")
}
if !matcher.AnyMatch([]net.IP{
net.IP{},
ip("1.1.1.1"),
ip("8.8.8.8"),
}) {
t.Fatal("expect AnyMatch to ignore invalid IPs and return true when one valid IP matches")
}
if matcher.AnyMatch([]net.IP{
ip("1.1.1.1"),
ip("2001:db8::1"),
}) {
t.Fatal("expect AnyMatch to be false when no valid IP matches")
}
if !matcher.Matches([]net.IP{
ip("8.8.8.8"),
ip("2001:4860:4860::8888"),
}) {
t.Fatal("expect Matches to be true when all valid IPs match")
}
if matcher.Matches([]net.IP{
ip("8.8.8.8"),
ip("1.1.1.1"),
}) {
t.Fatal("expect Matches to be false when one valid IP does not match")
}
if matcher.Matches([]net.IP{
ip("8.8.8.8"),
net.IP{},
}) {
t.Fatal("expect Matches to be false when any IP is invalid")
}
}
func TestIPMatcherFilterIPs(t *testing.T) {
matcher := buildIPMatcher(
"8.8.8.8/32",
"91.108.4.0/16",
"2001:4860:4860::8888/128",
)
ip := func(raw string) net.IP {
return xnet.ParseAddress(raw).IP()
}
matched, unmatched := matcher.FilterIPs([]net.IP{
net.IP{},
ip("8.8.8.8"),
ip("91.108.255.254"),
ip("1.1.1.1"),
ip("2001:4860:4860::8888"),
ip("2001:db8::1"),
})
wantMatched := []string{
"2001:4860:4860::8888",
"8.8.8.8",
"91.108.255.254",
}
slices.Sort(wantMatched)
if v := sortIPStrings(matched); !reflect.DeepEqual(v, wantMatched) {
t.Error("unexpected output: ", v, " want ", wantMatched)
}
wantUnmatched := []string{
"1.1.1.1",
"2001:db8::1",
}
slices.Sort(wantUnmatched)
if v := sortIPStrings(unmatched); !reflect.DeepEqual(v, wantUnmatched) {
t.Error("unexpected output: ", v, " want ", wantUnmatched)
}
}
func TestIPMatcher4CN(t *testing.T) {
t.Setenv("xray.location.asset", filepath.Join("..", "..", "resources"))
matcher := buildIPMatcher("geoip:cn")
if matcher.Match([]byte{8, 8, 8, 8}) {
t.Error("expect CN geoip doesn't contain 8.8.8.8, but actually does")
}
}
func TestIPMatcher6US(t *testing.T) {
t.Setenv("xray.location.asset", filepath.Join("..", "..", "resources"))
matcher := buildIPMatcher("geoip:us")
if !matcher.Match(xnet.ParseAddress("2001:4860:4860::8888").IP()) {
t.Error("expect US geoip contain 2001:4860:4860::8888, but actually not")
}
}
func BenchmarkIPMatcher4CN(b *testing.B) {
b.Setenv("xray.location.asset", filepath.Join("..", "..", "resources"))
matcher := buildIPMatcher("geoip:cn")
ip := net.IP{8, 8, 8, 8}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = matcher.Match(ip)
}
}
func BenchmarkIPMatcher6US(b *testing.B) {
b.Setenv("xray.location.asset", filepath.Join("..", "..", "resources"))
matcher := buildIPMatcher("geoip:us")
ip := xnet.ParseAddress("2001:4860:4860::8888").IP()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = matcher.Match(ip)
}
}

View File

@@ -0,0 +1,17 @@
package geodata
type IPRegistry struct {
ipsetFactory *IPSetFactory
}
func (r *IPRegistry) BuildIPMatcher(rules []*IPRule) (IPMatcher, error) {
return buildOptimizedIPMatcher(r.ipsetFactory, rules)
}
func newIPRegistry() *IPRegistry {
return &IPRegistry{
ipsetFactory: &IPSetFactory{shared: make(map[string]*IPSet)},
}
}
var IPReg = newIPRegistry()

View File

@@ -0,0 +1,254 @@
package geodata
import (
"strconv"
"strings"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
)
const (
DefaultGeoIPDat = "geoip.dat"
DefaultGeoSiteDat = "geosite.dat"
)
func ParseIPRules(rules []string) ([]*IPRule, error) {
var ipRules []*IPRule
for i, r := range rules {
if strings.HasPrefix(r, "geoip:") {
r = "ext:" + DefaultGeoIPDat + ":" + r[len("geoip:"):]
}
prefix := 0
for _, ext := range [...]string{"ext:", "ext-ip:"} {
if strings.HasPrefix(r, ext) {
prefix = len(ext)
break
}
}
var rule isIPRule_Value
var err error
if prefix > 0 {
rule, err = parseGeoIPRule(r[prefix:])
} else {
rule, err = parseCustomIPRule(r)
}
if err != nil {
return nil, errors.New("illegal ip rule: ", rules[i]).Base(err)
}
ipRules = append(ipRules, &IPRule{Value: rule})
}
return ipRules, nil
}
func parseGeoIPRule(rule string) (*IPRule_Geoip, error) {
file, code, ok := strings.Cut(rule, ":")
if !ok {
return nil, errors.New("syntax error")
}
if file == "" {
return nil, errors.New("empty file")
}
reverse := false
if strings.HasPrefix(code, "!") {
code = code[1:]
reverse = true
}
if code == "" {
return nil, errors.New("empty code")
}
code = strings.ToUpper(code)
if err := checkFile(file, code); err != nil {
return nil, err
}
return &IPRule_Geoip{
Geoip: &GeoIPRule{
File: file,
Code: code,
ReverseMatch: reverse,
},
}, nil
}
func parseCustomIPRule(rule string) (*IPRule_Custom, error) {
cidr, err := parseCIDR(rule)
if err != nil {
return nil, err
}
return &IPRule_Custom{
Custom: cidr,
}, nil
}
func parseCIDR(s string) (*CIDR, error) {
ipStr, prefixStr, _ := strings.Cut(s, "/")
ipAddr := net.ParseAddress(ipStr)
var maxPrefix uint32
switch ipAddr.Family() {
case net.AddressFamilyIPv4:
maxPrefix = 32
case net.AddressFamilyIPv6:
maxPrefix = 128
default:
return nil, errors.New("unsupported address family")
}
prefixBits := maxPrefix
if prefixStr != "" {
parsedPrefix, err := strconv.ParseUint(prefixStr, 10, 32)
if err != nil {
return nil, errors.New("invalid CIDR prefix length: ", prefixStr).Base(err)
}
prefixBits = uint32(parsedPrefix)
}
if prefixBits > maxPrefix {
return nil, errors.New("CIDR prefix length ", prefixBits, " exceeds max ", maxPrefix)
}
return &CIDR{
Ip: []byte(ipAddr.IP()),
Prefix: prefixBits,
}, nil
}
func ParseDomainRule(r string, defaultType Domain_Type) (*DomainRule, error) {
if strings.HasPrefix(r, "geosite:") {
r = "ext:" + DefaultGeoSiteDat + ":" + r[len("geosite:"):]
}
prefix := 0
for _, ext := range [...]string{"ext:", "ext-domain:"} {
if strings.HasPrefix(r, ext) {
prefix = len(ext)
break
}
}
var rule isDomainRule_Value
var err error
if prefix > 0 {
rule, err = parseGeoSiteRule(r[prefix:])
} else {
rule, err = parseCustomDomainRule(r, defaultType)
}
if err != nil {
return nil, errors.New("illegal domain rule: ", r).Base(err)
}
return &DomainRule{Value: rule}, nil
}
func ParseDomainRules(rules []string, defaultType Domain_Type) ([]*DomainRule, error) {
var domainRules []*DomainRule
for i, r := range rules {
if strings.HasPrefix(r, "geosite:") {
r = "ext:" + DefaultGeoSiteDat + ":" + r[len("geosite:"):]
}
prefix := 0
for _, ext := range [...]string{"ext:", "ext-domain:"} {
if strings.HasPrefix(r, ext) {
prefix = len(ext)
break
}
}
var rule isDomainRule_Value
var err error
if prefix > 0 {
rule, err = parseGeoSiteRule(r[prefix:])
} else {
rule, err = parseCustomDomainRule(r, defaultType)
}
if err != nil {
return nil, errors.New("illegal domain rule: ", rules[i]).Base(err)
}
domainRules = append(domainRules, &DomainRule{Value: rule})
}
return domainRules, nil
}
func parseGeoSiteRule(rule string) (*DomainRule_Geosite, error) {
file, codeWithAttrs, ok := strings.Cut(rule, ":")
if !ok {
return nil, errors.New("syntax error")
}
if file == "" {
return nil, errors.New("empty file")
}
if strings.HasSuffix(codeWithAttrs, "@") || strings.Contains(codeWithAttrs, "@@") {
return nil, errors.New("empty attr")
}
code, attrs, _ := strings.Cut(codeWithAttrs, "@")
if code == "" {
return nil, errors.New("empty code")
}
code = strings.ToUpper(code)
if err := checkFile(file, code); err != nil {
return nil, err
}
return &DomainRule_Geosite{
Geosite: &GeoSiteRule{
File: file,
Code: code,
Attrs: strings.ToLower(attrs),
},
}, nil
}
func parseCustomDomainRule(rule string, defaultType Domain_Type) (*DomainRule_Custom, error) {
domain := new(Domain)
switch {
case strings.HasPrefix(rule, "regexp:"):
domain.Type = Domain_Regex
domain.Value = rule[7:]
case strings.HasPrefix(rule, "domain:"):
domain.Type = Domain_Domain
domain.Value = rule[7:]
case strings.HasPrefix(rule, "full:"):
domain.Type = Domain_Full
domain.Value = rule[5:]
case strings.HasPrefix(rule, "keyword:"):
domain.Type = Domain_Substr
domain.Value = rule[8:]
case strings.HasPrefix(rule, "dotless:"):
domain.Type = Domain_Regex
switch substr := rule[8:]; {
case substr == "":
domain.Value = "^[^.]*$"
case !strings.Contains(substr, "."):
domain.Value = "^[^.]*" + substr + "[^.]*$"
default:
return nil, errors.New("substr in dotless rule should not contain a dot")
}
default:
domain.Type = defaultType
domain.Value = rule
}
return &DomainRule_Custom{
Custom: domain,
}, nil
}

View File

@@ -0,0 +1,51 @@
package geodata_test
import (
"path/filepath"
"testing"
"github.com/xtls/xray-core/common/geodata"
)
func TestParseIPRules(t *testing.T) {
t.Setenv("xray.location.asset", filepath.Join("..", "..", "resources"))
rules := []string{
"geoip:us",
"geoip:cn",
"geoip:!cn",
"ext:geoip.dat:!cn",
"ext:geoip.dat:ca",
"ext-ip:geoip.dat:!cn",
"ext-ip:geoip.dat:!ca",
"192.168.0.0/24",
"192.168.0.1",
"fe80::/64",
"fe80::",
}
_, err := geodata.ParseIPRules(rules)
if err != nil {
t.Fatalf("Failed to parse ip rules, got %s", err)
}
}
func TestParseDomainRules(t *testing.T) {
t.Setenv("xray.location.asset", filepath.Join("..", "..", "resources"))
rules := []string{
"geosite:cn",
"geosite:geolocation-!cn",
"geosite:cn@!cn",
"ext:geosite.dat:geolocation-!cn",
"ext:geosite.dat:cn@!cn",
"ext-site:geosite.dat:geolocation-!cn",
"ext-site:geosite.dat:cn@!cn",
"domain:google.com",
}
_, err := geodata.ParseDomainRules(rules, geodata.Domain_Domain)
if err != nil {
t.Fatalf("Failed to parse domain rules, got %s", err)
}
}

View File

@@ -0,0 +1,58 @@
package strmatcher_test
import (
"testing"
. "github.com/xtls/xray-core/common/geodata/strmatcher"
)
func BenchmarkLinearIndexMatcher(b *testing.B) {
benchmarkIndexMatcher(b, func() IndexMatcher {
return NewLinearIndexMatcher()
})
}
func BenchmarkMphIndexMatcher(b *testing.B) {
benchmarkIndexMatcher(b, func() IndexMatcher {
return NewMphIndexMatcher()
})
}
func benchmarkIndexMatcher(b *testing.B, ctor func() IndexMatcher) {
b.Run("Match", func(b *testing.B) {
b.Run("Domain------------", func(b *testing.B) {
benchmarkMatch(b, ctor(), map[Type]bool{Domain: true})
})
b.Run("Domain+Full-------", func(b *testing.B) {
benchmarkMatch(b, ctor(), map[Type]bool{Domain: true, Full: true})
})
b.Run("Domain+Full+Substr", func(b *testing.B) {
benchmarkMatch(b, ctor(), map[Type]bool{Domain: true, Full: true, Substr: true})
})
b.Run("All-Fail----------", func(b *testing.B) {
benchmarkMatch(b, ctor(), map[Type]bool{Domain: false, Full: false, Substr: false})
})
})
b.Run("Match/Dotless", func(b *testing.B) { // Dotless domain matcher automatically inserted in DNS app when "localhost" DNS is used.
b.Run("All-Succ", func(b *testing.B) {
benchmarkMatch(b, ctor(), map[Type]bool{Domain: true, Full: true, Substr: true, Regex: true})
})
b.Run("All-Fail", func(b *testing.B) {
benchmarkMatch(b, ctor(), map[Type]bool{Domain: false, Full: false, Substr: false, Regex: false})
})
})
b.Run("MatchAny", func(b *testing.B) {
b.Run("First-Full--", func(b *testing.B) {
benchmarkMatchAny(b, ctor(), map[Type]bool{Full: true, Domain: true, Substr: true})
})
b.Run("First-Domain", func(b *testing.B) {
benchmarkMatchAny(b, ctor(), map[Type]bool{Full: false, Domain: true, Substr: true})
})
b.Run("First-Substr", func(b *testing.B) {
benchmarkMatchAny(b, ctor(), map[Type]bool{Full: false, Domain: false, Substr: true})
})
b.Run("All-Fail----", func(b *testing.B) {
benchmarkMatchAny(b, ctor(), map[Type]bool{Full: false, Domain: false, Substr: false})
})
})
}

View File

@@ -0,0 +1,149 @@
package strmatcher_test
import (
"strconv"
"testing"
"github.com/xtls/xray-core/common"
. "github.com/xtls/xray-core/common/geodata/strmatcher"
)
func BenchmarkFullMatcher(b *testing.B) {
b.Run("SimpleMatcherGroup------", func(b *testing.B) {
benchmarkMatcherType(b, Full, func() MatcherGroup {
return new(SimpleMatcherGroup)
})
})
b.Run("FullMatcherGroup--------", func(b *testing.B) {
benchmarkMatcherType(b, Full, func() MatcherGroup {
return NewFullMatcherGroup()
})
})
b.Run("ACAutomationMatcherGroup", func(b *testing.B) {
benchmarkMatcherType(b, Full, func() MatcherGroup {
return NewACAutomatonMatcherGroup()
})
})
b.Run("MphMatcherGroup---------", func(b *testing.B) {
benchmarkMatcherType(b, Full, func() MatcherGroup {
return NewMphMatcherGroup()
})
})
}
func BenchmarkDomainMatcher(b *testing.B) {
b.Run("SimpleMatcherGroup------", func(b *testing.B) {
benchmarkMatcherType(b, Domain, func() MatcherGroup {
return new(SimpleMatcherGroup)
})
})
b.Run("DomainMatcherGroup------", func(b *testing.B) {
benchmarkMatcherType(b, Domain, func() MatcherGroup {
return NewDomainMatcherGroup()
})
})
b.Run("ACAutomationMatcherGroup", func(b *testing.B) {
benchmarkMatcherType(b, Domain, func() MatcherGroup {
return NewACAutomatonMatcherGroup()
})
})
b.Run("MphMatcherGroup---------", func(b *testing.B) {
benchmarkMatcherType(b, Domain, func() MatcherGroup {
return NewMphMatcherGroup()
})
})
}
func BenchmarkSubstrMatcher(b *testing.B) {
b.Run("SimpleMatcherGroup------", func(b *testing.B) {
benchmarkMatcherType(b, Substr, func() MatcherGroup {
return new(SimpleMatcherGroup)
})
})
b.Run("SubstrMatcherGroup------", func(b *testing.B) {
benchmarkMatcherType(b, Substr, func() MatcherGroup {
return new(SubstrMatcherGroup)
})
})
b.Run("ACAutomationMatcherGroup", func(b *testing.B) {
benchmarkMatcherType(b, Substr, func() MatcherGroup {
return NewACAutomatonMatcherGroup()
})
})
}
// Utility functions for benchmark
func benchmarkMatcherType(b *testing.B, t Type, ctor func() MatcherGroup) {
b.Run("Match", func(b *testing.B) {
b.Run("Succ", func(b *testing.B) {
benchmarkMatch(b, ctor(), map[Type]bool{t: true})
})
b.Run("Fail", func(b *testing.B) {
benchmarkMatch(b, ctor(), map[Type]bool{t: false})
})
})
b.Run("MatchAny", func(b *testing.B) {
b.Run("Succ", func(b *testing.B) {
benchmarkMatchAny(b, ctor(), map[Type]bool{t: true})
})
b.Run("Fail", func(b *testing.B) {
benchmarkMatchAny(b, ctor(), map[Type]bool{t: false})
})
})
}
func benchmarkMatch(b *testing.B, g MatcherGroup, enabledTypes map[Type]bool) {
prepareMatchers(g, enabledTypes)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = g.Match("0.example.com")
}
}
func benchmarkMatchAny(b *testing.B, g MatcherGroup, enabledTypes map[Type]bool) {
prepareMatchers(g, enabledTypes)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = g.MatchAny("0.example.com")
}
}
func prepareMatchers(g MatcherGroup, enabledTypes map[Type]bool) {
for matcherType, hasMatch := range enabledTypes {
switch matcherType {
case Domain:
if hasMatch {
AddMatcherToGroup(g, DomainMatcher("example.com"), 0)
}
for i := 1; i < 1024; i++ {
AddMatcherToGroup(g, DomainMatcher(strconv.Itoa(i)+".example.com"), uint32(i))
}
case Full:
if hasMatch {
AddMatcherToGroup(g, FullMatcher("0.example.com"), 0)
}
for i := 1; i < 64; i++ {
AddMatcherToGroup(g, FullMatcher(strconv.Itoa(i)+".example.com"), uint32(i))
}
case Substr:
if hasMatch {
AddMatcherToGroup(g, SubstrMatcher("example.com"), 0)
}
for i := 1; i < 4; i++ {
AddMatcherToGroup(g, SubstrMatcher(strconv.Itoa(i)+".example.com"), uint32(i))
}
case Regex:
matcher, err := Regex.New("^[^.]*$") // Dotless domain matcher automatically inserted in DNS app when "localhost" DNS is used.
common.Must(err)
AddMatcherToGroup(g, matcher, 0)
}
}
if g, ok := g.(buildable); ok {
common.Must(g.Build())
}
}
type buildable interface {
Build() error
}

View File

@@ -0,0 +1,96 @@
package strmatcher
// LinearIndexMatcher is an implementation of IndexMatcher.
type LinearIndexMatcher struct {
count uint32
full *FullMatcherGroup
domain *DomainMatcherGroup
substr *SubstrMatcherGroup
regex *SimpleMatcherGroup
}
func NewLinearIndexMatcher() *LinearIndexMatcher {
return new(LinearIndexMatcher)
}
// Add implements IndexMatcher.Add.
func (g *LinearIndexMatcher) Add(matcher Matcher) uint32 {
g.count++
index := g.count
switch matcher := matcher.(type) {
case FullMatcher:
if g.full == nil {
g.full = NewFullMatcherGroup()
}
g.full.AddFullMatcher(matcher, index)
case DomainMatcher:
if g.domain == nil {
g.domain = NewDomainMatcherGroup()
}
g.domain.AddDomainMatcher(matcher, index)
case SubstrMatcher:
if g.substr == nil {
g.substr = new(SubstrMatcherGroup)
}
g.substr.AddSubstrMatcher(matcher, index)
default:
if g.regex == nil {
g.regex = new(SimpleMatcherGroup)
}
g.regex.AddMatcher(matcher, index)
}
return index
}
// Build implements IndexMatcher.Build.
func (*LinearIndexMatcher) Build() error {
return nil
}
// Match implements IndexMatcher.Match.
func (g *LinearIndexMatcher) Match(input string) []uint32 {
// Allocate capacity to prevent matches escaping to heap
result := make([][]uint32, 0, 5)
if g.full != nil {
if matches := g.full.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
if g.domain != nil {
if matches := g.domain.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
if g.substr != nil {
if matches := g.substr.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
if g.regex != nil {
if matches := g.regex.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
return CompositeMatches(result)
}
// MatchAny implements IndexMatcher.MatchAny.
func (g *LinearIndexMatcher) MatchAny(input string) bool {
if g.full != nil && g.full.MatchAny(input) {
return true
}
if g.domain != nil && g.domain.MatchAny(input) {
return true
}
if g.substr != nil && g.substr.MatchAny(input) {
return true
}
return g.regex != nil && g.regex.MatchAny(input)
}
// Size implements IndexMatcher.Size.
func (g *LinearIndexMatcher) Size() uint32 {
return g.count
}

View File

@@ -0,0 +1,95 @@
package strmatcher_test
import (
"reflect"
"testing"
"github.com/xtls/xray-core/common"
. "github.com/xtls/xray-core/common/geodata/strmatcher"
)
// See https://github.com/v2fly/v2ray-core/issues/92#issuecomment-673238489
func TestLinearIndexMatcher(t *testing.T) {
rules := []struct {
Type Type
Domain string
}{
{
Type: Regex,
Domain: "apis\\.us$",
},
{
Type: Substr,
Domain: "apis",
},
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Domain,
Domain: "com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
{
Type: Substr,
Domain: "apis",
},
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Full,
Domain: "fonts.googleapis.com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
{
Type: Domain,
Domain: "example.com",
},
}
cases := []struct {
Input string
Output []uint32
}{
{
Input: "www.baidu.com",
Output: []uint32{5, 9, 4},
},
{
Input: "fonts.googleapis.com",
Output: []uint32{8, 3, 7, 4, 2, 6},
},
{
Input: "example.googleapis.com",
Output: []uint32{3, 7, 4, 2, 6},
},
{
Input: "testapis.us",
Output: []uint32{2, 6, 1},
},
{
Input: "example.com",
Output: []uint32{10, 4},
},
}
matcherGroup := NewLinearIndexMatcher()
for _, rule := range rules {
matcher, err := rule.Type.New(rule.Domain)
common.Must(err)
matcherGroup.Add(matcher)
}
matcherGroup.Build()
for _, test := range cases {
if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}

View File

@@ -0,0 +1,100 @@
package strmatcher
import "runtime"
// A MphIndexMatcher is divided into three parts:
// 1. `full` and `domain` patterns are matched by Rabin-Karp algorithm and minimal perfect hash table;
// 2. `substr` patterns are matched by ac automaton;
// 3. `regex` patterns are matched with the regex library.
type MphIndexMatcher struct {
count uint32
mph *MphMatcherGroup
ac *ACAutomatonMatcherGroup
regex *SimpleMatcherGroup
}
func NewMphIndexMatcher() *MphIndexMatcher {
return new(MphIndexMatcher)
}
// Add implements IndexMatcher.Add.
func (g *MphIndexMatcher) Add(matcher Matcher) uint32 {
g.count++
index := g.count
switch matcher := matcher.(type) {
case FullMatcher:
if g.mph == nil {
g.mph = NewMphMatcherGroup()
}
g.mph.AddFullMatcher(matcher, index)
case DomainMatcher:
if g.mph == nil {
g.mph = NewMphMatcherGroup()
}
g.mph.AddDomainMatcher(matcher, index)
case SubstrMatcher:
if g.ac == nil {
g.ac = NewACAutomatonMatcherGroup()
}
g.ac.AddSubstrMatcher(matcher, index)
case *RegexMatcher:
if g.regex == nil {
g.regex = &SimpleMatcherGroup{}
}
g.regex.AddMatcher(matcher, index)
}
return index
}
// Build implements IndexMatcher.Build.
func (g *MphIndexMatcher) Build() error {
if g.mph != nil {
runtime.GC() // peak mem
g.mph.Build()
}
runtime.GC() // peak mem
if g.ac != nil {
g.ac.Build()
runtime.GC() // peak mem
}
return nil
}
// Match implements IndexMatcher.Match.
func (g *MphIndexMatcher) Match(input string) []uint32 {
result := make([][]uint32, 0, 5)
if g.mph != nil {
if matches := g.mph.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
if g.ac != nil {
if matches := g.ac.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
if g.regex != nil {
if matches := g.regex.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
return CompositeMatches(result)
}
// MatchAny implements IndexMatcher.MatchAny.
func (g *MphIndexMatcher) MatchAny(input string) bool {
if g.mph != nil && g.mph.MatchAny(input) {
return true
}
if g.ac != nil && g.ac.MatchAny(input) {
return true
}
return g.regex != nil && g.regex.MatchAny(input)
}
// Size implements IndexMatcher.Size.
func (g *MphIndexMatcher) Size() uint32 {
return g.count
}

View File

@@ -0,0 +1,94 @@
package strmatcher_test
import (
"reflect"
"testing"
"github.com/xtls/xray-core/common"
. "github.com/xtls/xray-core/common/geodata/strmatcher"
)
func TestMphIndexMatcher(t *testing.T) {
rules := []struct {
Type Type
Domain string
}{
{
Type: Regex,
Domain: "apis\\.us$",
},
{
Type: Substr,
Domain: "apis",
},
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Domain,
Domain: "com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
{
Type: Substr,
Domain: "apis",
},
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Full,
Domain: "fonts.googleapis.com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
{
Type: Domain,
Domain: "example.com",
},
}
cases := []struct {
Input string
Output []uint32
}{
{
Input: "www.baidu.com",
Output: []uint32{5, 9, 4},
},
{
Input: "fonts.googleapis.com",
Output: []uint32{8, 3, 7, 4, 2, 6},
},
{
Input: "example.googleapis.com",
Output: []uint32{3, 7, 4, 2, 6},
},
{
Input: "testapis.us",
Output: []uint32{2, 6, 1},
},
{
Input: "example.com",
Output: []uint32{10, 4},
},
}
matcherGroup := NewMphIndexMatcher()
for _, rule := range rules {
matcher, err := rule.Type.New(rule.Domain)
common.Must(err)
matcherGroup.Add(matcher)
}
matcherGroup.Build()
for _, test := range cases {
if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}

View File

@@ -0,0 +1,282 @@
package strmatcher
import (
"container/list"
)
const (
acValidCharCount = 39 // aA-zZ (26), 0-9 (10), - (1), . (1), invalid(1)
acMatchTypeCount = 3 // Full, Domain and Substr
)
type acEdge byte
const (
acTrieEdge acEdge = 1
acFailEdge acEdge = 0
)
type acNode struct {
next [acValidCharCount]uint32 // EdgeIdx -> Next NodeIdx (Next trie node or fail node)
edge [acValidCharCount]acEdge // EdgeIdx -> Trie Edge / Fail Edge
fail uint32 // NodeIdx of *next matched* Substr Pattern on its fail path
match uint32 // MatchIdx of matchers registered on this node, 0 indicates no match
} // Sizeof acNode: (4+1)*acValidCharCount + <padding> + 4 + 4
type acValue [acMatchTypeCount][]uint32 // MatcherType -> Registered Matcher Values
// ACAutoMationMatcherGroup is an implementation of MatcherGroup.
// It uses an AC Automata to provide support for Full, Domain and Substr matcher. Trie node is char based.
//
// NOTICE: ACAutomatonMatcherGroup currently uses a restricted charset (LDH Subset),
// upstream should manually in a way to ensure all patterns and inputs passed to it to be in this charset.
type ACAutomatonMatcherGroup struct {
nodes []acNode // NodeIdx -> acNode
values []acValue // MatchIdx -> acValue
}
func NewACAutomatonMatcherGroup() *ACAutomatonMatcherGroup {
ac := new(ACAutomatonMatcherGroup)
ac.addNode() // Create root node (NodeIdx 0)
ac.addMatchEntry() // Create sentinel match entry (MatchIdx 0)
return ac
}
// AddFullMatcher implements MatcherGroupForFull.AddFullMatcher.
func (ac *ACAutomatonMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) {
ac.addPattern(0, matcher.Pattern(), matcher.Type(), value)
}
// AddDomainMatcher implements MatcherGroupForDomain.AddDomainMatcher.
func (ac *ACAutomatonMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) {
node := ac.addPattern(0, matcher.Pattern(), matcher.Type(), value) // For full domain match
ac.addPattern(node, ".", matcher.Type(), value) // For partial domain match
}
// AddSubstrMatcher implements MatcherGroupForSubstr.AddSubstrMatcher.
func (ac *ACAutomatonMatcherGroup) AddSubstrMatcher(matcher SubstrMatcher, value uint32) {
ac.addPattern(0, matcher.Pattern(), matcher.Type(), value)
}
func (ac *ACAutomatonMatcherGroup) addPattern(nodeIdx uint32, pattern string, matcherType Type, value uint32) uint32 {
node := &ac.nodes[nodeIdx]
for i := len(pattern) - 1; i >= 0; i-- {
edgeIdx := acCharset[pattern[i]]
nextIdx := node.next[edgeIdx]
if nextIdx == 0 { // Add new Trie Edge
nextIdx = ac.addNode()
ac.nodes[nodeIdx].next[edgeIdx] = nextIdx
ac.nodes[nodeIdx].edge[edgeIdx] = acTrieEdge
}
nodeIdx = nextIdx
node = &ac.nodes[nodeIdx]
}
if node.match == 0 { // Add new match entry
node.match = ac.addMatchEntry()
}
ac.values[node.match][matcherType] = append(ac.values[node.match][matcherType], value)
return nodeIdx
}
func (ac *ACAutomatonMatcherGroup) addNode() uint32 {
ac.nodes = append(ac.nodes, acNode{})
return uint32(len(ac.nodes) - 1)
}
func (ac *ACAutomatonMatcherGroup) addMatchEntry() uint32 {
ac.values = append(ac.values, acValue{})
return uint32(len(ac.values) - 1)
}
func (ac *ACAutomatonMatcherGroup) Build() error {
fail := make([]uint32, len(ac.nodes))
queue := list.New()
for edgeIdx := 0; edgeIdx < acValidCharCount; edgeIdx++ {
if nextIdx := ac.nodes[0].next[edgeIdx]; nextIdx != 0 {
queue.PushBack(nextIdx)
}
}
for {
front := queue.Front()
if front == nil {
break
}
queue.Remove(front)
nodeIdx := front.Value.(uint32)
node := &ac.nodes[nodeIdx] // Current node
failNode := &ac.nodes[fail[nodeIdx]] // Fail node of currrent node
for edgeIdx := 0; edgeIdx < acValidCharCount; edgeIdx++ {
nodeIdx := node.next[edgeIdx] // Next node through trie edge
failIdx := failNode.next[edgeIdx] // Next node through fail edge
if nodeIdx != 0 {
queue.PushBack(nodeIdx)
fail[nodeIdx] = failIdx
if match := ac.nodes[failIdx].match; match != 0 && len(ac.values[match][Substr]) > 0 { // Fail node is a Substr match node
ac.nodes[nodeIdx].fail = failIdx
} else { // Use path compression to reduce fail path to only contain match nodes
ac.nodes[nodeIdx].fail = ac.nodes[failIdx].fail
}
} else { // Add new fail edge
node.next[edgeIdx] = failIdx
node.edge[edgeIdx] = acFailEdge
}
}
}
return nil
}
// Match implements MatcherGroup.Match.
func (ac *ACAutomatonMatcherGroup) Match(input string) []uint32 {
suffixMatches := make([][]uint32, 0, 5)
substrMatches := make([][]uint32, 0, 5)
fullMatch := true // fullMatch indicates no fail edge traversed so far.
node := &ac.nodes[0] // start from root node.
// 1. the match string is all through trie edge. FULL MATCH or DOMAIN
// 2. the match string is through a fail edge. NOT FULL MATCH
// 2.1 Through a fail edge, but there exists a valid node. SUBSTR
for i := len(input) - 1; i >= 0; i-- {
edge := acCharset[input[i]]
fullMatch = fullMatch && (node.edge[edge] == acTrieEdge)
node = &ac.nodes[node.next[edge]] // Advance to next node
// When entering a new node, traverse the fail path to find all possible Substr patterns:
// 1. The fail path is compressed to only contains match nodes and root node (for terminate condition).
// 2. node.fail != 0 is added here for better performance (as shown by benchmark), possibly it helps branch prediction.
if node.fail != 0 {
for failIdx, failNode := node.fail, &ac.nodes[node.fail]; failIdx != 0; failIdx, failNode = failNode.fail, &ac.nodes[failIdx] {
substrMatches = append(substrMatches, ac.values[failNode.match][Substr])
}
}
// When entering a new node, check whether this node is a match.
// For Substr matchers:
// 1. Matched in any situation, whether a failNode edge is traversed or not.
// For Domain matchers:
// 1. Should not traverse any fail edge (fullMatch).
// 2. Only check on dot separator (input[i] == '.').
if node.match != 0 {
values := ac.values[node.match]
if len(values[Substr]) > 0 {
substrMatches = append(substrMatches, values[Substr])
}
if fullMatch && input[i] == '.' && len(values[Domain]) > 0 {
suffixMatches = append(suffixMatches, values[Domain])
}
}
}
// At the end of input, check if the whole string matches a pattern.
// For Domain matchers:
// 1. Exact match on Domain Matcher works like Full Match. e.g. foo.com is a full match for domain:foo.com.
// For Full matchers:
// 1. Only when no fail edge is traversed (fullMatch).
// 2. Takes the highest priority (added at last).
if fullMatch && node.match != 0 {
values := ac.values[node.match]
if len(values[Domain]) > 0 {
suffixMatches = append(suffixMatches, values[Domain])
}
if len(values[Full]) > 0 {
suffixMatches = append(suffixMatches, values[Full])
}
}
if len(substrMatches) == 0 {
return CompositeMatchesReverse(suffixMatches)
}
return CompositeMatchesReverse(append(substrMatches, suffixMatches...))
}
// MatchAny implements MatcherGroup.MatchAny.
func (ac *ACAutomatonMatcherGroup) MatchAny(input string) bool {
fullMatch := true
node := &ac.nodes[0]
for i := len(input) - 1; i >= 0; i-- {
edge := acCharset[input[i]]
fullMatch = fullMatch && (node.edge[edge] == acTrieEdge)
node = &ac.nodes[node.next[edge]]
if node.fail != 0 { // There is a match on this node's fail path
return true
}
if node.match != 0 { // There is a match on this node
values := ac.values[node.match]
if len(values[Substr]) > 0 { // Substr match succeeds unconditionally
return true
}
if fullMatch && input[i] == '.' && len(values[Domain]) > 0 { // Domain match only succeeds with dot separator on trie path
return true
}
}
}
return fullMatch && node.match != 0 // At the end of input, Domain and Full match will succeed if no fail edge is traversed
}
// Letter-Digit-Hyphen (LDH) subset (https://tools.ietf.org/html/rfc952):
// - Letters A to Z (no distinction is made between uppercase and lowercase)
// - Digits 0 to 9
// - Hyphens(-) and Periods(.)
//
// If for future the strmatcher are used for other scenarios than domain,
// we could add a new Charset interface to represent variable charsets.
var acCharset = [256]int{
'A': 1,
'a': 1,
'B': 2,
'b': 2,
'C': 3,
'c': 3,
'D': 4,
'd': 4,
'E': 5,
'e': 5,
'F': 6,
'f': 6,
'G': 7,
'g': 7,
'H': 8,
'h': 8,
'I': 9,
'i': 9,
'J': 10,
'j': 10,
'K': 11,
'k': 11,
'L': 12,
'l': 12,
'M': 13,
'm': 13,
'N': 14,
'n': 14,
'O': 15,
'o': 15,
'P': 16,
'p': 16,
'Q': 17,
'q': 17,
'R': 18,
'r': 18,
'S': 19,
's': 19,
'T': 20,
't': 20,
'U': 21,
'u': 21,
'V': 22,
'v': 22,
'W': 23,
'w': 23,
'X': 24,
'x': 24,
'Y': 25,
'y': 25,
'Z': 26,
'z': 26,
'-': 27,
'.': 28,
'0': 29,
'1': 30,
'2': 31,
'3': 32,
'4': 33,
'5': 34,
'6': 35,
'7': 36,
'8': 37,
'9': 38,
}

View File

@@ -0,0 +1,365 @@
package strmatcher_test
import (
"reflect"
"testing"
"github.com/xtls/xray-core/common"
. "github.com/xtls/xray-core/common/geodata/strmatcher"
)
func TestACAutomatonMatcherGroup(t *testing.T) {
cases1 := []struct {
pattern string
mType Type
input string
output bool
}{
{
pattern: "example.com",
mType: Domain,
input: "www.example.com",
output: true,
},
{
pattern: "example.com",
mType: Domain,
input: "example.com",
output: true,
},
{
pattern: "example.com",
mType: Domain,
input: "www.e3ample.com",
output: false,
},
{
pattern: "example.com",
mType: Domain,
input: "xample.com",
output: false,
},
{
pattern: "example.com",
mType: Domain,
input: "xexample.com",
output: false,
},
{
pattern: "example.com",
mType: Full,
input: "example.com",
output: true,
},
{
pattern: "example.com",
mType: Full,
input: "xexample.com",
output: false,
},
}
for _, test := range cases1 {
ac := NewACAutomatonMatcherGroup()
matcher, err := test.mType.New(test.pattern)
common.Must(err)
common.Must(AddMatcherToGroup(ac, matcher, 0))
ac.Build()
if m := ac.MatchAny(test.input); m != test.output {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
{
cases2Input := []struct {
pattern string
mType Type
}{
{
pattern: "163.com",
mType: Domain,
},
{
pattern: "m.126.com",
mType: Full,
},
{
pattern: "3.com",
mType: Full,
},
{
pattern: "google.com",
mType: Substr,
},
{
pattern: "vgoogle.com",
mType: Substr,
},
}
ac := NewACAutomatonMatcherGroup()
for _, test := range cases2Input {
matcher, err := test.mType.New(test.pattern)
common.Must(err)
common.Must(AddMatcherToGroup(ac, matcher, 0))
}
ac.Build()
cases2Output := []struct {
pattern string
res bool
}{
{
pattern: "126.com",
res: false,
},
{
pattern: "m.163.com",
res: true,
},
{
pattern: "mm163.com",
res: false,
},
{
pattern: "m.126.com",
res: true,
},
{
pattern: "163.com",
res: true,
},
{
pattern: "63.com",
res: false,
},
{
pattern: "oogle.com",
res: false,
},
{
pattern: "vvgoogle.com",
res: true,
},
}
for _, test := range cases2Output {
if m := ac.MatchAny(test.pattern); m != test.res {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}
{
cases3Input := []struct {
pattern string
mType Type
}{
{
pattern: "video.google.com",
mType: Domain,
},
{
pattern: "gle.com",
mType: Domain,
},
}
ac := NewACAutomatonMatcherGroup()
for _, test := range cases3Input {
matcher, err := test.mType.New(test.pattern)
common.Must(err)
common.Must(AddMatcherToGroup(ac, matcher, 0))
}
ac.Build()
cases3Output := []struct {
pattern string
res bool
}{
{
pattern: "google.com",
res: false,
},
}
for _, test := range cases3Output {
if m := ac.MatchAny(test.pattern); m != test.res {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}
{
cases4Input := []struct {
pattern string
mType Type
}{
{
pattern: "apis",
mType: Substr,
},
{
pattern: "googleapis.com",
mType: Domain,
},
}
ac := NewACAutomatonMatcherGroup()
for _, test := range cases4Input {
matcher, err := test.mType.New(test.pattern)
common.Must(err)
common.Must(AddMatcherToGroup(ac, matcher, 0))
}
ac.Build()
cases4Output := []struct {
pattern string
res bool
}{
{
pattern: "gapis.com",
res: true,
},
}
for _, test := range cases4Output {
if m := ac.MatchAny(test.pattern); m != test.res {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}
}
func TestACAutomatonMatcherGroupSubstr(t *testing.T) {
patterns := []struct {
pattern string
mType Type
}{
{
pattern: "apis",
mType: Substr,
},
{
pattern: "google",
mType: Substr,
},
{
pattern: "apis",
mType: Substr,
},
}
cases := []struct {
input string
output []uint32
}{
{
input: "google.com",
output: []uint32{1},
},
{
input: "apis.com",
output: []uint32{0, 2},
},
{
input: "googleapis.com",
output: []uint32{1, 0, 2},
},
{
input: "fonts.googleapis.com",
output: []uint32{1, 0, 2},
},
{
input: "apis.googleapis.com",
output: []uint32{0, 2, 1, 0, 2},
},
}
matcherGroup := NewACAutomatonMatcherGroup()
for id, entry := range patterns {
matcher, err := entry.mType.New(entry.pattern)
common.Must(err)
common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(id)))
}
matcherGroup.Build()
for _, test := range cases {
if r := matcherGroup.Match(test.input); !reflect.DeepEqual(r, test.output) {
t.Error("unexpected output: ", r, " for test case ", test)
}
}
}
// See https://github.com/v2fly/v2ray-core/issues/92#issuecomment-673238489
func TestACAutomatonMatcherGroupAsIndexMatcher(t *testing.T) {
rules := []struct {
Type Type
Domain string
}{
// Regex not supported by ACAutomationMatcherGroup
// {
// Type: Regex,
// Domain: "apis\\.us$",
// },
{
Type: Substr,
Domain: "apis",
},
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Domain,
Domain: "com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
{
Type: Substr,
Domain: "apis",
},
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Full,
Domain: "fonts.googleapis.com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
{
Type: Domain,
Domain: "example.com",
},
}
cases := []struct {
Input string
Output []uint32
}{
{
Input: "www.baidu.com",
Output: []uint32{5, 9, 4},
},
{
Input: "fonts.googleapis.com",
Output: []uint32{8, 3, 7, 4, 2, 6},
},
{
Input: "example.googleapis.com",
Output: []uint32{3, 7, 4, 2, 6},
},
{
Input: "testapis.us",
Output: []uint32{2, 6 /*, 1*/},
},
{
Input: "example.com",
Output: []uint32{10, 4},
},
}
matcherGroup := NewACAutomatonMatcherGroup()
for i, rule := range rules {
matcher, err := rule.Type.New(rule.Domain)
common.Must(err)
common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(i+2)))
}
matcherGroup.Build()
for _, test := range cases {
if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}

View File

@@ -0,0 +1,109 @@
package strmatcher
type trieNode struct {
values []uint32
children map[string]*trieNode
}
// DomainMatcherGroup is an implementation of MatcherGroup.
// It uses trie to optimize both memory consumption and lookup speed. Trie node is domain label based.
type DomainMatcherGroup struct {
root *trieNode
}
func NewDomainMatcherGroup() *DomainMatcherGroup {
return &DomainMatcherGroup{
root: new(trieNode),
}
}
// AddDomainMatcher implements MatcherGroupForDomain.AddDomainMatcher.
func (g *DomainMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) {
node := g.root
pattern := matcher.Pattern()
for i := len(pattern); i > 0; {
var part string
for j := i - 1; ; j-- {
if pattern[j] == '.' {
part = pattern[j+1 : i]
i = j
break
}
if j == 0 {
part = pattern[j:i]
i = j
break
}
}
if node.children == nil {
node.children = make(map[string]*trieNode)
}
next := node.children[part]
if next == nil {
next = new(trieNode)
node.children[part] = next
}
node = next
}
node.values = append(node.values, value)
}
// Match implements MatcherGroup.Match.
func (g *DomainMatcherGroup) Match(input string) []uint32 {
matches := make([][]uint32, 0, 5)
node := g.root
for i := len(input); i > 0; {
for j := i - 1; ; j-- {
if input[j] == '.' { // Domain label found
node = node.children[input[j+1:i]]
i = j
break
}
if j == 0 { // The last part of domain label
node = node.children[input[j:i]]
i = j
break
}
}
if node == nil { // No more match if no trie edge transition
break
}
if len(node.values) > 0 { // Found matched matchers
matches = append(matches, node.values)
}
if node.children == nil { // No more match if leaf node reached
break
}
}
return CompositeMatchesReverse(matches)
}
// MatchAny implements MatcherGroup.MatchAny.
func (g *DomainMatcherGroup) MatchAny(input string) bool {
node := g.root
for i := len(input); i > 0; {
for j := i - 1; ; j-- {
if input[j] == '.' {
node = node.children[input[j+1:i]]
i = j
break
}
if j == 0 {
node = node.children[input[j:i]]
i = j
break
}
}
if node == nil {
return false
}
if len(node.values) > 0 {
return true
}
if node.children == nil {
return false
}
}
return false
}

View File

@@ -4,19 +4,43 @@ import (
"reflect"
"testing"
. "github.com/xtls/xray-core/common/strmatcher"
. "github.com/xtls/xray-core/common/geodata/strmatcher"
)
func TestDomainMatcherGroup(t *testing.T) {
g := new(DomainMatcherGroup)
g.Add("example.com", 1)
g.Add("google.com", 2)
g.Add("x.a.com", 3)
g.Add("a.b.com", 4)
g.Add("c.a.b.com", 5)
g.Add("x.y.com", 4)
g.Add("x.y.com", 6)
patterns := []struct {
Pattern string
Value uint32
}{
{
Pattern: "example.com",
Value: 1,
},
{
Pattern: "google.com",
Value: 2,
},
{
Pattern: "x.a.com",
Value: 3,
},
{
Pattern: "a.b.com",
Value: 4,
},
{
Pattern: "c.a.b.com",
Value: 5,
},
{
Pattern: "x.y.com",
Value: 4,
},
{
Pattern: "x.y.com",
Value: 6,
},
}
testCases := []struct {
Domain string
Result []uint32
@@ -58,7 +82,10 @@ func TestDomainMatcherGroup(t *testing.T) {
Result: []uint32{4, 6},
},
}
g := NewDomainMatcherGroup()
for _, pattern := range patterns {
AddMatcherToGroup(g, DomainMatcher(pattern.Pattern), pattern.Value)
}
for _, testCase := range testCases {
r := g.Match(testCase.Domain)
if !reflect.DeepEqual(r, testCase.Result) {
@@ -68,7 +95,7 @@ func TestDomainMatcherGroup(t *testing.T) {
}
func TestEmptyDomainMatcherGroup(t *testing.T) {
g := new(DomainMatcherGroup)
g := NewDomainMatcherGroup()
r := g.Match("example.com")
if len(r) != 0 {
t.Error("Expect [], but ", r)

View File

@@ -0,0 +1,30 @@
package strmatcher
// FullMatcherGroup is an implementation of MatcherGroup.
// It uses a hash table to facilitate exact match lookup.
type FullMatcherGroup struct {
matchers map[string][]uint32
}
func NewFullMatcherGroup() *FullMatcherGroup {
return &FullMatcherGroup{
matchers: make(map[string][]uint32),
}
}
// AddFullMatcher implements MatcherGroupForFull.AddFullMatcher.
func (g *FullMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) {
domain := matcher.Pattern()
g.matchers[domain] = append(g.matchers[domain], value)
}
// Match implements MatcherGroup.Match.
func (g *FullMatcherGroup) Match(input string) []uint32 {
return g.matchers[input]
}
// MatchAny implements MatcherGroup.Any.
func (g *FullMatcherGroup) MatchAny(input string) bool {
_, found := g.matchers[input]
return found
}

View File

@@ -4,17 +4,35 @@ import (
"reflect"
"testing"
. "github.com/xtls/xray-core/common/strmatcher"
. "github.com/xtls/xray-core/common/geodata/strmatcher"
)
func TestFullMatcherGroup(t *testing.T) {
g := new(FullMatcherGroup)
g.Add("example.com", 1)
g.Add("google.com", 2)
g.Add("x.a.com", 3)
g.Add("x.y.com", 4)
g.Add("x.y.com", 6)
patterns := []struct {
Pattern string
Value uint32
}{
{
Pattern: "example.com",
Value: 1,
},
{
Pattern: "google.com",
Value: 2,
},
{
Pattern: "x.a.com",
Value: 3,
},
{
Pattern: "x.y.com",
Value: 4,
},
{
Pattern: "x.y.com",
Value: 6,
},
}
testCases := []struct {
Domain string
Result []uint32
@@ -32,7 +50,10 @@ func TestFullMatcherGroup(t *testing.T) {
Result: []uint32{4, 6},
},
}
g := NewFullMatcherGroup()
for _, pattern := range patterns {
AddMatcherToGroup(g, FullMatcher(pattern.Pattern), pattern.Value)
}
for _, testCase := range testCases {
r := g.Match(testCase.Domain)
if !reflect.DeepEqual(r, testCase.Result) {
@@ -42,7 +63,7 @@ func TestFullMatcherGroup(t *testing.T) {
}
func TestEmptyFullMatcherGroup(t *testing.T) {
g := new(FullMatcherGroup)
g := NewFullMatcherGroup()
r := g.Match("example.com")
if len(r) != 0 {
t.Error("Expect [], but ", r)

View File

@@ -0,0 +1,198 @@
package strmatcher
import (
"math/bits"
"runtime"
"sort"
"strings"
"unsafe"
)
// PrimeRK is the prime base used in Rabin-Karp algorithm.
const PrimeRK = 16777619
// RollingHash calculates the rolling murmurHash of given string based on a provided suffix hash.
func RollingHash(hash uint32, input string) uint32 {
for i := len(input) - 1; i >= 0; i-- {
hash = hash*PrimeRK + uint32(input[i])
}
return hash
}
// MemHash is the hash function used by go map, it utilizes available hardware instructions(behaves
// as aeshash if aes instruction is available).
// With different seed, each MemHash<seed> performs as distinct hash functions.
func MemHash(seed uint32, input string) uint32 {
return uint32(strhash(unsafe.Pointer(&input), uintptr(seed))) // nosemgrep
}
const (
mphMatchTypeCount = 2 // Full and Domain
)
type mphRuleInfo struct {
rollingHash uint32
matchers [mphMatchTypeCount][]uint32
}
// MphMatcherGroup is an implementation of MatcherGroup.
// It implements Rabin-Karp algorithm and minimal perfect hash table for Full and Domain matcher.
type MphMatcherGroup struct {
rules []string // RuleIdx -> pattern string, index 0 reserved for failed lookup
values [][]uint32 // RuleIdx -> registered matcher values for the pattern (Full Matcher takes precedence)
level0 []uint32 // RollingHash & Mask -> seed for Memhash
level0Mask uint32 // Mask restricting RollingHash to 0 ~ len(level0)
level1 []uint32 // Memhash<seed> & Mask -> stored index for rules
level1Mask uint32 // Mask for restricting Memhash<seed> to 0 ~ len(level1)
ruleInfos *map[string]mphRuleInfo
}
func NewMphMatcherGroup() *MphMatcherGroup {
return &MphMatcherGroup{
rules: []string{""},
values: [][]uint32{nil},
level0: nil,
level0Mask: 0,
level1: nil,
level1Mask: 0,
ruleInfos: &map[string]mphRuleInfo{}, // Only used for building, destroyed after build complete
}
}
// AddFullMatcher implements MatcherGroupForFull.
func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) {
pattern := strings.ToLower(matcher.Pattern())
g.addPattern(0, "", pattern, matcher.Type(), value)
}
// AddDomainMatcher implements MatcherGroupForDomain.
func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) {
pattern := strings.ToLower(matcher.Pattern())
hash := g.addPattern(0, "", pattern, matcher.Type(), value) // For full domain match
g.addPattern(hash, pattern, ".", matcher.Type(), value) // For partial domain match
}
func (g *MphMatcherGroup) addPattern(suffixHash uint32, suffixPattern string, pattern string, matcherType Type, value uint32) uint32 {
fullPattern := pattern + suffixPattern
info, found := (*g.ruleInfos)[fullPattern]
if !found {
info = mphRuleInfo{rollingHash: RollingHash(suffixHash, pattern)}
g.rules = append(g.rules, fullPattern)
g.values = append(g.values, nil)
}
info.matchers[matcherType] = append(info.matchers[matcherType], value)
(*g.ruleInfos)[fullPattern] = info
return info.rollingHash
}
// Build builds a minimal perfect hash table for insert rules.
// Algorithm used: Hash, displace, and compress. See http://cmph.sourceforge.net/papers/esa09.pdf
func (g *MphMatcherGroup) Build() error {
ruleCount := len(*g.ruleInfos)
g.level0 = make([]uint32, nextPow2(ruleCount/4))
g.level0Mask = uint32(len(g.level0) - 1)
g.level1 = make([]uint32, nextPow2(ruleCount))
g.level1Mask = uint32(len(g.level1) - 1)
// Create buckets based on all rule's rolling hash
buckets := make([][]uint32, len(g.level0))
for ruleIdx := 1; ruleIdx < len(g.rules); ruleIdx++ { // Traverse rules starting from index 1 (0 reserved for failed lookup)
ruleInfo := (*g.ruleInfos)[g.rules[ruleIdx]]
bucketIdx := ruleInfo.rollingHash & g.level0Mask
buckets[bucketIdx] = append(buckets[bucketIdx], uint32(ruleIdx))
g.values[ruleIdx] = append(ruleInfo.matchers[Full], ruleInfo.matchers[Domain]...) // nolint:gocritic
}
g.ruleInfos = nil // Set ruleInfos nil to release memory
runtime.GC() // peak mem
// Sort buckets in descending order with respect to each bucket's size
bucketIdxs := make([]int, len(buckets))
for bucketIdx := range buckets {
bucketIdxs[bucketIdx] = bucketIdx
}
sort.Slice(bucketIdxs, func(i, j int) bool { return len(buckets[bucketIdxs[i]]) > len(buckets[bucketIdxs[j]]) })
// Exercise Hash, Displace, and Compress algorithm to construct minimal perfect hash table
occupied := make([]bool, len(g.level1)) // Whether a second-level hash has been already used
hashedBucket := make([]uint32, 0, 4) // Second-level hashes for each rule in a specific bucket
for _, bucketIdx := range bucketIdxs {
bucket := buckets[bucketIdx]
hashedBucket = hashedBucket[:0]
seed := uint32(0)
for len(hashedBucket) != len(bucket) {
for _, ruleIdx := range bucket {
memHash := MemHash(seed, g.rules[ruleIdx]) & g.level1Mask
if occupied[memHash] { // Collision occurred with this seed
for _, hash := range hashedBucket { // Revert all values in this hashed bucket
occupied[hash] = false
g.level1[hash] = 0
}
hashedBucket = hashedBucket[:0]
seed++ // Try next seed
break
}
occupied[memHash] = true
g.level1[memHash] = ruleIdx // The final value in the hash table
hashedBucket = append(hashedBucket, memHash)
}
}
g.level0[bucketIdx] = seed // Displacement value for this bucket
}
return nil
}
// Lookup searches for input in minimal perfect hash table and returns its index. 0 indicates not found.
func (g *MphMatcherGroup) Lookup(rollingHash uint32, input string) uint32 {
i0 := rollingHash & g.level0Mask
seed := g.level0[i0]
i1 := MemHash(seed, input) & g.level1Mask
if n := g.level1[i1]; g.rules[n] == input {
return n
}
return 0
}
// Match implements MatcherGroup.Match.
func (g *MphMatcherGroup) Match(input string) []uint32 {
matches := make([][]uint32, 0, 5)
hash := uint32(0)
for i := len(input) - 1; i >= 0; i-- {
hash = hash*PrimeRK + uint32(input[i])
if input[i] == '.' {
if mphIdx := g.Lookup(hash, input[i:]); mphIdx != 0 {
matches = append(matches, g.values[mphIdx])
}
}
}
if mphIdx := g.Lookup(hash, input); mphIdx != 0 {
matches = append(matches, g.values[mphIdx])
}
return CompositeMatchesReverse(matches)
}
// MatchAny implements MatcherGroup.MatchAny.
func (g *MphMatcherGroup) MatchAny(input string) bool {
hash := uint32(0)
for i := len(input) - 1; i >= 0; i-- {
hash = hash*PrimeRK + uint32(input[i])
if input[i] == '.' {
if g.Lookup(hash, input[i:]) != 0 {
return true
}
}
}
return g.Lookup(hash, input) != 0
}
func nextPow2(v int) int {
if v <= 1 {
return 1
}
const MaxUInt = ^uint(0)
n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1
return int(n)
}
//go:noescape
//go:linkname strhash runtime.strhash
func strhash(p unsafe.Pointer, h uintptr) uintptr

View File

@@ -5,94 +5,10 @@ import (
"testing"
"github.com/xtls/xray-core/common"
. "github.com/xtls/xray-core/common/strmatcher"
. "github.com/xtls/xray-core/common/geodata/strmatcher"
)
func TestMatcherGroup(t *testing.T) {
rules := []struct {
Type Type
Domain string
}{
{
Type: Regex,
Domain: "apis\\.us$",
},
{
Type: Substr,
Domain: "apis",
},
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Domain,
Domain: "com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
{
Type: Substr,
Domain: "apis",
},
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Full,
Domain: "fonts.googleapis.com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
{
Type: Domain,
Domain: "example.com",
},
}
cases := []struct {
Input string
Output []uint32
}{
{
Input: "www.baidu.com",
Output: []uint32{5, 9, 4},
},
{
Input: "fonts.googleapis.com",
Output: []uint32{8, 3, 7, 4, 2, 6},
},
{
Input: "example.googleapis.com",
Output: []uint32{3, 7, 4, 2, 6},
},
{
Input: "testapis.us",
Output: []uint32{1, 2, 6},
},
{
Input: "example.com",
Output: []uint32{10, 4},
},
}
matcherGroup := &MatcherGroup{}
for _, rule := range rules {
matcher, err := rule.Type.New(rule.Domain)
common.Must(err)
matcherGroup.Add(matcher)
}
for _, test := range cases {
if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}
func TestACAutomaton(t *testing.T) {
func TestMphMatcherGroup(t *testing.T) {
cases1 := []struct {
pattern string
mType Type
@@ -100,53 +16,55 @@ func TestACAutomaton(t *testing.T) {
output bool
}{
{
pattern: "xtls.github.io",
pattern: "example.com",
mType: Domain,
input: "www.xtls.github.io",
input: "www.example.com",
output: true,
},
{
pattern: "xtls.github.io",
pattern: "example.com",
mType: Domain,
input: "xtls.github.io",
input: "example.com",
output: true,
},
{
pattern: "xtls.github.io",
pattern: "example.com",
mType: Domain,
input: "www.xtis.github.io",
input: "www.e3ample.com",
output: false,
},
{
pattern: "xtls.github.io",
pattern: "example.com",
mType: Domain,
input: "tls.github.io",
input: "xample.com",
output: false,
},
{
pattern: "xtls.github.io",
pattern: "example.com",
mType: Domain,
input: "xxtls.github.io",
input: "xexample.com",
output: false,
},
{
pattern: "xtls.github.io",
pattern: "example.com",
mType: Full,
input: "xtls.github.io",
input: "example.com",
output: true,
},
{
pattern: "xtls.github.io",
pattern: "example.com",
mType: Full,
input: "xxtls.github.io",
input: "xexample.com",
output: false,
},
}
for _, test := range cases1 {
ac := NewACAutomaton()
ac.Add(test.pattern, test.mType)
ac.Build()
if m := ac.Match(test.input); m != test.output {
mph := NewMphMatcherGroup()
matcher, err := test.mType.New(test.pattern)
common.Must(err)
common.Must(AddMatcherToGroup(mph, matcher, 0))
mph.Build()
if m := mph.MatchAny(test.input); m != test.output {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
@@ -167,20 +85,14 @@ func TestACAutomaton(t *testing.T) {
pattern: "3.com",
mType: Full,
},
{
pattern: "google.com",
mType: Substr,
},
{
pattern: "vgoogle.com",
mType: Substr,
},
}
ac := NewACAutomaton()
mph := NewMphMatcherGroup()
for _, test := range cases2Input {
ac.Add(test.pattern, test.mType)
matcher, err := test.mType.New(test.pattern)
common.Must(err)
common.Must(AddMatcherToGroup(mph, matcher, 0))
}
ac.Build()
mph.Build()
cases2Output := []struct {
pattern string
res bool
@@ -215,15 +127,11 @@ func TestACAutomaton(t *testing.T) {
},
{
pattern: "vvgoogle.com",
res: true,
},
{
pattern: "½",
res: false,
},
}
for _, test := range cases2Output {
if m := ac.Match(test.pattern); m != test.res {
if m := mph.MatchAny(test.pattern); m != test.res {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
@@ -242,11 +150,13 @@ func TestACAutomaton(t *testing.T) {
mType: Domain,
},
}
ac := NewACAutomaton()
mph := NewMphMatcherGroup()
for _, test := range cases3Input {
ac.Add(test.pattern, test.mType)
matcher, err := test.mType.New(test.pattern)
common.Must(err)
common.Must(AddMatcherToGroup(mph, matcher, 0))
}
ac.Build()
mph.Build()
cases3Output := []struct {
pattern string
res bool
@@ -257,9 +167,112 @@ func TestACAutomaton(t *testing.T) {
},
}
for _, test := range cases3Output {
if m := ac.Match(test.pattern); m != test.res {
if m := mph.MatchAny(test.pattern); m != test.res {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}
}
// See https://github.com/v2fly/v2ray-core/issues/92#issuecomment-673238489
func TestMphMatcherGroupAsIndexMatcher(t *testing.T) {
rules := []struct {
Type Type
Domain string
}{
// Regex not supported by MphMatcherGroup
// {
// Type: Regex,
// Domain: "apis\\.us$",
// },
// Substr not supported by MphMatcherGroup
// {
// Type: Substr,
// Domain: "apis",
// },
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Domain,
Domain: "com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
// Substr not supported by MphMatcherGroup, We add another matcher to preserve index
{
Type: Domain, // Substr,
Domain: "example.com", // "apis",
},
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Full,
Domain: "fonts.googleapis.com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
{ // This matcher (index 10) is swapped with matcher (index 6) to test that full matcher takes high priority.
Type: Full,
Domain: "example.com",
},
{
Type: Domain,
Domain: "example.com",
},
}
cases := []struct {
Input string
Output []uint32
}{
{
Input: "www.baidu.com",
Output: []uint32{5, 9, 4},
},
{
Input: "fonts.googleapis.com",
Output: []uint32{8, 3, 7, 4 /*2, 6*/},
},
{
Input: "example.googleapis.com",
Output: []uint32{3, 7, 4 /*2, 6*/},
},
{
Input: "testapis.us",
// Output: []uint32{ /*2, 6*/ /*1,*/ },
Output: nil,
},
{
Input: "example.com",
Output: []uint32{10, 6, 11, 4},
},
}
matcherGroup := NewMphMatcherGroup()
for i, rule := range rules {
matcher, err := rule.Type.New(rule.Domain)
common.Must(err)
common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(i+3)))
}
matcherGroup.Build()
for _, test := range cases {
if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}
func TestEmptyMphMatcherGroup(t *testing.T) {
g := NewMphMatcherGroup()
g.Build()
r := g.Match("example.com")
if len(r) != 0 {
t.Error("Expect [], but ", r)
}
}

View File

@@ -0,0 +1,41 @@
package strmatcher
type matcherEntry struct {
matcher Matcher
value uint32
}
// SimpleMatcherGroup is an implementation of MatcherGroup.
// It simply stores all matchers in an array and sequentially matches them.
type SimpleMatcherGroup struct {
matchers []matcherEntry
}
// AddMatcher implements MatcherGroupForAll.AddMatcher.
func (g *SimpleMatcherGroup) AddMatcher(matcher Matcher, value uint32) {
g.matchers = append(g.matchers, matcherEntry{
matcher: matcher,
value: value,
})
}
// Match implements MatcherGroup.Match.
func (g *SimpleMatcherGroup) Match(input string) []uint32 {
result := []uint32{}
for _, e := range g.matchers {
if e.matcher.Match(input) {
result = append(result, e.value)
}
}
return result
}
// MatchAny implements MatcherGroup.MatchAny.
func (g *SimpleMatcherGroup) MatchAny(input string) bool {
for _, e := range g.matchers {
if e.matcher.Match(input) {
return true
}
}
return false
}

View File

@@ -0,0 +1,69 @@
package strmatcher_test
import (
"reflect"
"testing"
"github.com/xtls/xray-core/common"
. "github.com/xtls/xray-core/common/geodata/strmatcher"
)
func TestSimpleMatcherGroup(t *testing.T) {
patterns := []struct {
pattern string
mType Type
}{
{
pattern: "example.com",
mType: Domain,
},
{
pattern: "example.com",
mType: Full,
},
{
pattern: "example.com",
mType: Regex,
},
}
cases := []struct {
input string
output []uint32
}{
{
input: "www.example.com",
output: []uint32{0, 2},
},
{
input: "example.com",
output: []uint32{0, 1, 2},
},
{
input: "www.e3ample.com",
output: []uint32{},
},
{
input: "xample.com",
output: []uint32{},
},
{
input: "xexample.com",
output: []uint32{2},
},
{
input: "examplexcom",
output: []uint32{2},
},
}
matcherGroup := &SimpleMatcherGroup{}
for id, entry := range patterns {
matcher, err := entry.mType.New(entry.pattern)
common.Must(err)
common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(id)))
}
for _, test := range cases {
if r := matcherGroup.Match(test.input); !reflect.DeepEqual(r, test.output) {
t.Error("unexpected output: ", r, " for test case ", test)
}
}
}

View File

@@ -0,0 +1,61 @@
package strmatcher
import (
"sort"
"strings"
)
// SubstrMatcherGroup is implementation of MatcherGroup,
// It is simply implmeneted to comply with the priority specification of Substr matchers.
type SubstrMatcherGroup struct {
patterns []string
values []uint32
}
// AddSubstrMatcher implements MatcherGroupForSubstr.AddSubstrMatcher.
func (g *SubstrMatcherGroup) AddSubstrMatcher(matcher SubstrMatcher, value uint32) {
g.patterns = append(g.patterns, matcher.Pattern())
g.values = append(g.values, value)
}
// Match implements MatcherGroup.Match.
func (g *SubstrMatcherGroup) Match(input string) []uint32 {
var result []uint32
for i, pattern := range g.patterns {
for j := strings.LastIndex(input, pattern); j != -1; j = strings.LastIndex(input[:j], pattern) {
result = append(result, uint32(j)<<16|uint32(i)&0xffff) // uint32: position (higher 16 bit) | patternIdx (lower 16 bit)
}
}
// sort.Slice will trigger allocation no matter what input is. See https://github.com/golang/go/issues/17332
// We optimize the sorting by length to prevent memory allocation as possible.
switch len(result) {
case 0:
return nil
case 1:
// No need to sort
case 2:
// Do a simple swap if unsorted
if result[0] > result[1] {
result[0], result[1] = result[1], result[0]
}
default:
// Sort the match results in dictionary order, so that:
// 1. Pattern matched at smaller position (meaning matched further) takes precedence.
// 2. When patterns matched at same position, pattern with smaller index (meaning inserted early) takes precedence.
sort.Slice(result, func(i, j int) bool { return result[i] < result[j] })
}
for i, entry := range result {
result[i] = g.values[entry&0xffff] // Get pattern value from its index (the lower 16 bit)
}
return result
}
// MatchAny implements MatcherGroup.MatchAny.
func (g *SubstrMatcherGroup) MatchAny(input string) bool {
for _, pattern := range g.patterns {
if strings.Contains(input, pattern) {
return true
}
}
return false
}

View File

@@ -0,0 +1,65 @@
package strmatcher_test
import (
"reflect"
"testing"
"github.com/xtls/xray-core/common"
. "github.com/xtls/xray-core/common/geodata/strmatcher"
)
func TestSubstrMatcherGroup(t *testing.T) {
patterns := []struct {
pattern string
mType Type
}{
{
pattern: "apis",
mType: Substr,
},
{
pattern: "google",
mType: Substr,
},
{
pattern: "apis",
mType: Substr,
},
}
cases := []struct {
input string
output []uint32
}{
{
input: "google.com",
output: []uint32{1},
},
{
input: "apis.com",
output: []uint32{0, 2},
},
{
input: "googleapis.com",
output: []uint32{1, 0, 2},
},
{
input: "fonts.googleapis.com",
output: []uint32{1, 0, 2},
},
{
input: "apis.googleapis.com",
output: []uint32{0, 2, 1, 0, 2},
},
}
matcherGroup := &SubstrMatcherGroup{}
for id, entry := range patterns {
matcher, err := entry.mType.New(entry.pattern)
common.Must(err)
common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(id)))
}
for _, test := range cases {
if r := matcherGroup.Match(test.input); !reflect.DeepEqual(r, test.output) {
t.Error("unexpected output: ", r, " for test case ", test)
}
}
}

View File

@@ -0,0 +1,290 @@
package strmatcher
import (
"errors"
"regexp"
"strings"
"unicode/utf8"
"golang.org/x/net/idna"
)
// FullMatcher is an implementation of Matcher.
type FullMatcher string
func (FullMatcher) Type() Type {
return Full
}
func (m FullMatcher) Pattern() string {
return string(m)
}
func (m FullMatcher) String() string {
return "full:" + m.Pattern()
}
func (m FullMatcher) Match(s string) bool {
return string(m) == s
}
// DomainMatcher is an implementation of Matcher.
type DomainMatcher string
func (DomainMatcher) Type() Type {
return Domain
}
func (m DomainMatcher) Pattern() string {
return string(m)
}
func (m DomainMatcher) String() string {
return "domain:" + m.Pattern()
}
func (m DomainMatcher) Match(s string) bool {
pattern := m.Pattern()
if !strings.HasSuffix(s, pattern) {
return false
}
return len(s) == len(pattern) || s[len(s)-len(pattern)-1] == '.'
}
// SubstrMatcher is an implementation of Matcher.
type SubstrMatcher string
func (SubstrMatcher) Type() Type {
return Substr
}
func (m SubstrMatcher) Pattern() string {
return string(m)
}
func (m SubstrMatcher) String() string {
return "keyword:" + m.Pattern()
}
func (m SubstrMatcher) Match(s string) bool {
return strings.Contains(s, m.Pattern())
}
// RegexMatcher is an implementation of Matcher.
type RegexMatcher struct {
pattern *regexp.Regexp
}
func (*RegexMatcher) Type() Type {
return Regex
}
func (m *RegexMatcher) Pattern() string {
return m.pattern.String()
}
func (m *RegexMatcher) String() string {
return "regexp:" + m.Pattern()
}
func (m *RegexMatcher) Match(s string) bool {
return m.pattern.MatchString(s)
}
// New creates a new Matcher based on the given pattern.
func (t Type) New(pattern string) (Matcher, error) {
switch t {
case Full:
return FullMatcher(pattern), nil
case Substr:
return SubstrMatcher(pattern), nil
case Domain:
pattern, err := ToDomain(pattern)
if err != nil {
return nil, err
}
return DomainMatcher(pattern), nil
case Regex: // 1. regex matching is case-sensitive
regex, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
return &RegexMatcher{pattern: regex}, nil
default:
return nil, errors.New("unknown matcher type")
}
}
// NewDomainPattern creates a new Matcher based on the given domain pattern.
// It works like `Type.New`, but will do validation and conversion to ensure it's a valid domain pattern.
func (t Type) NewDomainPattern(pattern string) (Matcher, error) {
switch t {
case Full:
pattern, err := ToDomain(pattern)
if err != nil {
return nil, err
}
return FullMatcher(pattern), nil
case Substr:
pattern, err := ToDomain(pattern)
if err != nil {
return nil, err
}
return SubstrMatcher(pattern), nil
case Domain:
pattern, err := ToDomain(pattern)
if err != nil {
return nil, err
}
return DomainMatcher(pattern), nil
case Regex: // Regex's charset not in LDH subset
regex, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
return &RegexMatcher{pattern: regex}, nil
default:
return nil, errors.New("unknown matcher type")
}
}
// ToDomain converts input pattern to a domain string, and return error if such a conversion cannot be made.
// 1. Conforms to Letter-Digit-Hyphen (LDH) subset (https://tools.ietf.org/html/rfc952):
// * Letters A to Z (no distinction between uppercase and lowercase, we convert to lowers)
// * Digits 0 to 9
// * Hyphens(-) and Periods(.)
// 2. If any non-ASCII characters, domain are converted from Internationalized domain name to Punycode.
func ToDomain(pattern string) (string, error) {
for {
isASCII, hasUpper := true, false
for i := 0; i < len(pattern); i++ {
c := pattern[i]
if c >= utf8.RuneSelf {
isASCII = false
break
}
switch {
case 'A' <= c && c <= 'Z':
hasUpper = true
case 'a' <= c && c <= 'z':
case '0' <= c && c <= '9':
case c == '-':
case c == '.':
default:
return "", errors.New("pattern string does not conform to Letter-Digit-Hyphen (LDH) subset")
}
}
if !isASCII {
var err error
pattern, err = idna.Punycode.ToASCII(pattern)
if err != nil {
return "", err
}
continue
}
if hasUpper {
pattern = strings.ToLower(pattern)
}
break
}
return pattern, nil
}
// MatcherGroupForAll is an interface indicating a MatcherGroup could accept all types of matchers.
type MatcherGroupForAll interface {
AddMatcher(matcher Matcher, value uint32)
}
// MatcherGroupForFull is an interface indicating a MatcherGroup could accept FullMatchers.
type MatcherGroupForFull interface {
AddFullMatcher(matcher FullMatcher, value uint32)
}
// MatcherGroupForDomain is an interface indicating a MatcherGroup could accept DomainMatchers.
type MatcherGroupForDomain interface {
AddDomainMatcher(matcher DomainMatcher, value uint32)
}
// MatcherGroupForSubstr is an interface indicating a MatcherGroup could accept SubstrMatchers.
type MatcherGroupForSubstr interface {
AddSubstrMatcher(matcher SubstrMatcher, value uint32)
}
// MatcherGroupForRegex is an interface indicating a MatcherGroup could accept RegexMatchers.
type MatcherGroupForRegex interface {
AddRegexMatcher(matcher *RegexMatcher, value uint32)
}
// AddMatcherToGroup is a helper function to try to add a Matcher to any kind of MatcherGroup.
// It returns error if the MatcherGroup does not accept the provided Matcher's type.
// This function is provided to help writing code to test a MatcherGroup.
func AddMatcherToGroup(g MatcherGroup, matcher Matcher, value uint32) error {
if g, ok := g.(IndexMatcher); ok {
g.Add(matcher)
return nil
}
if g, ok := g.(MatcherGroupForAll); ok {
g.AddMatcher(matcher, value)
return nil
}
switch matcher := matcher.(type) {
case FullMatcher:
if g, ok := g.(MatcherGroupForFull); ok {
g.AddFullMatcher(matcher, value)
return nil
}
case DomainMatcher:
if g, ok := g.(MatcherGroupForDomain); ok {
g.AddDomainMatcher(matcher, value)
return nil
}
case SubstrMatcher:
if g, ok := g.(MatcherGroupForSubstr); ok {
g.AddSubstrMatcher(matcher, value)
return nil
}
case *RegexMatcher:
if g, ok := g.(MatcherGroupForRegex); ok {
g.AddRegexMatcher(matcher, value)
return nil
}
}
return errors.New("cannot add matcher to matcher group")
}
// CompositeMatches flattens the matches slice to produce a single matched indices slice.
// It is designed to avoid new memory allocation as possible.
func CompositeMatches(matches [][]uint32) []uint32 {
switch len(matches) {
case 0:
return nil
case 1:
return matches[0]
default:
result := make([]uint32, 0, 5)
for i := 0; i < len(matches); i++ {
result = append(result, matches[i]...)
}
return result
}
}
// CompositeMatches flattens the matches slice to produce a single matched indices slice.
// It is designed that:
// 1. All matchers are concatenated in reverse order, so the matcher that matches further ranks higher.
// 2. Indices in the same matcher keeps their original order.
// 3. Avoid new memory allocation as possible.
func CompositeMatchesReverse(matches [][]uint32) []uint32 {
switch len(matches) {
case 0:
return nil
case 1:
return matches[0]
default:
result := make([]uint32, 0, 5)
for i := len(matches) - 1; i >= 0; i-- {
result = append(result, matches[i]...)
}
return result
}
}

View File

@@ -0,0 +1,149 @@
package strmatcher_test
import (
"reflect"
"testing"
"unsafe"
"github.com/xtls/xray-core/common"
. "github.com/xtls/xray-core/common/geodata/strmatcher"
)
func TestMatcher(t *testing.T) {
cases := []struct {
pattern string
mType Type
input string
output bool
}{
{
pattern: "example.com",
mType: Domain,
input: "www.example.com",
output: true,
},
{
pattern: "example.com",
mType: Domain,
input: "example.com",
output: true,
},
{
pattern: "example.com",
mType: Domain,
input: "www.e3ample.com",
output: false,
},
{
pattern: "example.com",
mType: Domain,
input: "xample.com",
output: false,
},
{
pattern: "example.com",
mType: Domain,
input: "xexample.com",
output: false,
},
{
pattern: "example.com",
mType: Full,
input: "example.com",
output: true,
},
{
pattern: "example.com",
mType: Full,
input: "xexample.com",
output: false,
},
{
pattern: "example.com",
mType: Regex,
input: "examplexcom",
output: true,
},
}
for _, test := range cases {
matcher, err := test.mType.New(test.pattern)
common.Must(err)
if m := matcher.Match(test.input); m != test.output {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}
func TestToDomain(t *testing.T) {
{ // Test normal ASCII domain, which should not trigger new string data allocation
input := "example.com"
domain, err := ToDomain(input)
if err != nil {
t.Error("unexpected error: ", err)
}
if domain != input {
t.Error("unexpected output: ", domain, " for test case ", input)
}
if (*reflect.StringHeader)(unsafe.Pointer(&input)).Data != (*reflect.StringHeader)(unsafe.Pointer(&domain)).Data {
t.Error("different string data of output: ", domain, " and test case ", input)
}
}
{ // Test ASCII domain containing upper case letter, which should be converted to lower case
input := "eXAMPLE.cOm"
domain, err := ToDomain(input)
if err != nil {
t.Error("unexpected error: ", err)
}
if domain != "example.com" {
t.Error("unexpected output: ", domain, " for test case ", input)
}
}
{ // Test internationalized domain, which should be translated to ASCII punycode
input := "example.公益"
domain, err := ToDomain(input)
if err != nil {
t.Error("unexpected error: ", err)
}
if domain != "example.xn--55qw42g" {
t.Error("unexpected output: ", domain, " for test case ", input)
}
}
{ // Test internationalized domain containing upper case letter
input := "eXAMPLE.公益"
domain, err := ToDomain(input)
if err != nil {
t.Error("unexpected error: ", err)
}
if domain != "example.xn--55qw42g" {
t.Error("unexpected output: ", domain, " for test case ", input)
}
}
{ // Test domain name of invalid character, which should return with error
input := "{"
_, err := ToDomain(input)
if err == nil {
t.Error("unexpected non error for test case ", input)
}
}
{ // Test domain name containing a space, which should return with error
input := "Mijia Cloud"
_, err := ToDomain(input)
if err == nil {
t.Error("unexpected non error for test case ", input)
}
}
{ // Test domain name containing an underscore, which should return with error
input := "Mijia_Cloud.com"
_, err := ToDomain(input)
if err == nil {
t.Error("unexpected non error for test case ", input)
}
}
{ // Test internationalized domain containing invalid character
input := "Mijia Cloud.公司"
_, err := ToDomain(input)
if err == nil {
t.Error("unexpected non error for test case ", input)
}
}
}

View File

@@ -0,0 +1,101 @@
package strmatcher
// Type is the type of the matcher.
type Type byte
const (
// Full is the type of matcher that the input string must exactly equal to the pattern.
Full Type = 0
// Domain is the type of matcher that the input string must be a sub-domain or itself of the pattern.
Domain Type = 1
// Substr is the type of matcher that the input string must contain the pattern as a sub-string.
Substr Type = 2
// Regex is the type of matcher that the input string must matches the regular-expression pattern.
Regex Type = 3
)
// Matcher is the interface to determine a string matches a pattern.
// - This is a basic matcher to represent a certain kind of match semantic(full, substr, domain or regex).
type Matcher interface {
// Type returns the matcher's type.
Type() Type
// Pattern returns the matcher's raw string representation.
Pattern() string
// String returns a string representation of the matcher containing its type and pattern.
String() string
// Match returns true if the given string matches a predefined pattern.
// * This method is seldom used for performance reason
// and is generally taken over by their corresponding MatcherGroup.
Match(input string) bool
}
// MatcherGroup is an advanced type of matcher to accept a bunch of basic Matchers (of certain type, not all matcher types).
// For example:
// - FullMatcherGroup accepts FullMatcher and uses a hash table to facilitate lookup.
// - DomainMatcherGroup accepts DomainMatcher and uses a trie to optimize both memory consumption and lookup speed.
type MatcherGroup interface {
// Match returns all matched matchers with their corresponding values.
Match(input string) []uint32
// MatchAny returns true as soon as one matching matcher is found.
MatchAny(input string) bool
}
// IndexMatcher is a general type of matcher thats accepts all kinds of basic matchers.
// It should:
// - Accept all Matcher types with no exception.
// - Optimize string matching with a combination of MatcherGroups.
// - Obey certain priority order specification when returning matched Matchers.
type IndexMatcher interface {
// Size returns number of matchers added to IndexMatcher.
Size() uint32
// Add adds a new Matcher to IndexMatcher, and returns its index. The index will never be 0.
Add(matcher Matcher) uint32
// Build builds the IndexMatcher to be ready for matching.
Build() error
// Match returns the indices of all matchers that matches the input.
// * Empty array is returned if no such matcher exists.
// * The order of returned matchers should follow priority specification.
// Priority specification:
// 1. Priority between matcher types: full > domain > substr > regex.
// 2. Priority of same-priority matchers matching at same position: the early added takes precedence.
// 3. Priority of domain matchers matching at different levels: the further matched domain takes precedence.
// 4. Priority of substr matchers matching at different positions: the further matched substr takes precedence.
Match(input string) []uint32
// MatchAny returns true as soon as one matching matcher is found.
MatchAny(input string) bool
}
// ValueMatcher is a general type of matcher that accepts all kinds of basic matchers.
// It should:
// - Accept all Matcher types with no exception.
// - Optimize string matching with a combination of MatcherGroups.
// - Obey certain priority order specification when returning matched values.
type ValueMatcher interface {
// Add adds a new Matcher to ValueMatcher, binding it to the provided value.
Add(matcher Matcher, value uint32)
// Build builds the ValueMatcher to be ready for matching.
Build() error
// Match returns the values of all matchers that matches the input.
// * Empty array is returned if no such matcher exists.
// * The order of returned values should follow priority specification.
// * Same value may appear multiple times if multiple matched matchers were added with that value.
// Priority specification:
// 1. Priority between matcher types: full > domain > substr > regex.
// 2. Priority of same-priority matchers matching at same position: the early added takes precedence.
// 3. Priority of domain matchers matching at different levels: the further matched domain takes precedence.
// 4. Priority of substr matchers matching at different positions: the further matched substr takes precedence.
Match(input string) []uint32
// MatchAny returns true as soon as one matching matcher is found.
MatchAny(input string) bool
}

View File

@@ -0,0 +1,85 @@
package strmatcher
// LinearValueMatcher is an implementation of ValueMatcher.
type LinearValueMatcher struct {
full *FullMatcherGroup
domain *DomainMatcherGroup
substr *SubstrMatcherGroup
regex *SimpleMatcherGroup
}
func NewLinearValueMatcher() *LinearValueMatcher {
return new(LinearValueMatcher)
}
// Add implements ValueMatcher.Add.
func (g *LinearValueMatcher) Add(matcher Matcher, value uint32) {
switch matcher := matcher.(type) {
case FullMatcher:
if g.full == nil {
g.full = NewFullMatcherGroup()
}
g.full.AddFullMatcher(matcher, value)
case DomainMatcher:
if g.domain == nil {
g.domain = NewDomainMatcherGroup()
}
g.domain.AddDomainMatcher(matcher, value)
case SubstrMatcher:
if g.substr == nil {
g.substr = new(SubstrMatcherGroup)
}
g.substr.AddSubstrMatcher(matcher, value)
default:
if g.regex == nil {
g.regex = new(SimpleMatcherGroup)
}
g.regex.AddMatcher(matcher, value)
}
}
// Build implements ValueMatcher.Build.
func (*LinearValueMatcher) Build() error {
return nil
}
// Match implements ValueMatcher.Match.
func (g *LinearValueMatcher) Match(input string) []uint32 {
// Allocate capacity to prevent matches escaping to heap
result := make([][]uint32, 0, 5)
if g.full != nil {
if matches := g.full.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
if g.domain != nil {
if matches := g.domain.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
if g.substr != nil {
if matches := g.substr.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
if g.regex != nil {
if matches := g.regex.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
return CompositeMatches(result)
}
// MatchAny implements ValueMatcher.MatchAny.
func (g *LinearValueMatcher) MatchAny(input string) bool {
if g.full != nil && g.full.MatchAny(input) {
return true
}
if g.domain != nil && g.domain.MatchAny(input) {
return true
}
if g.substr != nil && g.substr.MatchAny(input) {
return true
}
return g.regex != nil && g.regex.MatchAny(input)
}

View File

@@ -0,0 +1,89 @@
package strmatcher
import "runtime"
// A MphValueMatcher is divided into three parts:
// 1. `full` and `domain` patterns are matched by Rabin-Karp algorithm and minimal perfect hash table;
// 2. `substr` patterns are matched by ac automaton;
// 3. `regex` patterns are matched with the regex library.
type MphValueMatcher struct {
mph *MphMatcherGroup
ac *ACAutomatonMatcherGroup
regex *SimpleMatcherGroup
}
func NewMphValueMatcher() *MphValueMatcher {
return new(MphValueMatcher)
}
// Add implements ValueMatcher.Add.
func (g *MphValueMatcher) Add(matcher Matcher, value uint32) {
switch matcher := matcher.(type) {
case FullMatcher:
if g.mph == nil {
g.mph = NewMphMatcherGroup()
}
g.mph.AddFullMatcher(matcher, value)
case DomainMatcher:
if g.mph == nil {
g.mph = NewMphMatcherGroup()
}
g.mph.AddDomainMatcher(matcher, value)
case SubstrMatcher:
if g.ac == nil {
g.ac = NewACAutomatonMatcherGroup()
}
g.ac.AddSubstrMatcher(matcher, value)
case *RegexMatcher:
if g.regex == nil {
g.regex = &SimpleMatcherGroup{}
}
g.regex.AddMatcher(matcher, value)
}
}
// Build implements ValueMatcher.Build.
func (g *MphValueMatcher) Build() error {
if g.mph != nil {
runtime.GC() // peak mem
g.mph.Build()
}
runtime.GC() // peak mem
if g.ac != nil {
g.ac.Build()
runtime.GC() // peak mem
}
return nil
}
// Match implements ValueMatcher.Match.
func (g *MphValueMatcher) Match(input string) []uint32 {
result := make([][]uint32, 0, 5)
if g.mph != nil {
if matches := g.mph.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
if g.ac != nil {
if matches := g.ac.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
if g.regex != nil {
if matches := g.regex.Match(input); len(matches) > 0 {
result = append(result, matches)
}
}
return CompositeMatches(result)
}
// MatchAny implements ValueMatcher.MatchAny.
func (g *MphValueMatcher) MatchAny(input string) bool {
if g.mph != nil && g.mph.MatchAny(input) {
return true
}
if g.ac != nil && g.ac.MatchAny(input) {
return true
}
return g.regex != nil && g.regex.MatchAny(input)
}

View File

@@ -24,8 +24,6 @@ const (
XUDPBaseKey = "xray.xudp.basekey"
TunFdKey = "xray.tun.fd"
MphCachePath = "xray.mph.cache"
)
type EnvFlag struct {

View File

@@ -1,247 +0,0 @@
package strmatcher
import (
"container/list"
)
const validCharCount = 53
type MatchType struct {
Type Type
Exist bool
}
const (
TrieEdge bool = true
FailEdge bool = false
)
type Edge struct {
Type bool
NextNode int
}
type ACAutomaton struct {
Trie [][validCharCount]Edge
Fail []int
Exists []MatchType
Count int
}
func newNode() [validCharCount]Edge {
var s [validCharCount]Edge
for i := range s {
s[i] = Edge{
Type: FailEdge,
NextNode: 0,
}
}
return s
}
var char2Index = []int{
'A': 0,
'a': 0,
'B': 1,
'b': 1,
'C': 2,
'c': 2,
'D': 3,
'd': 3,
'E': 4,
'e': 4,
'F': 5,
'f': 5,
'G': 6,
'g': 6,
'H': 7,
'h': 7,
'I': 8,
'i': 8,
'J': 9,
'j': 9,
'K': 10,
'k': 10,
'L': 11,
'l': 11,
'M': 12,
'm': 12,
'N': 13,
'n': 13,
'O': 14,
'o': 14,
'P': 15,
'p': 15,
'Q': 16,
'q': 16,
'R': 17,
'r': 17,
'S': 18,
's': 18,
'T': 19,
't': 19,
'U': 20,
'u': 20,
'V': 21,
'v': 21,
'W': 22,
'w': 22,
'X': 23,
'x': 23,
'Y': 24,
'y': 24,
'Z': 25,
'z': 25,
'!': 26,
'$': 27,
'&': 28,
'\'': 29,
'(': 30,
')': 31,
'*': 32,
'+': 33,
',': 34,
';': 35,
'=': 36,
':': 37,
'%': 38,
'-': 39,
'.': 40,
'_': 41,
'~': 42,
'0': 43,
'1': 44,
'2': 45,
'3': 46,
'4': 47,
'5': 48,
'6': 49,
'7': 50,
'8': 51,
'9': 52,
}
func NewACAutomaton() *ACAutomaton {
ac := new(ACAutomaton)
ac.Trie = append(ac.Trie, newNode())
ac.Fail = append(ac.Fail, 0)
ac.Exists = append(ac.Exists, MatchType{
Type: Full,
Exist: false,
})
return ac
}
func (ac *ACAutomaton) Add(domain string, t Type) {
node := 0
for i := len(domain) - 1; i >= 0; i-- {
idx := char2Index[domain[i]]
if ac.Trie[node][idx].NextNode == 0 {
ac.Count++
if len(ac.Trie) < ac.Count+1 {
ac.Trie = append(ac.Trie, newNode())
ac.Fail = append(ac.Fail, 0)
ac.Exists = append(ac.Exists, MatchType{
Type: Full,
Exist: false,
})
}
ac.Trie[node][idx] = Edge{
Type: TrieEdge,
NextNode: ac.Count,
}
}
node = ac.Trie[node][idx].NextNode
}
ac.Exists[node] = MatchType{
Type: t,
Exist: true,
}
switch t {
case Domain:
ac.Exists[node] = MatchType{
Type: Full,
Exist: true,
}
idx := char2Index['.']
if ac.Trie[node][idx].NextNode == 0 {
ac.Count++
if len(ac.Trie) < ac.Count+1 {
ac.Trie = append(ac.Trie, newNode())
ac.Fail = append(ac.Fail, 0)
ac.Exists = append(ac.Exists, MatchType{
Type: Full,
Exist: false,
})
}
ac.Trie[node][idx] = Edge{
Type: TrieEdge,
NextNode: ac.Count,
}
}
node = ac.Trie[node][idx].NextNode
ac.Exists[node] = MatchType{
Type: t,
Exist: true,
}
default:
break
}
}
func (ac *ACAutomaton) Build() {
queue := list.New()
for i := 0; i < validCharCount; i++ {
if ac.Trie[0][i].NextNode != 0 {
queue.PushBack(ac.Trie[0][i])
}
}
for {
front := queue.Front()
if front == nil {
break
} else {
node := front.Value.(Edge).NextNode
queue.Remove(front)
for i := 0; i < validCharCount; i++ {
if ac.Trie[node][i].NextNode != 0 {
ac.Fail[ac.Trie[node][i].NextNode] = ac.Trie[ac.Fail[node]][i].NextNode
queue.PushBack(ac.Trie[node][i])
} else {
ac.Trie[node][i] = Edge{
Type: FailEdge,
NextNode: ac.Trie[ac.Fail[node]][i].NextNode,
}
}
}
}
}
}
func (ac *ACAutomaton) Match(s string) bool {
node := 0
fullMatch := true
// 1. the match string is all through trie edge. FULL MATCH or DOMAIN
// 2. the match string is through a fail edge. NOT FULL MATCH
// 2.1 Through a fail edge, but there exists a valid node. SUBSTR
for i := len(s) - 1; i >= 0; i-- {
chr := int(s[i])
if chr >= len(char2Index) {
return false
}
idx := char2Index[chr]
fullMatch = fullMatch && ac.Trie[node][idx].Type
node = ac.Trie[node][idx].NextNode
switch ac.Exists[node].Type {
case Substr:
return true
case Domain:
if fullMatch {
return true
}
default:
break
}
}
return fullMatch && ac.Exists[node].Exist
}

View File

@@ -1,62 +0,0 @@
package strmatcher_test
import (
"strconv"
"testing"
"github.com/xtls/xray-core/common"
. "github.com/xtls/xray-core/common/strmatcher"
)
func BenchmarkACAutomaton(b *testing.B) {
ac := NewACAutomaton()
for i := 1; i <= 1024; i++ {
ac.Add(strconv.Itoa(i)+".xray.com", Domain)
}
ac.Build()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ac.Match("0.xray.com")
}
}
func BenchmarkDomainMatcherGroup(b *testing.B) {
g := new(DomainMatcherGroup)
for i := 1; i <= 1024; i++ {
g.Add(strconv.Itoa(i)+".example.com", uint32(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = g.Match("0.example.com")
}
}
func BenchmarkFullMatcherGroup(b *testing.B) {
g := new(FullMatcherGroup)
for i := 1; i <= 1024; i++ {
g.Add(strconv.Itoa(i)+".example.com", uint32(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = g.Match("0.example.com")
}
}
func BenchmarkMarchGroup(b *testing.B) {
g := new(MatcherGroup)
for i := 1; i <= 1024; i++ {
m, err := Domain.New(strconv.Itoa(i) + ".example.com")
common.Must(err)
g.Add(m)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = g.Match("0.example.com")
}
}

View File

@@ -1,98 +0,0 @@
package strmatcher
import "strings"
func breakDomain(domain string) []string {
return strings.Split(domain, ".")
}
type node struct {
values []uint32
sub map[string]*node
}
// DomainMatcherGroup is a IndexMatcher for a large set of Domain matchers.
// Visible for testing only.
type DomainMatcherGroup struct {
root *node
}
func (g *DomainMatcherGroup) Add(domain string, value uint32) {
if g.root == nil {
g.root = new(node)
}
current := g.root
parts := breakDomain(domain)
for i := len(parts) - 1; i >= 0; i-- {
part := parts[i]
if current.sub == nil {
current.sub = make(map[string]*node)
}
next := current.sub[part]
if next == nil {
next = new(node)
current.sub[part] = next
}
current = next
}
current.values = append(current.values, value)
}
func (g *DomainMatcherGroup) addMatcher(m domainMatcher, value uint32) {
g.Add(string(m), value)
}
func (g *DomainMatcherGroup) Match(domain string) []uint32 {
if domain == "" {
return nil
}
current := g.root
if current == nil {
return nil
}
nextPart := func(idx int) int {
for i := idx - 1; i >= 0; i-- {
if domain[i] == '.' {
return i
}
}
return -1
}
matches := [][]uint32{}
idx := len(domain)
for {
if idx == -1 || current.sub == nil {
break
}
nidx := nextPart(idx)
part := domain[nidx+1 : idx]
next := current.sub[part]
if next == nil {
break
}
current = next
idx = nidx
if len(current.values) > 0 {
matches = append(matches, current.values)
}
}
switch len(matches) {
case 0:
return nil
case 1:
return matches[0]
default:
result := []uint32{}
for idx := range matches {
// Insert reversely, the subdomain that matches further ranks higher
result = append(result, matches[len(matches)-1-idx]...)
}
return result
}
}

View File

@@ -1,25 +0,0 @@
package strmatcher
type FullMatcherGroup struct {
matchers map[string][]uint32
}
func (g *FullMatcherGroup) Add(domain string, value uint32) {
if g.matchers == nil {
g.matchers = make(map[string][]uint32)
}
g.matchers[domain] = append(g.matchers[domain], value)
}
func (g *FullMatcherGroup) addMatcher(m fullMatcher, value uint32) {
g.Add(string(m), value)
}
func (g *FullMatcherGroup) Match(str string) []uint32 {
if g.matchers == nil {
return nil
}
return g.matchers[str]
}

View File

@@ -1,56 +0,0 @@
package strmatcher
import (
"regexp"
"strings"
)
type fullMatcher string
func (m fullMatcher) Match(s string) bool {
return string(m) == s
}
func (m fullMatcher) String() string {
return "full:" + string(m)
}
type substrMatcher string
func (m substrMatcher) Match(s string) bool {
return strings.Contains(s, string(m))
}
func (m substrMatcher) String() string {
return "keyword:" + string(m)
}
type domainMatcher string
func (m domainMatcher) Match(s string) bool {
pattern := string(m)
if !strings.HasSuffix(s, pattern) {
return false
}
return len(s) == len(pattern) || s[len(s)-len(pattern)-1] == '.'
}
func (m domainMatcher) String() string {
return "domain:" + string(m)
}
type RegexMatcher struct {
Pattern string
reg *regexp.Regexp
}
func (m *RegexMatcher) Match(s string) bool {
if m.reg == nil {
m.reg = regexp.MustCompile(m.Pattern)
}
return m.reg.MatchString(s)
}
func (m *RegexMatcher) String() string {
return "regexp:" + m.Pattern
}

View File

@@ -1,73 +0,0 @@
package strmatcher_test
import (
"testing"
"github.com/xtls/xray-core/common"
. "github.com/xtls/xray-core/common/strmatcher"
)
func TestMatcher(t *testing.T) {
cases := []struct {
pattern string
mType Type
input string
output bool
}{
{
pattern: "example.com",
mType: Domain,
input: "www.example.com",
output: true,
},
{
pattern: "example.com",
mType: Domain,
input: "example.com",
output: true,
},
{
pattern: "example.com",
mType: Domain,
input: "www.fxample.com",
output: false,
},
{
pattern: "example.com",
mType: Domain,
input: "xample.com",
output: false,
},
{
pattern: "example.com",
mType: Domain,
input: "xexample.com",
output: false,
},
{
pattern: "example.com",
mType: Full,
input: "example.com",
output: true,
},
{
pattern: "example.com",
mType: Full,
input: "xexample.com",
output: false,
},
{
pattern: "example.com",
mType: Regex,
input: "examplexcom",
output: true,
},
}
for _, test := range cases {
matcher, err := test.mType.New(test.pattern)
common.Must(err)
if m := matcher.Match(test.input); m != test.output {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}

View File

@@ -1,308 +0,0 @@
package strmatcher
import (
"math/bits"
"regexp"
"sort"
"strings"
"unsafe"
)
// PrimeRK is the prime base used in Rabin-Karp algorithm.
const PrimeRK = 16777619
// calculate the rolling murmurHash of given string
func RollingHash(s string) uint32 {
h := uint32(0)
for i := len(s) - 1; i >= 0; i-- {
h = h*PrimeRK + uint32(s[i])
}
return h
}
// A MphMatcherGroup is divided into three parts:
// 1. `full` and `domain` patterns are matched by Rabin-Karp algorithm and minimal perfect hash table;
// 2. `substr` patterns are matched by ac automaton;
// 3. `regex` patterns are matched with the regex library.
type MphMatcherGroup struct {
Ac *ACAutomaton
OtherMatchers []MatcherEntry
Rules []string
Level0 []uint32
Level0Mask int
Level1 []uint32
Level1Mask int
Count uint32
RuleMap *map[string]uint32
}
func (g *MphMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) {
h := RollingHash(pattern)
switch t {
case Domain:
(*g.RuleMap)["."+pattern] = h*PrimeRK + uint32('.')
fallthrough
case Full:
(*g.RuleMap)[pattern] = h
default:
}
}
func NewMphMatcherGroup() *MphMatcherGroup {
return &MphMatcherGroup{
Ac: nil,
OtherMatchers: nil,
Rules: nil,
Level0: nil,
Level0Mask: 0,
Level1: nil,
Level1Mask: 0,
Count: 1,
RuleMap: &map[string]uint32{},
}
}
// AddPattern adds a pattern to MphMatcherGroup
func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) {
switch t {
case Substr:
if g.Ac == nil {
g.Ac = NewACAutomaton()
}
g.Ac.Add(pattern, t)
case Full, Domain:
pattern = strings.ToLower(pattern)
g.AddFullOrDomainPattern(pattern, t)
case Regex:
r, err := regexp.Compile(pattern)
if err != nil {
return 0, err
}
g.OtherMatchers = append(g.OtherMatchers, MatcherEntry{
M: &RegexMatcher{Pattern: pattern, reg: r},
Id: g.Count,
})
default:
panic("Unknown type")
}
return g.Count, nil
}
// Build builds a minimal perfect hash table and ac automaton from insert rules
func (g *MphMatcherGroup) Build() {
if g.Ac != nil {
g.Ac.Build()
}
keyLen := len(*g.RuleMap)
if keyLen == 0 {
keyLen = 1
(*g.RuleMap)["empty___"] = RollingHash("empty___")
}
g.Level0 = make([]uint32, nextPow2(keyLen/4))
g.Level0Mask = len(g.Level0) - 1
g.Level1 = make([]uint32, nextPow2(keyLen))
g.Level1Mask = len(g.Level1) - 1
sparseBuckets := make([][]int, len(g.Level0))
var ruleIdx int
for rule, hash := range *g.RuleMap {
n := int(hash) & g.Level0Mask
g.Rules = append(g.Rules, rule)
sparseBuckets[n] = append(sparseBuckets[n], ruleIdx)
ruleIdx++
}
g.RuleMap = nil
var buckets []indexBucket
for n, vals := range sparseBuckets {
if len(vals) > 0 {
buckets = append(buckets, indexBucket{n, vals})
}
}
sort.Sort(bySize(buckets))
occ := make([]bool, len(g.Level1))
var tmpOcc []int
for _, bucket := range buckets {
seed := uint32(0)
for {
findSeed := true
tmpOcc = tmpOcc[:0]
for _, i := range bucket.vals {
n := int(strhashFallback(unsafe.Pointer(&g.Rules[i]), uintptr(seed))) & g.Level1Mask
if occ[n] {
for _, n := range tmpOcc {
occ[n] = false
}
seed++
findSeed = false
break
}
occ[n] = true
tmpOcc = append(tmpOcc, n)
g.Level1[n] = uint32(i)
}
if findSeed {
g.Level0[bucket.n] = seed
break
}
}
}
}
func nextPow2(v int) int {
if v <= 1 {
return 1
}
const MaxUInt = ^uint(0)
n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1
return int(n)
}
// Lookup searches for s in t and returns its index and whether it was found.
func (g *MphMatcherGroup) Lookup(h uint32, s string) bool {
i0 := int(h) & g.Level0Mask
seed := g.Level0[i0]
i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.Level1Mask
n := g.Level1[i1]
return s == g.Rules[int(n)]
}
// Match implements IndexMatcher.Match.
func (g *MphMatcherGroup) Match(pattern string) []uint32 {
result := []uint32{}
hash := uint32(0)
for i := len(pattern) - 1; i >= 0; i-- {
hash = hash*PrimeRK + uint32(pattern[i])
if pattern[i] == '.' {
if g.Lookup(hash, pattern[i:]) {
result = append(result, 1)
return result
}
}
}
if g.Lookup(hash, pattern) {
result = append(result, 1)
return result
}
if g.Ac != nil && g.Ac.Match(pattern) {
result = append(result, 1)
return result
}
for _, e := range g.OtherMatchers {
if e.M.Match(pattern) {
result = append(result, e.Id)
return result
}
}
return nil
}
type indexBucket struct {
n int
vals []int
}
type bySize []indexBucket
func (s bySize) Len() int { return len(s) }
func (s bySize) Less(i, j int) bool { return len(s[i].vals) > len(s[j].vals) }
func (s bySize) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type stringStruct struct {
str unsafe.Pointer
len int
}
func strhashFallback(a unsafe.Pointer, h uintptr) uintptr {
x := (*stringStruct)(a)
return memhashFallback(x.str, h, uintptr(x.len))
}
const (
// Constants for multiplication: four random odd 64-bit numbers.
m1 = 16877499708836156737
m2 = 2820277070424839065
m3 = 9497967016996688599
m4 = 15839092249703872147
)
var hashkey = [4]uintptr{1, 1, 1, 1}
func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr {
h := uint64(seed + s*hashkey[0])
tail:
switch {
case s == 0:
case s < 4:
h ^= uint64(*(*byte)(p))
h ^= uint64(*(*byte)(add(p, s>>1))) << 8
h ^= uint64(*(*byte)(add(p, s-1))) << 16
h = rotl31(h*m1) * m2
case s <= 8:
h ^= uint64(readUnaligned32(p))
h ^= uint64(readUnaligned32(add(p, s-4))) << 32
h = rotl31(h*m1) * m2
case s <= 16:
h ^= readUnaligned64(p)
h = rotl31(h*m1) * m2
h ^= readUnaligned64(add(p, s-8))
h = rotl31(h*m1) * m2
case s <= 32:
h ^= readUnaligned64(p)
h = rotl31(h*m1) * m2
h ^= readUnaligned64(add(p, 8))
h = rotl31(h*m1) * m2
h ^= readUnaligned64(add(p, s-16))
h = rotl31(h*m1) * m2
h ^= readUnaligned64(add(p, s-8))
h = rotl31(h*m1) * m2
default:
v1 := h
v2 := uint64(seed * hashkey[1])
v3 := uint64(seed * hashkey[2])
v4 := uint64(seed * hashkey[3])
for s >= 32 {
v1 ^= readUnaligned64(p)
v1 = rotl31(v1*m1) * m2
p = add(p, 8)
v2 ^= readUnaligned64(p)
v2 = rotl31(v2*m2) * m3
p = add(p, 8)
v3 ^= readUnaligned64(p)
v3 = rotl31(v3*m3) * m4
p = add(p, 8)
v4 ^= readUnaligned64(p)
v4 = rotl31(v4*m4) * m1
p = add(p, 8)
s -= 32
}
h = v1 ^ v2 ^ v3 ^ v4
goto tail
}
h ^= h >> 29
h *= m3
h ^= h >> 32
return uintptr(h)
}
func add(p unsafe.Pointer, x uintptr) unsafe.Pointer {
return unsafe.Pointer(uintptr(p) + x)
}
func readUnaligned32(p unsafe.Pointer) uint32 {
q := (*[4]byte)(p)
return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24
}
func rotl31(x uint64) uint64 {
return (x << 31) | (x >> (64 - 31))
}
func readUnaligned64(p unsafe.Pointer) uint64 {
q := (*[8]byte)(p)
return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56
}
func (g *MphMatcherGroup) Size() uint32 {
return g.Count
}

View File

@@ -1,47 +0,0 @@
package strmatcher
import (
"bytes"
"encoding/gob"
"io"
)
func init() {
gob.Register(&RegexMatcher{})
gob.Register(fullMatcher(""))
gob.Register(substrMatcher(""))
gob.Register(domainMatcher(""))
}
func (g *MphMatcherGroup) Serialize(w io.Writer) error {
data := MphMatcherGroup{
Ac: g.Ac,
OtherMatchers: g.OtherMatchers,
Rules: g.Rules,
Level0: g.Level0,
Level0Mask: g.Level0Mask,
Level1: g.Level1,
Level1Mask: g.Level1Mask,
Count: g.Count,
}
return gob.NewEncoder(w).Encode(data)
}
func NewMphMatcherGroupFromBuffer(data []byte) (*MphMatcherGroup, error) {
var gData MphMatcherGroup
if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&gData); err != nil {
return nil, err
}
g := NewMphMatcherGroup()
g.Ac = gData.Ac
g.OtherMatchers = gData.OtherMatchers
g.Rules = gData.Rules
g.Level0 = gData.Level0
g.Level0Mask = gData.Level0Mask
g.Level1 = gData.Level1
g.Level1Mask = gData.Level1Mask
g.Count = gData.Count
return g, nil
}

View File

@@ -1,141 +0,0 @@
package strmatcher
import (
"errors"
"regexp"
)
// Matcher is the interface to determine a string matches a pattern.
type Matcher interface {
// Match returns true if the given string matches a predefined pattern.
Match(string) bool
String() string
}
// Type is the type of the matcher.
type Type byte
const (
// Full is the type of matcher that the input string must exactly equal to the pattern.
Full Type = iota
// Substr is the type of matcher that the input string must contain the pattern as a sub-string.
Substr
// Domain is the type of matcher that the input string must be a sub-domain or itself of the pattern.
Domain
// Regex is the type of matcher that the input string must matches the regular-expression pattern.
Regex
)
// New creates a new Matcher based on the given pattern.
func (t Type) New(pattern string) (Matcher, error) {
// 1. regex matching is case-sensitive
switch t {
case Full:
return fullMatcher(pattern), nil
case Substr:
return substrMatcher(pattern), nil
case Domain:
return domainMatcher(pattern), nil
case Regex:
r, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
return &RegexMatcher{
Pattern: pattern,
reg: r,
}, nil
default:
return nil, errors.New("unk type")
}
}
// IndexMatcher is the interface for matching with a group of matchers.
type IndexMatcher interface {
// Match returns the index of a matcher that matches the input. It returns empty array if no such matcher exists.
Match(input string) []uint32
// Size returns the number of matchers in the group.
Size() uint32
}
type MatcherEntry struct {
M Matcher
Id uint32
}
// MatcherGroup is an implementation of IndexMatcher.
// Empty initialization works.
type MatcherGroup struct {
count uint32
fullMatcher FullMatcherGroup
domainMatcher DomainMatcherGroup
otherMatchers []MatcherEntry
}
// Add adds a new Matcher into the MatcherGroup, and returns its index. The index will never be 0.
func (g *MatcherGroup) Add(m Matcher) uint32 {
g.count++
c := g.count
switch tm := m.(type) {
case fullMatcher:
g.fullMatcher.addMatcher(tm, c)
case domainMatcher:
g.domainMatcher.addMatcher(tm, c)
default:
g.otherMatchers = append(g.otherMatchers, MatcherEntry{
M: m,
Id: c,
})
}
return c
}
// Match implements IndexMatcher.Match.
func (g *MatcherGroup) Match(pattern string) []uint32 {
result := []uint32{}
result = append(result, g.fullMatcher.Match(pattern)...)
result = append(result, g.domainMatcher.Match(pattern)...)
for _, e := range g.otherMatchers {
if e.M.Match(pattern) {
result = append(result, e.Id)
}
}
return result
}
// Size returns the number of matchers in the MatcherGroup.
func (g *MatcherGroup) Size() uint32 {
return g.count
}
type IndexMatcherGroup struct {
Matchers []IndexMatcher
}
func (g *IndexMatcherGroup) Match(input string) []uint32 {
var offset uint32
for _, m := range g.Matchers {
if res := m.Match(input); len(res) > 0 {
if offset == 0 {
return res
}
shifted := make([]uint32, len(res))
for i, id := range res {
shifted[i] = id + offset
}
return shifted
}
offset += m.Size()
}
return nil
}
func (g *IndexMatcherGroup) Size() uint32 {
var count uint32
for _, m := range g.Matchers {
count += m.Size()
}
return count
}