diff --git a/go.mod b/go.mod index 83041dcd..c9dcff8c 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/xtls/xray-core go 1.26 require ( - github.com/apernet/quic-go v0.59.1-0.20260217092621-db4786c77a22 + github.com/apernet/quic-go v0.59.1-0.20260330051153-c402ee641eb6 github.com/cloudflare/circl v1.6.3 github.com/ghodss/yaml v1.0.1-0.20220118164431-d8423dcdf344 github.com/golang/mock v1.7.0-rc.1 diff --git a/go.sum b/go.sum index 81137e64..da6e58fe 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/apernet/quic-go v0.59.1-0.20260217092621-db4786c77a22 h1:00ziBGnLWQEcR9LThDwvxOznJJquJ9bYUdmBFnawLMU= -github.com/apernet/quic-go v0.59.1-0.20260217092621-db4786c77a22/go.mod h1:Npbg8qBtAZlsAB3FWmqwlVh5jtVG6a4DlYsOylUpvzA= +github.com/apernet/quic-go v0.59.1-0.20260330051153-c402ee641eb6 h1:cbF95uMsQwCwAzH2i8+2lNO2TReoELLuqeeMfyBjFbY= +github.com/apernet/quic-go v0.59.1-0.20260330051153-c402ee641eb6/go.mod h1:Npbg8qBtAZlsAB3FWmqwlVh5jtVG6a4DlYsOylUpvzA= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 849f0e38..70259507 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -36,6 +36,7 @@ import ( "github.com/xtls/xray-core/transport/internet/finalmask/xicmp" "github.com/xtls/xray-core/transport/internet/httpupgrade" "github.com/xtls/xray-core/transport/internet/hysteria" + "github.com/xtls/xray-core/transport/internet/hysteria/congestion/bbr" "github.com/xtls/xray-core/transport/internet/kcp" "github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/splithttp" @@ -630,6 +631,7 @@ func (c *TLSCertConfig) Build() (*tls.Certificate, error) { type QuicParamsConfig struct { Congestion string `json:"congestion"` Debug bool `json:"debug"` + BbrProfile string `json:"bbrProfile"` BrutalUp Bandwidth `json:"brutalUp"` BrutalDown Bandwidth `json:"brutalDown"` UdpHop UdpHop `json:"udpHop"` @@ -1894,6 +1896,16 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { config.Udpmasks = append(config.Udpmasks, serial.ToTypedMessage(u)) } if c.FinalMask.QuicParams != nil { + profile := strings.ToLower(c.FinalMask.QuicParams.BbrProfile) + switch profile { + case "", string(bbr.ProfileConservative), string(bbr.ProfileStandard), string(bbr.ProfileAggressive): + if profile == "" { + profile = string(bbr.ProfileStandard) + } + default: + return nil, errors.New("unknown bbr profile") + } + up, err := c.FinalMask.QuicParams.BrutalUp.Bps() if err != nil { return nil, err @@ -1965,6 +1977,7 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { config.QuicParams = &internet.QuicParams{ Congestion: c.FinalMask.QuicParams.Congestion, + BbrProfile: profile, BrutalUp: up, BrutalDown: down, UdpHop: &internet.UdpHop{ diff --git a/transport/internet/config.pb.go b/transport/internet/config.pb.go index 7e1da3f6..e2339fe8 100644 --- a/transport/internet/config.pb.go +++ b/transport/internet/config.pb.go @@ -445,17 +445,18 @@ func (x *UdpHop) GetIntervalMax() int64 { type QuicParams struct { state protoimpl.MessageState `protogen:"open.v1"` Congestion string `protobuf:"bytes,1,opt,name=congestion,proto3" json:"congestion,omitempty"` - BrutalUp uint64 `protobuf:"varint,2,opt,name=brutal_up,json=brutalUp,proto3" json:"brutal_up,omitempty"` - BrutalDown uint64 `protobuf:"varint,3,opt,name=brutal_down,json=brutalDown,proto3" json:"brutal_down,omitempty"` - UdpHop *UdpHop `protobuf:"bytes,4,opt,name=udp_hop,json=udpHop,proto3" json:"udp_hop,omitempty"` - InitStreamReceiveWindow uint64 `protobuf:"varint,5,opt,name=init_stream_receive_window,json=initStreamReceiveWindow,proto3" json:"init_stream_receive_window,omitempty"` - MaxStreamReceiveWindow uint64 `protobuf:"varint,6,opt,name=max_stream_receive_window,json=maxStreamReceiveWindow,proto3" json:"max_stream_receive_window,omitempty"` - InitConnReceiveWindow uint64 `protobuf:"varint,7,opt,name=init_conn_receive_window,json=initConnReceiveWindow,proto3" json:"init_conn_receive_window,omitempty"` - MaxConnReceiveWindow uint64 `protobuf:"varint,8,opt,name=max_conn_receive_window,json=maxConnReceiveWindow,proto3" json:"max_conn_receive_window,omitempty"` - MaxIdleTimeout int64 `protobuf:"varint,9,opt,name=max_idle_timeout,json=maxIdleTimeout,proto3" json:"max_idle_timeout,omitempty"` - KeepAlivePeriod int64 `protobuf:"varint,10,opt,name=keep_alive_period,json=keepAlivePeriod,proto3" json:"keep_alive_period,omitempty"` - DisablePathMtuDiscovery bool `protobuf:"varint,11,opt,name=disable_path_mtu_discovery,json=disablePathMtuDiscovery,proto3" json:"disable_path_mtu_discovery,omitempty"` - MaxIncomingStreams int64 `protobuf:"varint,12,opt,name=max_incoming_streams,json=maxIncomingStreams,proto3" json:"max_incoming_streams,omitempty"` + BbrProfile string `protobuf:"bytes,2,opt,name=bbr_profile,json=bbrProfile,proto3" json:"bbr_profile,omitempty"` + BrutalUp uint64 `protobuf:"varint,3,opt,name=brutal_up,json=brutalUp,proto3" json:"brutal_up,omitempty"` + BrutalDown uint64 `protobuf:"varint,4,opt,name=brutal_down,json=brutalDown,proto3" json:"brutal_down,omitempty"` + UdpHop *UdpHop `protobuf:"bytes,5,opt,name=udp_hop,json=udpHop,proto3" json:"udp_hop,omitempty"` + InitStreamReceiveWindow uint64 `protobuf:"varint,6,opt,name=init_stream_receive_window,json=initStreamReceiveWindow,proto3" json:"init_stream_receive_window,omitempty"` + MaxStreamReceiveWindow uint64 `protobuf:"varint,7,opt,name=max_stream_receive_window,json=maxStreamReceiveWindow,proto3" json:"max_stream_receive_window,omitempty"` + InitConnReceiveWindow uint64 `protobuf:"varint,8,opt,name=init_conn_receive_window,json=initConnReceiveWindow,proto3" json:"init_conn_receive_window,omitempty"` + MaxConnReceiveWindow uint64 `protobuf:"varint,9,opt,name=max_conn_receive_window,json=maxConnReceiveWindow,proto3" json:"max_conn_receive_window,omitempty"` + MaxIdleTimeout int64 `protobuf:"varint,10,opt,name=max_idle_timeout,json=maxIdleTimeout,proto3" json:"max_idle_timeout,omitempty"` + KeepAlivePeriod int64 `protobuf:"varint,11,opt,name=keep_alive_period,json=keepAlivePeriod,proto3" json:"keep_alive_period,omitempty"` + DisablePathMtuDiscovery bool `protobuf:"varint,12,opt,name=disable_path_mtu_discovery,json=disablePathMtuDiscovery,proto3" json:"disable_path_mtu_discovery,omitempty"` + MaxIncomingStreams int64 `protobuf:"varint,13,opt,name=max_incoming_streams,json=maxIncomingStreams,proto3" json:"max_incoming_streams,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -497,6 +498,13 @@ func (x *QuicParams) GetCongestion() string { return "" } +func (x *QuicParams) GetBbrProfile() string { + if x != nil { + return x.BbrProfile + } + return "" +} + func (x *QuicParams) GetBrutalUp() uint64 { if x != nil { return x.BrutalUp @@ -1028,25 +1036,27 @@ const file_transport_internet_config_proto_rawDesc = "" + "\x06UdpHop\x12\x14\n" + "\x05ports\x18\x01 \x03(\rR\x05ports\x12!\n" + "\finterval_min\x18\x02 \x01(\x03R\vintervalMin\x12!\n" + - "\finterval_max\x18\x03 \x01(\x03R\vintervalMax\"\xd1\x04\n" + + "\finterval_max\x18\x03 \x01(\x03R\vintervalMax\"\xf2\x04\n" + "\n" + "QuicParams\x12\x1e\n" + "\n" + "congestion\x18\x01 \x01(\tR\n" + - "congestion\x12\x1b\n" + - "\tbrutal_up\x18\x02 \x01(\x04R\bbrutalUp\x12\x1f\n" + - "\vbrutal_down\x18\x03 \x01(\x04R\n" + + "congestion\x12\x1f\n" + + "\vbbr_profile\x18\x02 \x01(\tR\n" + + "bbrProfile\x12\x1b\n" + + "\tbrutal_up\x18\x03 \x01(\x04R\bbrutalUp\x12\x1f\n" + + "\vbrutal_down\x18\x04 \x01(\x04R\n" + "brutalDown\x128\n" + - "\audp_hop\x18\x04 \x01(\v2\x1f.xray.transport.internet.UdpHopR\x06udpHop\x12;\n" + - "\x1ainit_stream_receive_window\x18\x05 \x01(\x04R\x17initStreamReceiveWindow\x129\n" + - "\x19max_stream_receive_window\x18\x06 \x01(\x04R\x16maxStreamReceiveWindow\x127\n" + - "\x18init_conn_receive_window\x18\a \x01(\x04R\x15initConnReceiveWindow\x125\n" + - "\x17max_conn_receive_window\x18\b \x01(\x04R\x14maxConnReceiveWindow\x12(\n" + - "\x10max_idle_timeout\x18\t \x01(\x03R\x0emaxIdleTimeout\x12*\n" + - "\x11keep_alive_period\x18\n" + - " \x01(\x03R\x0fkeepAlivePeriod\x12;\n" + - "\x1adisable_path_mtu_discovery\x18\v \x01(\bR\x17disablePathMtuDiscovery\x120\n" + - "\x14max_incoming_streams\x18\f \x01(\x03R\x12maxIncomingStreams\"Q\n" + + "\audp_hop\x18\x05 \x01(\v2\x1f.xray.transport.internet.UdpHopR\x06udpHop\x12;\n" + + "\x1ainit_stream_receive_window\x18\x06 \x01(\x04R\x17initStreamReceiveWindow\x129\n" + + "\x19max_stream_receive_window\x18\a \x01(\x04R\x16maxStreamReceiveWindow\x127\n" + + "\x18init_conn_receive_window\x18\b \x01(\x04R\x15initConnReceiveWindow\x125\n" + + "\x17max_conn_receive_window\x18\t \x01(\x04R\x14maxConnReceiveWindow\x12(\n" + + "\x10max_idle_timeout\x18\n" + + " \x01(\x03R\x0emaxIdleTimeout\x12*\n" + + "\x11keep_alive_period\x18\v \x01(\x03R\x0fkeepAlivePeriod\x12;\n" + + "\x1adisable_path_mtu_discovery\x18\f \x01(\bR\x17disablePathMtuDiscovery\x120\n" + + "\x14max_incoming_streams\x18\r \x01(\x03R\x12maxIncomingStreams\"Q\n" + "\vProxyConfig\x12\x10\n" + "\x03tag\x18\x01 \x01(\tR\x03tag\x120\n" + "\x13transportLayerProxy\x18\x02 \x01(\bR\x13transportLayerProxy\"\x93\x01\n" + diff --git a/transport/internet/config.proto b/transport/internet/config.proto index 653cb9aa..ad23f047 100644 --- a/transport/internet/config.proto +++ b/transport/internet/config.proto @@ -72,17 +72,18 @@ message UdpHop { message QuicParams { string congestion = 1; - uint64 brutal_up = 2; - uint64 brutal_down = 3; - UdpHop udp_hop = 4; - uint64 init_stream_receive_window = 5; - uint64 max_stream_receive_window = 6; - uint64 init_conn_receive_window = 7; - uint64 max_conn_receive_window = 8; - int64 max_idle_timeout = 9; - int64 keep_alive_period = 10; - bool disable_path_mtu_discovery = 11; - int64 max_incoming_streams = 12; + string bbr_profile = 2; + uint64 brutal_up = 3; + uint64 brutal_down = 4; + UdpHop udp_hop = 5; + uint64 init_stream_receive_window = 6; + uint64 max_stream_receive_window = 7; + uint64 init_conn_receive_window = 8; + uint64 max_conn_receive_window = 9; + int64 max_idle_timeout = 10; + int64 keep_alive_period = 11; + bool disable_path_mtu_discovery = 12; + int64 max_incoming_streams = 13; } message ProxyConfig { diff --git a/transport/internet/hysteria/congestion/bbr/bbr_sender.go b/transport/internet/hysteria/congestion/bbr/bbr_sender.go index e8787f15..bcbf8133 100644 --- a/transport/internet/hysteria/congestion/bbr/bbr_sender.go +++ b/transport/internet/hysteria/congestion/bbr/bbr_sender.go @@ -6,6 +6,7 @@ import ( "net" "os" "strconv" + "strings" "time" "github.com/apernet/quic-go/congestion" @@ -28,16 +29,13 @@ const ( invalidPacketNumber = -1 initialCongestionWindowPackets = 32 + minCongestionWindowPackets = 4 // Constants based on TCP defaults. // The minimum CWND to ensure delayed acks don't reduce bandwidth measurements. // Does not inflate the pacing rate. - defaultMinimumCongestionWindow = 4 * congestion.ByteCount(congestion.InitialPacketSize) - // The gain used for the STARTUP, equal to 2/ln(2). defaultHighGain = 2.885 - // The newly derived gain for STARTUP, equal to 4 * ln(2) - derivedHighGain = 2.773 // The newly derived CWND gain for STARTUP, 2. derivedHighCWNDGain = 2.0 @@ -66,7 +64,6 @@ const ( // Flag. defaultStartupFullLossCount = 8 quicBbr2DefaultLossThreshold = 0.02 - maxBbrBurstPackets = 10 ) type bbrMode int @@ -97,6 +94,76 @@ const ( bbrRecoveryStateGrowth ) +type Profile string + +const ( + ProfileConservative Profile = "conservative" + ProfileStandard Profile = "standard" + ProfileAggressive Profile = "aggressive" +) + +type profileConfig struct { + highGain float64 + highCwndGain float64 + congestionWindowGainConstant float64 + numStartupRtts int64 + drainToTarget bool + detectOvershooting bool + bytesLostMultiplier uint8 + enableAckAggregationStartup bool + expireAckAggregationStartup bool + enableOverestimateAvoidance bool + reduceExtraAckedOnBandwidthIncrease bool +} + +func ParseProfile(profile string) (Profile, error) { + switch normalized := strings.ToLower(profile); normalized { + case "", string(ProfileStandard): + return ProfileStandard, nil + case string(ProfileConservative): + return ProfileConservative, nil + case string(ProfileAggressive): + return ProfileAggressive, nil + default: + return "", fmt.Errorf("unsupported BBR profile %q", profile) + } +} + +func configForProfile(profile Profile) profileConfig { + switch profile { + case ProfileConservative: + return profileConfig{ + highGain: 2.25, + highCwndGain: 1.75, + congestionWindowGainConstant: 1.75, + numStartupRtts: 2, + drainToTarget: true, + detectOvershooting: true, + bytesLostMultiplier: 1, + enableOverestimateAvoidance: true, + reduceExtraAckedOnBandwidthIncrease: true, + } + case ProfileAggressive: + return profileConfig{ + highGain: 3.0, + highCwndGain: 2.25, + congestionWindowGainConstant: 2.5, + numStartupRtts: 4, + bytesLostMultiplier: 2, + enableAckAggregationStartup: true, + expireAckAggregationStartup: true, + } + default: + return profileConfig{ + highGain: defaultHighGain, + highCwndGain: derivedHighCWNDGain, + congestionWindowGainConstant: 2.0, + numStartupRtts: roundTripsWithoutGrowthBeforeExitingStartup, + bytesLostMultiplier: 2, + } + } +} + type bbrSender struct { rttStats congestion.RTTStatsProvider clock Clock @@ -145,6 +212,9 @@ type bbrSender struct { // The smallest value the |congestion_window_| can achieve. minCongestionWindow congestion.ByteCount + // The BBR profile used by the sender. + profile Profile + // The pacing gain applied during the STARTUP phase. highGain float64 @@ -251,12 +321,14 @@ var _ congestion.CongestionControl = &bbrSender{} func NewBbrSender( clock Clock, initialMaxDatagramSize congestion.ByteCount, + profile Profile, ) *bbrSender { return newBbrSender( clock, initialMaxDatagramSize, initialCongestionWindowPackets*initialMaxDatagramSize, congestion.MaxCongestionWindowPackets*initialMaxDatagramSize, + profile, ) } @@ -265,6 +337,7 @@ func newBbrSender( initialMaxDatagramSize, initialCongestionWindow, initialMaxCongestionWindow congestion.ByteCount, + profile Profile, ) *bbrSender { debug, _ := strconv.ParseBool(os.Getenv(debugEnv)) b := &bbrSender{ @@ -277,9 +350,10 @@ func newBbrSender( congestionWindow: initialCongestionWindow, initialCongestionWindow: initialCongestionWindow, maxCongestionWindow: initialMaxCongestionWindow, - minCongestionWindow: defaultMinimumCongestionWindow, + minCongestionWindow: minCongestionWindowForMaxDatagramSize(initialMaxDatagramSize), + profile: ProfileStandard, highGain: defaultHighGain, - highCwndGain: defaultHighGain, + highCwndGain: derivedHighCWNDGain, drainGain: 1.0 / defaultHighGain, pacingGain: 1.0, congestionWindowGain: 1.0, @@ -295,20 +369,63 @@ func newBbrSender( debug: debug, } b.pacer = common.NewPacer(b.bandwidthForPacer) - - /* - if b.tracer != nil { - b.lastState = logging.CongestionStateStartup - b.tracer.UpdatedCongestionState(logging.CongestionStateStartup) - } - */ + b.applyProfile(profile) + if b.debug { + b.debugPrint("Profile: %s", b.profile) + } b.enterStartupMode(b.clock.Now()) - b.setHighCwndGain(derivedHighCWNDGain) return b } +func (b *bbrSender) applyProfile(profile Profile) { + if profile == "" { + profile = ProfileStandard + } + cfg := configForProfile(profile) + b.profile = profile + b.highGain = cfg.highGain + b.highCwndGain = cfg.highCwndGain + b.drainGain = 1.0 / cfg.highGain + b.congestionWindowGainConstant = cfg.congestionWindowGainConstant + b.numStartupRtts = cfg.numStartupRtts + b.drainToTarget = cfg.drainToTarget + b.detectOvershooting = cfg.detectOvershooting + b.bytesLostMultiplierWhileDetectingOvershooting = cfg.bytesLostMultiplier + b.enableAckAggregationDuringStartup = cfg.enableAckAggregationStartup + b.expireAckAggregationInStartup = cfg.expireAckAggregationStartup + if cfg.enableOverestimateAvoidance { + b.sampler.EnableOverestimateAvoidance() + } + b.sampler.SetReduceExtraAckedOnBandwidthIncrease(cfg.reduceExtraAckedOnBandwidthIncrease) +} + +func minCongestionWindowForMaxDatagramSize(maxDatagramSize congestion.ByteCount) congestion.ByteCount { + return minCongestionWindowPackets * maxDatagramSize +} + +func scaleByteWindowForDatagramSize(window, oldMaxDatagramSize, newMaxDatagramSize congestion.ByteCount) congestion.ByteCount { + if oldMaxDatagramSize == newMaxDatagramSize { + return window + } + return congestion.ByteCount(uint64(window) * uint64(newMaxDatagramSize) / uint64(oldMaxDatagramSize)) +} + +func (b *bbrSender) rescalePacketSizedWindows(maxDatagramSize congestion.ByteCount) { + oldMaxDatagramSize := b.maxDatagramSize + b.maxDatagramSize = maxDatagramSize + b.initialCongestionWindow = scaleByteWindowForDatagramSize(b.initialCongestionWindow, oldMaxDatagramSize, maxDatagramSize) + b.maxCongestionWindow = scaleByteWindowForDatagramSize(b.maxCongestionWindow, oldMaxDatagramSize, maxDatagramSize) + b.minCongestionWindow = minCongestionWindowForMaxDatagramSize(maxDatagramSize) + b.cwndToCalculateMinPacingRate = scaleByteWindowForDatagramSize(b.cwndToCalculateMinPacingRate, oldMaxDatagramSize, maxDatagramSize) + b.maxCongestionWindowWithNetworkParametersAdjusted = scaleByteWindowForDatagramSize( + b.maxCongestionWindowWithNetworkParametersAdjusted, + oldMaxDatagramSize, + maxDatagramSize, + ) +} + func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { b.rttStats = provider } @@ -370,14 +487,24 @@ func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) { // SetMaxDatagramSize implements the SendAlgorithm interface. func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) { + if b.debug { + b.debugPrint("Max Datagram Size: %d", s) + } if s < b.maxDatagramSize { panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s)) } - cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow - b.maxDatagramSize = s - if cwndIsMinCwnd { + oldMinCongestionWindow := b.minCongestionWindow + oldInitialCongestionWindow := b.initialCongestionWindow + b.rescalePacketSizedWindows(s) + switch b.congestionWindow { + case oldMinCongestionWindow: b.congestionWindow = b.minCongestionWindow + case oldInitialCongestionWindow: + b.congestionWindow = b.initialCongestionWindow + default: + b.congestionWindow = min(b.maxCongestionWindow, max(b.congestionWindow, b.minCongestionWindow)) } + b.recoveryWindow = min(b.maxCongestionWindow, max(b.recoveryWindow, b.minCongestionWindow)) b.pacer.SetMaxDatagramSize(s) } @@ -519,22 +646,6 @@ func (b *bbrSender) PacingRate() Bandwidth { return b.pacingRate } -func (b *bbrSender) hasGoodBandwidthEstimateForResumption() bool { - return b.hasNonAppLimitedSample() -} - -func (b *bbrSender) hasNonAppLimitedSample() bool { - return b.hasNoAppLimitedSample -} - -// Sets the pacing gain used in STARTUP. Must be greater than 1. -func (b *bbrSender) setHighGain(highGain float64) { - b.highGain = highGain - if b.mode == bbrModeStartup { - b.pacingGain = highGain - } -} - // Sets the CWND gain used in STARTUP. Must be greater than 1. func (b *bbrSender) setHighCwndGain(highCwndGain float64) { b.highCwndGain = highCwndGain @@ -543,11 +654,6 @@ func (b *bbrSender) setHighCwndGain(highCwndGain float64) { } } -// Sets the gain used in DRAIN. Must be less than 1. -func (b *bbrSender) setDrainGain(drainGain float64) { - b.drainGain = drainGain -} - // Get the current bandwidth estimate. Note that Bandwidth is in bits per second. func (b *bbrSender) bandwidthEstimate() Bandwidth { return b.maxBandwidth.GetBest() diff --git a/transport/internet/hysteria/congestion/bbr/bbr_sender_test.go b/transport/internet/hysteria/congestion/bbr/bbr_sender_test.go new file mode 100644 index 00000000..5aff552f --- /dev/null +++ b/transport/internet/hysteria/congestion/bbr/bbr_sender_test.go @@ -0,0 +1,130 @@ +package bbr + +import ( + "testing" + + "github.com/apernet/quic-go/congestion" + "github.com/stretchr/testify/require" +) + +func TestSetMaxDatagramSizeRescalesPacketSizedWindows(t *testing.T) { + const oldMaxDatagramSize = congestion.ByteCount(1000) + const newMaxDatagramSize = congestion.ByteCount(1400) + const initialCongestionWindowPackets = congestion.ByteCount(20) + const maxCongestionWindowPackets = congestion.ByteCount(80) + + b := newBbrSender( + DefaultClock{}, + oldMaxDatagramSize, + initialCongestionWindowPackets*oldMaxDatagramSize, + maxCongestionWindowPackets*oldMaxDatagramSize, + ProfileStandard, + ) + b.congestionWindow = b.initialCongestionWindow + + b.SetMaxDatagramSize(newMaxDatagramSize) + + require.Equal(t, initialCongestionWindowPackets*newMaxDatagramSize, b.initialCongestionWindow) + require.Equal(t, maxCongestionWindowPackets*newMaxDatagramSize, b.maxCongestionWindow) + require.Equal(t, minCongestionWindowPackets*newMaxDatagramSize, b.minCongestionWindow) + require.Equal(t, initialCongestionWindowPackets*newMaxDatagramSize, b.congestionWindow) +} + +func TestSetMaxDatagramSizeClampsCongestionWindow(t *testing.T) { + const oldMaxDatagramSize = congestion.ByteCount(1000) + const newMaxDatagramSize = congestion.ByteCount(1400) + + b := NewBbrSender(DefaultClock{}, oldMaxDatagramSize, ProfileStandard) + b.congestionWindow = b.minCongestionWindow + oldMaxDatagramSize + b.recoveryWindow = b.minCongestionWindow + oldMaxDatagramSize + + b.SetMaxDatagramSize(newMaxDatagramSize) + + require.Equal(t, b.minCongestionWindow, b.congestionWindow) + require.Equal(t, b.minCongestionWindow, b.recoveryWindow) +} + +func TestNewBbrSenderAppliesProfiles(t *testing.T) { + testCases := []struct { + name string + profile Profile + highGain float64 + highCwndGain float64 + congestionWindowGainConstant float64 + numStartupRtts int64 + drainToTarget bool + detectOvershooting bool + bytesLostMultiplier uint8 + enableAckAggregationDuringStartup bool + expireAckAggregationInStartup bool + enableOverestimateAvoidance bool + reduceExtraAckedOnBandwidthIncrease bool + }{ + { + name: "standard", + profile: ProfileStandard, + highGain: defaultHighGain, + highCwndGain: derivedHighCWNDGain, + congestionWindowGainConstant: 2.0, + numStartupRtts: roundTripsWithoutGrowthBeforeExitingStartup, + bytesLostMultiplier: 2, + }, + { + name: "conservative", + profile: ProfileConservative, + highGain: 2.25, + highCwndGain: 1.75, + congestionWindowGainConstant: 1.75, + numStartupRtts: 2, + drainToTarget: true, + detectOvershooting: true, + bytesLostMultiplier: 1, + enableOverestimateAvoidance: true, + reduceExtraAckedOnBandwidthIncrease: true, + }, + { + name: "aggressive", + profile: ProfileAggressive, + highGain: 3.0, + highCwndGain: 2.25, + congestionWindowGainConstant: 2.5, + numStartupRtts: 4, + bytesLostMultiplier: 2, + enableAckAggregationDuringStartup: true, + expireAckAggregationInStartup: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + b := NewBbrSender(DefaultClock{}, congestion.InitialPacketSize, tc.profile) + require.Equal(t, tc.profile, b.profile) + require.Equal(t, tc.highGain, b.highGain) + require.Equal(t, tc.highCwndGain, b.highCwndGain) + require.Equal(t, tc.congestionWindowGainConstant, b.congestionWindowGainConstant) + require.Equal(t, tc.numStartupRtts, b.numStartupRtts) + require.Equal(t, tc.drainToTarget, b.drainToTarget) + require.Equal(t, tc.detectOvershooting, b.detectOvershooting) + require.Equal(t, tc.bytesLostMultiplier, b.bytesLostMultiplierWhileDetectingOvershooting) + require.Equal(t, tc.enableAckAggregationDuringStartup, b.enableAckAggregationDuringStartup) + require.Equal(t, tc.expireAckAggregationInStartup, b.expireAckAggregationInStartup) + require.Equal(t, tc.enableOverestimateAvoidance, b.sampler.IsOverestimateAvoidanceEnabled()) + require.Equal(t, tc.reduceExtraAckedOnBandwidthIncrease, b.sampler.maxAckHeightTracker.reduceExtraAckedOnBandwidthIncrease) + require.Equal(t, b.highGain, b.pacingGain) + require.Equal(t, b.highCwndGain, b.congestionWindowGain) + }) + } +} + +func TestParseProfile(t *testing.T) { + profile, err := ParseProfile("") + require.NoError(t, err) + require.Equal(t, ProfileStandard, profile) + + profile, err = ParseProfile("Aggressive") + require.NoError(t, err) + require.Equal(t, ProfileAggressive, profile) + + _, err = ParseProfile("turbo") + require.EqualError(t, err, `unsupported BBR profile "turbo"`) +} diff --git a/transport/internet/hysteria/congestion/utils.go b/transport/internet/hysteria/congestion/utils.go index 1036760e..0f04318d 100644 --- a/transport/internet/hysteria/congestion/utils.go +++ b/transport/internet/hysteria/congestion/utils.go @@ -1,18 +1,55 @@ package congestion import ( + "fmt" + "strings" + "github.com/apernet/quic-go" "github.com/xtls/xray-core/transport/internet/hysteria/congestion/bbr" "github.com/xtls/xray-core/transport/internet/hysteria/congestion/brutal" ) -func UseBBR(conn *quic.Conn) { +const ( + TypeBBR = "bbr" + TypeReno = "reno" +) + +func NormalizeType(congestionType string) (string, error) { + switch normalized := strings.ToLower(congestionType); normalized { + case "", TypeBBR: + return TypeBBR, nil + case TypeReno: + return TypeReno, nil + default: + return "", fmt.Errorf("unsupported congestion type %q", congestionType) + } +} + +func NormalizeBBRProfile(profile string) (string, error) { + normalized, err := bbr.ParseProfile(profile) + if err != nil { + return "", err + } + return string(normalized), nil +} + +func UseBBR(conn *quic.Conn, profile bbr.Profile) { conn.SetCongestionControl(bbr.NewBbrSender( bbr.DefaultClock{}, bbr.GetInitialPacketSize(conn.RemoteAddr()), + profile, )) } func UseBrutal(conn *quic.Conn, tx uint64) { conn.SetCongestionControl(brutal.NewBrutalSender(tx)) } + +func UseConfigured(conn *quic.Conn, congestionType, bbrProfile string) { + switch congestionType { + case TypeReno: + return + default: + UseBBR(conn, bbr.Profile(bbrProfile)) + } +} diff --git a/transport/internet/hysteria/dialer.go b/transport/internet/hysteria/dialer.go index d305b00d..b4ce8e4b 100644 --- a/transport/internet/hysteria/dialer.go +++ b/transport/internet/hysteria/dialer.go @@ -23,6 +23,7 @@ import ( "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/finalmask" "github.com/xtls/xray-core/transport/internet/hysteria/congestion" + "github.com/xtls/xray-core/transport/internet/hysteria/congestion/bbr" "github.com/xtls/xray-core/transport/internet/hysteria/udphop" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tls" @@ -157,10 +158,10 @@ func (c *client) dial() error { quicParams := c.quicParams if quicParams == nil { - quicParams = &internet.QuicParams{} - } - if quicParams.UdpHop == nil { - quicParams.UdpHop = &internet.UdpHop{} + quicParams = &internet.QuicParams{ + BbrProfile: string(bbr.ProfileStandard), + UdpHop: &internet.UdpHop{}, + } } var index int @@ -298,12 +299,12 @@ func (c *client) dial() error { case "reno": errors.LogDebug(c.ctx, "congestion reno") case "bbr": - errors.LogDebug(c.ctx, "congestion bbr") - congestion.UseBBR(quicConn) + errors.LogDebug(c.ctx, "congestion bbr ", quicParams.BbrProfile) + congestion.UseBBR(quicConn, bbr.Profile(quicParams.BbrProfile)) case "brutal", "": if serverAuto == "auto" || quicParams.BrutalUp == 0 || serverDown == 0 { - errors.LogDebug(c.ctx, "congestion bbr") - congestion.UseBBR(quicConn) + errors.LogDebug(c.ctx, "congestion bbr ", quicParams.BbrProfile) + congestion.UseBBR(quicConn, bbr.Profile(quicParams.BbrProfile)) } else { errors.LogDebug(c.ctx, "congestion brutal bytes per second ", min(quicParams.BrutalUp, serverDown)) congestion.UseBrutal(quicConn, min(quicParams.BrutalUp, serverDown)) diff --git a/transport/internet/hysteria/hub.go b/transport/internet/hysteria/hub.go index 992680d7..c7a685a1 100644 --- a/transport/internet/hysteria/hub.go +++ b/transport/internet/hysteria/hub.go @@ -23,6 +23,7 @@ import ( hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/hysteria/congestion" + "github.com/xtls/xray-core/transport/internet/hysteria/congestion/bbr" "github.com/xtls/xray-core/transport/internet/tls" ) @@ -188,12 +189,12 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "reno": errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion reno") case "bbr": - errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr") - congestion.UseBBR(h.conn) + errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr ", h.quicParams.BbrProfile) + congestion.UseBBR(h.conn, bbr.Profile(h.quicParams.BbrProfile)) case "brutal", "": if h.quicParams.BrutalUp == 0 || clientDown == 0 { - errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr") - congestion.UseBBR(h.conn) + errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr ", h.quicParams.BbrProfile) + congestion.UseBBR(h.conn, bbr.Profile(h.quicParams.BbrProfile)) } else { errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion brutal bytes per second ", min(h.quicParams.BrutalUp, clientDown)) congestion.UseBrutal(h.conn, min(h.quicParams.BrutalUp, clientDown)) @@ -389,7 +390,10 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti quicParams := streamSettings.QuicParams if quicParams == nil { - quicParams = &internet.QuicParams{} + quicParams = &internet.QuicParams{ + BbrProfile: string(bbr.ProfileStandard), + UdpHop: &internet.UdpHop{}, + } } quicConfig := &quic.Config{ diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 0c351a5a..6f4ec1d8 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -26,6 +26,7 @@ import ( "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/browser_dialer" "github.com/xtls/xray-core/transport/internet/hysteria/congestion" + "github.com/xtls/xray-core/transport/internet/hysteria/congestion/bbr" "github.com/xtls/xray-core/transport/internet/hysteria/udphop" "github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/stat" @@ -158,10 +159,10 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea if httpVersion == "3" { quicParams := streamSettings.QuicParams if quicParams == nil { - quicParams = &internet.QuicParams{} - } - if quicParams.UdpHop == nil { - quicParams.UdpHop = &internet.UdpHop{} + quicParams = &internet.QuicParams{ + BbrProfile: string(bbr.ProfileStandard), + UdpHop: &internet.UdpHop{}, + } } quicConfig := &quic.Config{ @@ -292,8 +293,8 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea case "reno": errors.LogDebug(context.Background(), quicConn.RemoteAddr(), " ", "congestion reno") default: - errors.LogDebug(context.Background(), quicConn.RemoteAddr(), " ", "congestion bbr") - congestion.UseBBR(quicConn) + errors.LogDebug(context.Background(), quicConn.RemoteAddr(), " ", "congestion bbr ", quicParams.BbrProfile) + congestion.UseBBR(quicConn, bbr.Profile(quicParams.BbrProfile)) } return quicConn, nil diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 1ffdf6f2..8b281457 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -25,6 +25,7 @@ import ( "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/hysteria/congestion" + "github.com/xtls/xray-core/transport/internet/hysteria/congestion/bbr" "github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tls" @@ -496,7 +497,10 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet quicParams := streamSettings.QuicParams if quicParams == nil { - quicParams = &internet.QuicParams{} + quicParams = &internet.QuicParams{ + BbrProfile: string(bbr.ProfileStandard), + UdpHop: &internet.UdpHop{}, + } } quicConfig := &quic.Config{ @@ -535,8 +539,8 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet case "reno": errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion reno") default: - errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion bbr") - congestion.UseBBR(conn) + errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion bbr ", quicParams.BbrProfile) + congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile)) } go func() {