mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-05-08 14:13:22 +00:00
Xray-core: Refactor geodata (#5814)
https://github.com/XTLS/Xray-core/issues/4422#issuecomment-3533007890 Breaking changes https://github.com/XTLS/Xray-core/pull/5569 Reverts https://github.com/XTLS/Xray-core/pull/5505 Closes https://github.com/XTLS/Xray-core/pull/643
This commit is contained in:
996
common/geodata/ip_matcher.go
Normal file
996
common/geodata/ip_matcher.go
Normal file
@@ -0,0 +1,996 @@
|
||||
package geodata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
type IPMatcher interface {
|
||||
// TODO: (PERF) all net.IP -> netipx.Addr
|
||||
|
||||
// Invalid IP always return false.
|
||||
Match(ip net.IP) bool
|
||||
|
||||
// Returns true if *any* IP is valid and match.
|
||||
AnyMatch(ips []net.IP) bool
|
||||
|
||||
// Returns true only if *all* IPs are valid and match. Any invalid IP, or non-matching valid IP, causes false.
|
||||
Matches(ips []net.IP) bool
|
||||
|
||||
// Filters IPs. Invalid IPs are silently dropped and not included in either result.
|
||||
FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP)
|
||||
|
||||
ToggleReverse()
|
||||
|
||||
SetReverse(reverse bool)
|
||||
}
|
||||
|
||||
type IPSet struct {
|
||||
ipv4, ipv6 *netipx.IPSet
|
||||
max4, max6 uint8
|
||||
}
|
||||
|
||||
type HeuristicIPMatcher struct {
|
||||
ipset *IPSet
|
||||
reverse bool
|
||||
}
|
||||
|
||||
type ipBucket struct {
|
||||
rep netip.Addr
|
||||
ips []net.IP
|
||||
}
|
||||
|
||||
// Match implements IPMatcher.
|
||||
func (m *HeuristicIPMatcher) Match(ip net.IP) bool {
|
||||
ipx, ok := netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return m.matchAddr(ipx)
|
||||
}
|
||||
|
||||
func (m *HeuristicIPMatcher) matchAddr(ipx netip.Addr) bool {
|
||||
if ipx.Is4() {
|
||||
if m.ipset.max4 == 0xff {
|
||||
return false
|
||||
}
|
||||
return m.ipset.ipv4.Contains(ipx) != m.reverse
|
||||
}
|
||||
if ipx.Is6() {
|
||||
if m.ipset.max6 == 0xff {
|
||||
return false
|
||||
}
|
||||
return m.ipset.ipv6.Contains(ipx) != m.reverse
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AnyMatch implements IPMatcher.
|
||||
func (m *HeuristicIPMatcher) AnyMatch(ips []net.IP) bool {
|
||||
n := len(ips)
|
||||
if n == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
return m.Match(ips[0])
|
||||
}
|
||||
|
||||
heur4 := m.ipset.max4 <= 24
|
||||
heur6 := m.ipset.max6 <= 64
|
||||
if !heur4 && !heur6 {
|
||||
for _, ip := range ips {
|
||||
if ipx, ok := netipx.FromStdIP(ip); ok {
|
||||
if m.matchAddr(ipx) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
buckets := make(map[[9]byte]struct{}, n)
|
||||
for _, ip := range ips {
|
||||
key, ok := prefixKeyFromIP(ip)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
heur := (key[0] == 4 && heur4) || (key[0] == 6 && heur6)
|
||||
if heur {
|
||||
if _, exists := buckets[key]; exists {
|
||||
continue
|
||||
}
|
||||
}
|
||||
ipx, ok := netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if m.matchAddr(ipx) {
|
||||
return true
|
||||
}
|
||||
if heur {
|
||||
buckets[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Matches implements IPMatcher.
|
||||
func (m *HeuristicIPMatcher) Matches(ips []net.IP) bool {
|
||||
n := len(ips)
|
||||
if n == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
return m.Match(ips[0])
|
||||
}
|
||||
|
||||
heur4 := m.ipset.max4 <= 24
|
||||
heur6 := m.ipset.max6 <= 64
|
||||
if !heur4 && !heur6 {
|
||||
for _, ip := range ips {
|
||||
ipx, ok := netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !m.matchAddr(ipx) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
buckets := make(map[[9]byte]netip.Addr, n)
|
||||
precise := make([]netip.Addr, 0, n)
|
||||
|
||||
for _, ip := range ips {
|
||||
key, ok := prefixKeyFromIP(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if (key[0] == 4 && heur4) || (key[0] == 6 && heur6) {
|
||||
if _, exists := buckets[key]; !exists {
|
||||
ipx, ok := netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
buckets[key] = ipx
|
||||
}
|
||||
} else {
|
||||
ipx, ok := netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
precise = append(precise, ipx)
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipx := range buckets {
|
||||
if !m.matchAddr(ipx) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, ipx := range precise {
|
||||
if !m.matchAddr(ipx) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func prefixKeyFromIP(ip net.IP) (key [9]byte, ok bool) {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
key[0] = 4
|
||||
key[1] = ip4[0]
|
||||
key[2] = ip4[1]
|
||||
key[3] = ip4[2] // /24
|
||||
return key, true
|
||||
}
|
||||
if ip16 := ip.To16(); ip16 != nil {
|
||||
key[0] = 6
|
||||
key[1] = ip16[0]
|
||||
key[2] = ip16[1]
|
||||
key[3] = ip16[2]
|
||||
key[4] = ip16[3]
|
||||
key[5] = ip16[4]
|
||||
key[6] = ip16[5]
|
||||
key[7] = ip16[6]
|
||||
key[8] = ip16[7] // /64
|
||||
return key, true
|
||||
}
|
||||
return key, false // illegal
|
||||
}
|
||||
|
||||
// FilterIPs implements IPMatcher.
|
||||
func (m *HeuristicIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
|
||||
n := len(ips)
|
||||
if n == 0 {
|
||||
return []net.IP{}, []net.IP{}
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
ipx, ok := netipx.FromStdIP(ips[0])
|
||||
if !ok {
|
||||
return []net.IP{}, []net.IP{}
|
||||
}
|
||||
if m.matchAddr(ipx) {
|
||||
return ips, []net.IP{}
|
||||
}
|
||||
return []net.IP{}, ips
|
||||
}
|
||||
|
||||
heur4 := m.ipset.max4 <= 24
|
||||
heur6 := m.ipset.max6 <= 64
|
||||
if !heur4 && !heur6 {
|
||||
matched = make([]net.IP, 0, n)
|
||||
unmatched = make([]net.IP, 0, n)
|
||||
for _, ip := range ips {
|
||||
ipx, ok := netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
continue // illegal ip, ignore
|
||||
}
|
||||
if m.matchAddr(ipx) {
|
||||
matched = append(matched, ip)
|
||||
} else {
|
||||
unmatched = append(unmatched, ip)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
buckets := make(map[[9]byte]*ipBucket, n)
|
||||
precise := make([]net.IP, 0, n)
|
||||
|
||||
for _, ip := range ips {
|
||||
key, ok := prefixKeyFromIP(ip)
|
||||
if !ok {
|
||||
continue // illegal ip, ignore
|
||||
}
|
||||
|
||||
if (key[0] == 4 && !heur4) || (key[0] == 6 && !heur6) {
|
||||
precise = append(precise, ip)
|
||||
continue
|
||||
}
|
||||
|
||||
b, exists := buckets[key]
|
||||
if !exists {
|
||||
// build bucket
|
||||
ipx, ok := netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
continue // illegal ip, ignore
|
||||
}
|
||||
b = &ipBucket{
|
||||
rep: ipx,
|
||||
ips: make([]net.IP, 0, 4), // for dns answer
|
||||
}
|
||||
buckets[key] = b
|
||||
}
|
||||
b.ips = append(b.ips, ip)
|
||||
}
|
||||
|
||||
matched = make([]net.IP, 0, n)
|
||||
unmatched = make([]net.IP, 0, n)
|
||||
for _, b := range buckets {
|
||||
if m.matchAddr(b.rep) {
|
||||
matched = append(matched, b.ips...)
|
||||
} else {
|
||||
unmatched = append(unmatched, b.ips...)
|
||||
}
|
||||
}
|
||||
for _, ip := range precise {
|
||||
ipx, ok := netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
continue // illegal ip, ignore
|
||||
}
|
||||
if m.matchAddr(ipx) {
|
||||
matched = append(matched, ip)
|
||||
} else {
|
||||
unmatched = append(unmatched, ip)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ToggleReverse implements IPMatcher.
|
||||
func (m *HeuristicIPMatcher) ToggleReverse() {
|
||||
m.reverse = !m.reverse
|
||||
}
|
||||
|
||||
// SetReverse implements IPMatcher.
|
||||
func (m *HeuristicIPMatcher) SetReverse(reverse bool) {
|
||||
m.reverse = reverse
|
||||
}
|
||||
|
||||
type GeneralMultiIPMatcher struct {
|
||||
matchers []IPMatcher
|
||||
}
|
||||
|
||||
// Match implements IPMatcher.
|
||||
func (mm *GeneralMultiIPMatcher) Match(ip net.IP) bool {
|
||||
for _, m := range mm.matchers {
|
||||
if m.Match(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AnyMatch implements IPMatcher.
|
||||
func (mm *GeneralMultiIPMatcher) AnyMatch(ips []net.IP) bool {
|
||||
for _, m := range mm.matchers {
|
||||
if m.AnyMatch(ips) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Matches implements IPMatcher.
|
||||
func (mm *GeneralMultiIPMatcher) Matches(ips []net.IP) bool {
|
||||
for _, m := range mm.matchers {
|
||||
if m.Matches(ips) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FilterIPs implements IPMatcher.
|
||||
func (mm *GeneralMultiIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
|
||||
matched = make([]net.IP, 0, len(ips))
|
||||
unmatched = ips
|
||||
for _, m := range mm.matchers {
|
||||
if len(unmatched) == 0 {
|
||||
break
|
||||
}
|
||||
var mtch []net.IP
|
||||
mtch, unmatched = m.FilterIPs(unmatched)
|
||||
if len(mtch) > 0 {
|
||||
matched = append(matched, mtch...)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ToggleReverse implements IPMatcher.
|
||||
func (mm *GeneralMultiIPMatcher) ToggleReverse() {
|
||||
for _, m := range mm.matchers {
|
||||
m.ToggleReverse()
|
||||
}
|
||||
}
|
||||
|
||||
// SetReverse implements IPMatcher.
|
||||
func (mm *GeneralMultiIPMatcher) SetReverse(reverse bool) {
|
||||
for _, m := range mm.matchers {
|
||||
m.SetReverse(reverse)
|
||||
}
|
||||
}
|
||||
|
||||
type HeuristicMultiIPMatcher struct {
|
||||
matchers []*HeuristicIPMatcher
|
||||
}
|
||||
|
||||
// Match implements IPMatcher.
|
||||
func (mm *HeuristicMultiIPMatcher) Match(ip net.IP) bool {
|
||||
ipx, ok := netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, m := range mm.matchers {
|
||||
if m.matchAddr(ipx) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AnyMatch implements IPMatcher.
|
||||
func (mm *HeuristicMultiIPMatcher) AnyMatch(ips []net.IP) bool {
|
||||
n := len(ips)
|
||||
if n == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
return mm.Match(ips[0])
|
||||
}
|
||||
|
||||
buckets := make(map[[9]byte]struct{}, n)
|
||||
for _, ip := range ips {
|
||||
var ipx netip.Addr
|
||||
state := uint8(0) // 0 = Not initialized, 1 = Initialized, 4 = IPv4 can be skipped, 6 = IPv6 can be skipped
|
||||
for _, m := range mm.matchers {
|
||||
heur4 := m.ipset.max4 <= 24
|
||||
heur6 := m.ipset.max6 <= 64
|
||||
|
||||
if state == 0 && (heur4 || heur6) {
|
||||
key, ok := prefixKeyFromIP(ip)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if _, exists := buckets[key]; exists {
|
||||
state = key[0]
|
||||
} else {
|
||||
buckets[key] = struct{}{}
|
||||
state = 1
|
||||
}
|
||||
}
|
||||
if (heur4 && state == 4) || (heur6 && state == 6) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !ipx.IsValid() {
|
||||
nipx, ok := netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
ipx = nipx
|
||||
}
|
||||
if m.matchAddr(ipx) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Matches implements IPMatcher.
|
||||
func (mm *HeuristicMultiIPMatcher) Matches(ips []net.IP) bool {
|
||||
n := len(ips)
|
||||
if n == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
return mm.Match(ips[0])
|
||||
}
|
||||
|
||||
var views ipViews
|
||||
for _, m := range mm.matchers {
|
||||
if !views.ensureForMatcher(m, ips) {
|
||||
return false
|
||||
}
|
||||
|
||||
matched := true
|
||||
if m.ipset.max4 <= 24 {
|
||||
for _, ipx := range views.buckets4 {
|
||||
if !m.matchAddr(ipx) {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, ipx := range views.precise4 {
|
||||
if !m.matchAddr(ipx) {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
continue
|
||||
}
|
||||
|
||||
if m.ipset.max6 <= 64 {
|
||||
for _, ipx := range views.buckets6 {
|
||||
if !m.matchAddr(ipx) {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, ipx := range views.precise6 {
|
||||
if !m.matchAddr(ipx) {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type ipViews struct {
|
||||
buckets4, buckets6 map[[9]byte]netip.Addr
|
||||
precise4, precise6 []netip.Addr
|
||||
}
|
||||
|
||||
func (v *ipViews) ensureForMatcher(m *HeuristicIPMatcher, ips []net.IP) bool {
|
||||
needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
|
||||
needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
|
||||
needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
|
||||
needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
|
||||
|
||||
if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
|
||||
return true
|
||||
}
|
||||
|
||||
if needHeur4 {
|
||||
v.buckets4 = make(map[[9]byte]netip.Addr, len(ips))
|
||||
}
|
||||
if needHeur6 {
|
||||
v.buckets6 = make(map[[9]byte]netip.Addr, len(ips))
|
||||
}
|
||||
if needPrec4 {
|
||||
v.precise4 = make([]netip.Addr, 0, len(ips))
|
||||
}
|
||||
if needPrec6 {
|
||||
v.precise6 = make([]netip.Addr, 0, len(ips))
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
key, ok := prefixKeyFromIP(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch key[0] {
|
||||
case 4:
|
||||
var ipx netip.Addr
|
||||
if needHeur4 {
|
||||
if _, exists := v.buckets4[key]; !exists {
|
||||
ipx, ok = netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
v.buckets4[key] = ipx
|
||||
}
|
||||
}
|
||||
if needPrec4 {
|
||||
if !ipx.IsValid() {
|
||||
ipx, ok = netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
v.precise4 = append(v.precise4, ipx)
|
||||
}
|
||||
case 6:
|
||||
var ipx netip.Addr
|
||||
if needHeur6 {
|
||||
if _, exists := v.buckets6[key]; !exists {
|
||||
ipx, ok = netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
v.buckets6[key] = ipx
|
||||
}
|
||||
}
|
||||
if needPrec6 {
|
||||
if !ipx.IsValid() {
|
||||
ipx, ok = netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
v.precise6 = append(v.precise6, ipx)
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// FilterIPs implements IPMatcher.
|
||||
func (mm *HeuristicMultiIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
|
||||
n := len(ips)
|
||||
if n == 0 {
|
||||
return []net.IP{}, []net.IP{}
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
ipx, ok := netipx.FromStdIP(ips[0])
|
||||
if !ok {
|
||||
return []net.IP{}, []net.IP{}
|
||||
}
|
||||
for _, m := range mm.matchers {
|
||||
if m.matchAddr(ipx) {
|
||||
return ips, []net.IP{}
|
||||
}
|
||||
}
|
||||
return []net.IP{}, ips
|
||||
}
|
||||
|
||||
var views ipBucketViews
|
||||
|
||||
matched = make([]net.IP, 0, n)
|
||||
for _, m := range mm.matchers {
|
||||
views.ensureForMatcher(m, ips)
|
||||
|
||||
if m.ipset.max4 <= 24 {
|
||||
for key, b := range views.buckets4 {
|
||||
if b == nil {
|
||||
continue
|
||||
}
|
||||
if m.matchAddr(b.rep) {
|
||||
views.buckets4[key] = nil
|
||||
matched = append(matched, b.ips...)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for ipx, ip := range views.precise4 {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
if m.matchAddr(ipx) {
|
||||
views.precise4[ipx] = nil
|
||||
matched = append(matched, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if m.ipset.max6 <= 64 {
|
||||
for key, b := range views.buckets6 {
|
||||
if b == nil {
|
||||
continue
|
||||
}
|
||||
if m.matchAddr(b.rep) {
|
||||
views.buckets6[key] = nil
|
||||
matched = append(matched, b.ips...)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for ipx, ip := range views.precise6 {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
if m.matchAddr(ipx) {
|
||||
views.precise6[ipx] = nil
|
||||
matched = append(matched, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unmatched = make([]net.IP, 0, n-len(matched))
|
||||
if views.buckets4 != nil {
|
||||
for _, b := range views.buckets4 {
|
||||
if b == nil {
|
||||
continue
|
||||
}
|
||||
unmatched = append(unmatched, b.ips...)
|
||||
}
|
||||
}
|
||||
if views.precise4 != nil {
|
||||
for _, ip := range views.precise4 {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
unmatched = append(unmatched, ip)
|
||||
}
|
||||
}
|
||||
if views.buckets6 != nil {
|
||||
for _, b := range views.buckets6 {
|
||||
if b == nil {
|
||||
continue
|
||||
}
|
||||
unmatched = append(unmatched, b.ips...)
|
||||
}
|
||||
}
|
||||
if views.precise6 != nil {
|
||||
for _, ip := range views.precise6 {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
unmatched = append(unmatched, ip)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type ipBucketViews struct {
|
||||
buckets4, buckets6 map[[9]byte]*ipBucket
|
||||
precise4, precise6 map[netip.Addr]net.IP
|
||||
}
|
||||
|
||||
func (v *ipBucketViews) ensureForMatcher(m *HeuristicIPMatcher, ips []net.IP) {
|
||||
needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
|
||||
needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
|
||||
needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
|
||||
needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
|
||||
|
||||
if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
|
||||
return
|
||||
}
|
||||
|
||||
if needHeur4 {
|
||||
v.buckets4 = make(map[[9]byte]*ipBucket, len(ips))
|
||||
}
|
||||
if needHeur6 {
|
||||
v.buckets6 = make(map[[9]byte]*ipBucket, len(ips))
|
||||
}
|
||||
if needPrec4 {
|
||||
v.precise4 = make(map[netip.Addr]net.IP, len(ips))
|
||||
}
|
||||
if needPrec6 {
|
||||
v.precise6 = make(map[netip.Addr]net.IP, len(ips))
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
key, ok := prefixKeyFromIP(ip)
|
||||
if !ok {
|
||||
continue // illegal ip, ignore
|
||||
}
|
||||
|
||||
switch key[0] {
|
||||
case 4:
|
||||
var ipx netip.Addr
|
||||
if needHeur4 {
|
||||
b, exists := v.buckets4[key]
|
||||
if !exists {
|
||||
// build bucket
|
||||
ipx, ok = netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
continue // illegal ip, ignore
|
||||
}
|
||||
b = &ipBucket{
|
||||
rep: ipx,
|
||||
ips: make([]net.IP, 0, 4), // for dns answer
|
||||
}
|
||||
v.buckets4[key] = b
|
||||
}
|
||||
b.ips = append(b.ips, ip)
|
||||
}
|
||||
if needPrec4 {
|
||||
if !ipx.IsValid() {
|
||||
ipx, ok = netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
continue // illegal ip, ignore
|
||||
}
|
||||
}
|
||||
v.precise4[ipx] = ip
|
||||
}
|
||||
case 6:
|
||||
var ipx netip.Addr
|
||||
if needHeur6 {
|
||||
b, exists := v.buckets6[key]
|
||||
if !exists {
|
||||
// build bucket
|
||||
ipx, ok = netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
continue // illegal ip, ignore
|
||||
}
|
||||
b = &ipBucket{
|
||||
rep: ipx,
|
||||
ips: make([]net.IP, 0, 4), // for dns answer
|
||||
}
|
||||
v.buckets6[key] = b
|
||||
}
|
||||
b.ips = append(b.ips, ip)
|
||||
}
|
||||
if needPrec6 {
|
||||
if !ipx.IsValid() {
|
||||
ipx, ok = netipx.FromStdIP(ip)
|
||||
if !ok {
|
||||
continue // illegal ip, ignore
|
||||
}
|
||||
}
|
||||
v.precise6[ipx] = ip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ToggleReverse implements IPMatcher.
|
||||
func (mm *HeuristicMultiIPMatcher) ToggleReverse() {
|
||||
for _, m := range mm.matchers {
|
||||
m.ToggleReverse()
|
||||
}
|
||||
}
|
||||
|
||||
// SetReverse implements IPMatcher.
|
||||
func (mm *HeuristicMultiIPMatcher) SetReverse(reverse bool) {
|
||||
for _, m := range mm.matchers {
|
||||
m.SetReverse(reverse)
|
||||
}
|
||||
}
|
||||
|
||||
type IPSetFactory struct {
|
||||
sync.Mutex
|
||||
shared map[string]*IPSet // TODO: cleanup
|
||||
}
|
||||
|
||||
func (f *IPSetFactory) GetOrCreateFromGeoIPRules(rules []*GeoIPRule) (*IPSet, error) {
|
||||
key := buildGeoIPRulesKey(rules)
|
||||
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
|
||||
if ipset := f.shared[key]; ipset != nil {
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
ipset, err := f.createFrom(func(add func(*CIDR)) error {
|
||||
for _, r := range rules {
|
||||
cidrs, err := loadIP(r.File, r.Code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i, c := range cidrs {
|
||||
add(c)
|
||||
cidrs[i] = nil // peak mem
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
f.shared[key] = ipset
|
||||
}
|
||||
return ipset, err
|
||||
}
|
||||
|
||||
func buildGeoIPRulesKey(rules []*GeoIPRule) string {
|
||||
rules = slices.Clone(rules)
|
||||
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
ri, rj := rules[i], rules[j]
|
||||
if ri.File != rj.File {
|
||||
return ri.File < rj.File
|
||||
}
|
||||
return ri.Code < rj.Code
|
||||
})
|
||||
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(rules) * 20) // geoip.dat:xx,
|
||||
var last *GeoIPRule
|
||||
for i, r := range rules {
|
||||
if i == 0 || (r.File != last.File || r.Code != last.Code) {
|
||||
last = r
|
||||
sb.WriteString(r.File)
|
||||
sb.WriteString(":")
|
||||
sb.WriteString(r.Code)
|
||||
sb.WriteString(",")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (f *IPSetFactory) CreateFromCIDRs(cidrs []*CIDR) (*IPSet, error) {
|
||||
return f.createFrom(func(add func(*CIDR)) error {
|
||||
for _, c := range cidrs {
|
||||
add(c)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (f *IPSetFactory) createFrom(yield func(func(*CIDR)) error) (*IPSet, error) {
|
||||
var ipv4Builder, ipv6Builder netipx.IPSetBuilder
|
||||
|
||||
err := yield(func(c *CIDR) {
|
||||
ipBytes := c.GetIp()
|
||||
prefixLen := int(c.GetPrefix())
|
||||
|
||||
addr, ok := netip.AddrFromSlice(ipBytes)
|
||||
if !ok {
|
||||
errors.LogError(context.Background(), "ignore invalid IP byte slice: ", ipBytes)
|
||||
return
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(addr, prefixLen)
|
||||
if !prefix.IsValid() {
|
||||
errors.LogError(context.Background(), "ignore created invalid prefix from addr ", addr, " and length ", prefixLen)
|
||||
return
|
||||
}
|
||||
|
||||
if addr.Is4() {
|
||||
ipv4Builder.AddPrefix(prefix)
|
||||
} else if addr.Is6() {
|
||||
ipv6Builder.AddPrefix(prefix)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// peak mem
|
||||
runtime.GC()
|
||||
defer runtime.GC()
|
||||
|
||||
ipv4, err := ipv4Builder.IPSet()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to build IPv4 set").Base(err)
|
||||
}
|
||||
ipv6, err := ipv6Builder.IPSet()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to build IPv6 set").Base(err)
|
||||
}
|
||||
|
||||
var max4, max6 int
|
||||
|
||||
for _, p := range ipv4.Prefixes() {
|
||||
if b := p.Bits(); b > max4 {
|
||||
max4 = b
|
||||
}
|
||||
}
|
||||
for _, p := range ipv6.Prefixes() {
|
||||
if b := p.Bits(); b > max6 {
|
||||
max6 = b
|
||||
}
|
||||
}
|
||||
|
||||
if max4 == 0 {
|
||||
max4 = 0xff
|
||||
}
|
||||
if max6 == 0 {
|
||||
max6 = 0xff
|
||||
}
|
||||
|
||||
return &IPSet{ipv4: ipv4, ipv6: ipv6, max4: uint8(max4), max6: uint8(max6)}, nil
|
||||
}
|
||||
|
||||
func buildOptimizedIPMatcher(f *IPSetFactory, rules []*IPRule) (IPMatcher, error) {
|
||||
n := len(rules)
|
||||
custom := make([]*CIDR, 0, n)
|
||||
pos := make([]*GeoIPRule, 0, n)
|
||||
neg := make([]*GeoIPRule, 0, n)
|
||||
|
||||
for _, r := range rules {
|
||||
switch v := r.Value.(type) {
|
||||
case *IPRule_Custom:
|
||||
custom = append(custom, v.Custom)
|
||||
case *IPRule_Geoip:
|
||||
if !v.Geoip.ReverseMatch {
|
||||
pos = append(pos, v.Geoip)
|
||||
} else {
|
||||
neg = append(neg, v.Geoip)
|
||||
}
|
||||
default:
|
||||
panic("unknown ip rule type")
|
||||
}
|
||||
}
|
||||
|
||||
subs := make([]*HeuristicIPMatcher, 0, 3)
|
||||
|
||||
if len(custom) > 0 {
|
||||
ipset, err := f.CreateFromCIDRs(custom)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subs = append(subs, &HeuristicIPMatcher{ipset: ipset, reverse: false})
|
||||
}
|
||||
|
||||
if len(pos) > 0 {
|
||||
ipset, err := f.GetOrCreateFromGeoIPRules(pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subs = append(subs, &HeuristicIPMatcher{ipset: ipset, reverse: false})
|
||||
}
|
||||
|
||||
if len(neg) > 0 {
|
||||
ipset, err := f.GetOrCreateFromGeoIPRules(neg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subs = append(subs, &HeuristicIPMatcher{ipset: ipset, reverse: true})
|
||||
}
|
||||
|
||||
switch len(subs) {
|
||||
case 0:
|
||||
return nil, errors.New("no valid ip matcher")
|
||||
case 1:
|
||||
return subs[0], nil
|
||||
default:
|
||||
return &HeuristicMultiIPMatcher{matchers: subs}, nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user