diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 9a41f379..cfb360c0 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -469,11 +469,7 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { if len(c.EchConfigList) > 0 || len(c.EchServerKeys) > 0 { err := ApplyECH(c, config) if err != nil { - if c.EchForceQuery == "full" { - errors.LogError(context.Background(), err) - } else { - errors.LogInfo(context.Background(), err) - } + errors.LogError(context.Background(), err) } } diff --git a/transport/internet/tls/config.proto b/transport/internet/tls/config.proto index 57cd7866..45928226 100644 --- a/transport/internet/tls/config.proto +++ b/transport/internet/tls/config.proto @@ -81,6 +81,7 @@ message Config { string ech_config_list = 19; + // Deprecated string ech_force_query = 20; SocketConfig ech_socket_settings = 21; diff --git a/transport/internet/tls/ech.go b/transport/internet/tls/ech.go index 8cfb1251..1a691f97 100644 --- a/transport/internet/tls/ech.go +++ b/transport/internet/tls/ech.go @@ -17,7 +17,6 @@ import ( utls "github.com/refraction-networking/utls" "github.com/xtls/xray-core/common/crypto" - dns2 "github.com/xtls/xray-core/features/dns" "golang.org/x/net/http2" "github.com/miekg/dns" @@ -49,20 +48,10 @@ func ApplyECH(c *Config, config *tls.Config) error { // for client if len(c.EchConfigList) != 0 { - ECHForceQuery := c.EchForceQuery - switch ECHForceQuery { - case "none", "half", "full": - case "": - ECHForceQuery = "full" // default to full - default: - panic("Invalid ECHForceQuery: " + c.EchForceQuery) - } defer func() { // if failed to get ECHConfig, use an invalid one to make connection fail - if err != nil || len(ECHConfig) == 0 { - if ECHForceQuery == "full" { - ECHConfig = []byte{1, 1, 4, 5, 1, 4} - } + if len(ECHConfig) == 0 { + ECHConfig = []byte{1, 1, 4, 5, 1, 4} } config.EncryptedClientHelloConfigList = ECHConfig }() @@ -83,7 +72,7 @@ func ApplyECH(c *Config, config *tls.Config) error { if nameToQuery == "" { return errors.New("Using DNS for ECH Config needs serverName or use Server format example.com+https://1.1.1.1/dns-query") } - ECHConfig, err = QueryRecord(nameToQuery, DNSServer, c.EchForceQuery, c.EchSocketSettings) + ECHConfig, err = QueryRecord(nameToQuery, DNSServer, c.EchSocketSettings) if err != nil { return errors.New("Failed to query ECH DNS record for domain: ", nameToQuery, " at server: ", DNSServer).Base(err) } @@ -107,7 +96,6 @@ type ECHConfigCache struct { type echConfigRecord struct { config []byte expire time.Time - err error } var ( @@ -125,39 +113,34 @@ func ECHCacheKey(server, domain string, sockopt *internet.SocketConfig) string { // Update updates the ECH config for given domain and server. // this method is concurrent safe, only one update request will be sent, others get the cache. // if isLockedUpdate is true, it will not try to acquire the lock. -func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) { +func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool, sockopt *internet.SocketConfig) ([]byte, error) { if !isLockedUpdate { c.UpdateLock.Lock() defer c.UpdateLock.Unlock() } // Double check cache after acquiring lock configRecord := c.configRecord.Load() - if configRecord.expire.After(time.Now()) && configRecord.err == nil { + if configRecord.expire.After(time.Now()) { errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain) - return configRecord.config, configRecord.err + return configRecord.config, nil } // Query ECH config from DNS server errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server) echConfig, ttl, err := dnsQuery(server, domain, sockopt) - // if in "full", directly return - if err != nil && forceQuery == "full" { + if err != nil { return nil, err } - if ttl == 0 { - ttl = dns2.DefaultTTL - } configRecord = &echConfigRecord{ config: echConfig, expire: time.Now().Add(time.Duration(ttl) * time.Second), - err: err, } c.configRecord.Store(configRecord) - return configRecord.config, configRecord.err + return configRecord.config, nil } // QueryRecord returns the ECH config for given domain. // If the record is not in cache or expired, it will query the DNS server and update the cache. -func QueryRecord(domain string, server string, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) { +func QueryRecord(domain string, server string, sockopt *internet.SocketConfig) ([]byte, error) { GlobalECHConfigCacheKey := ECHCacheKey(server, domain, sockopt) echConfigCache, ok := GlobalECHConfigCache.Load(GlobalECHConfigCacheKey) if !ok { @@ -166,25 +149,25 @@ func QueryRecord(domain string, server string, forceQuery string, sockopt *inter echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(GlobalECHConfigCacheKey, echConfigCache) } configRecord := echConfigCache.configRecord.Load() - if configRecord.expire.After(time.Now()) && (configRecord.err == nil || forceQuery == "none") { + if configRecord.expire.After(time.Now()) { errors.LogDebug(context.Background(), "Cache hit for domain: ", domain) - return configRecord.config, configRecord.err + return configRecord.config, nil } // If expire is zero value, it means we are in initial state, wait for the query to finish // otherwise return old value immediately and update in a goroutine // but if the cache is too old, wait for update if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*4).Before(time.Now()) { - return echConfigCache.Update(domain, server, false, forceQuery, sockopt) + return echConfigCache.Update(domain, server, false, sockopt) } else { // If someone already acquired the lock, it means it is updating, do not start another update goroutine if echConfigCache.UpdateLock.TryLock() { go func() { defer echConfigCache.UpdateLock.Unlock() - echConfigCache.Update(domain, server, true, forceQuery, sockopt) + echConfigCache.Update(domain, server, true, sockopt) }() } - return configRecord.config, configRecord.err + return configRecord.config, nil } } @@ -322,8 +305,7 @@ func dnsQuery(server string, domain string, sockopt *internet.SocketConfig) ([]b } } } - // empty is valid, means no ECH config found - return nil, dns2.DefaultTTL, nil + return nil, 0, errors.New("no valid ECH config found in DNS response") } var ErrInvalidLen = errors.New("goech: invalid length") diff --git a/transport/internet/tls/ech_test.go b/transport/internet/tls/ech_test.go index bdf87868..69db965b 100644 --- a/transport/internet/tls/ech_test.go +++ b/transport/internet/tls/ech_test.go @@ -3,6 +3,7 @@ package tls import ( "io" "net/http" + "slices" "strings" "sync" "testing" @@ -59,21 +60,11 @@ func TestECHDial(t *testing.T) { func TestECHDialFail(t *testing.T) { config := &Config{ ServerName: "cloudflare.com", - EchConfigList: "udp://127.0.0.1", - EchForceQuery: "half", + EchConfigList: "udp://0.0.0.0", } - config.GetTLSConfig() - // check cache - echConfigCache, ok := GlobalECHConfigCache.Load(ECHCacheKey("udp://127.0.0.1", "cloudflare.com", nil)) - if !ok { - t.Error("ECH config cache not found") - } - configRecord := echConfigCache.configRecord.Load() - if configRecord == nil { - t.Error("ECH config record not found in cache") - return - } - if configRecord.err == nil { - t.Error("unexpected nil error in ECH config record") + tlsConfig := config.GetTLSConfig() + ApplyECH(config, tlsConfig) + if !slices.Equal(tlsConfig.EncryptedClientHelloConfigList, []byte{1, 1, 4, 5, 1, 4}) { + t.Error("ECH config should be invalid when query failed", " but got ", tlsConfig.EncryptedClientHelloConfigList) } }