From 1d62941bd2a0f1ca410c7b81c1aeb198dd105825 Mon Sep 17 00:00:00 2001 From: LjhAUMEM Date: Sat, 2 May 2026 20:27:27 +0800 Subject: [PATCH] Hysteria: Upgrade to official v2.8.2 (#6041) https://github.com/XTLS/Xray-core/pull/6041#issuecomment-4357417742 And fixes https://github.com/XTLS/Xray-core/issues/6039 --- app/proxyman/inbound/worker.go | 10 +- go.mod | 2 +- go.sum | 4 +- proxy/hysteria/client.go | 98 ++--- proxy/hysteria/config.go | 9 - proxy/hysteria/ctx/ctx.go | 35 -- proxy/hysteria/frag.go | 73 ---- proxy/hysteria/protocol.go | 79 +++- proxy/hysteria/server.go | 61 +-- transport/internet/hysteria/config.go | 91 ++-- transport/internet/hysteria/conn.go | 293 +++++++++---- transport/internet/hysteria/dialer.go | 398 ++++++------------ transport/internet/hysteria/hub.go | 303 ++++--------- .../internet/hysteria/padding/padding.go | 24 -- transport/internet/hysteria/udphop/addr.go | 65 --- transport/internet/hysteria/udphop/conn.go | 192 ++++----- transport/internet/kcp/dialer.go | 46 +- transport/internet/splithttp/dialer.go | 112 ++--- transport/internet/splithttp/hub.go | 71 ++-- transport/internet/udp/dialer.go | 50 +-- 20 files changed, 845 insertions(+), 1171 deletions(-) delete mode 100644 proxy/hysteria/ctx/ctx.go delete mode 100644 proxy/hysteria/frag.go delete mode 100644 transport/internet/hysteria/padding/padding.go delete mode 100644 transport/internet/hysteria/udphop/addr.go diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index 1a7b769a..b91c9ec6 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -18,9 +18,9 @@ import ( "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/proxy" - "github.com/xtls/xray-core/proxy/hysteria/account" - hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx" + hysteria_proxy "github.com/xtls/xray-core/proxy/hysteria" "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/hysteria" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tcp" "github.com/xtls/xray-core/transport/internet/udp" @@ -134,10 +134,8 @@ func (w *tcpWorker) Proxy() proxy.Inbound { func (w *tcpWorker) Start() error { ctx := context.Background() - type HysteriaInboundValidator interface{ HysteriaInboundValidator() *account.Validator } - if v, ok := w.proxy.(HysteriaInboundValidator); ok { - ctx = hyCtx.ContextWithRequireDatagram(ctx, true) - ctx = hyCtx.ContextWithValidator(ctx, v.HysteriaInboundValidator()) + if v, ok := w.proxy.(*hysteria_proxy.Server); ok { + ctx = hysteria.ContextWithValidator(ctx, v.HysteriaInboundValidator()) } hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) { diff --git a/go.mod b/go.mod index 107dc09c..2177477a 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/xtls/xray-core go 1.26 require ( - github.com/apernet/quic-go v0.59.1-0.20260330051153-c402ee641eb6 + github.com/apernet/quic-go v0.59.1-0.20260425001925-6c6cc9bcb716 github.com/cloudflare/circl v1.6.3 github.com/ghodss/yaml v1.0.1-0.20220118164431-d8423dcdf344 github.com/golang/mock v1.7.0-rc.1 diff --git a/go.sum b/go.sum index 31840072..2ce16899 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/apernet/quic-go v0.59.1-0.20260330051153-c402ee641eb6 h1:cbF95uMsQwCwAzH2i8+2lNO2TReoELLuqeeMfyBjFbY= -github.com/apernet/quic-go v0.59.1-0.20260330051153-c402ee641eb6/go.mod h1:Npbg8qBtAZlsAB3FWmqwlVh5jtVG6a4DlYsOylUpvzA= +github.com/apernet/quic-go v0.59.1-0.20260425001925-6c6cc9bcb716 h1:J1O+xpLuJWkdYbw5JPGwBqIHs2J8tiEP7Py9lPqkN2I= +github.com/apernet/quic-go v0.59.1-0.20260425001925-6c6cc9bcb716/go.mod h1:Npbg8qBtAZlsAB3FWmqwlVh5jtVG6a4DlYsOylUpvzA= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= diff --git a/proxy/hysteria/client.go b/proxy/hysteria/client.go index 614a7ff9..7d5730b6 100644 --- a/proxy/hysteria/client.go +++ b/proxy/hysteria/client.go @@ -17,7 +17,6 @@ import ( "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" - hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/hysteria" @@ -56,7 +55,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter ob.CanSpliceCopy = 3 target := ob.Target - conn, err := dialer.Dial(hyCtx.ContextWithRequireDatagram(ctx, target.Network == net.Network_UDP), c.server.Destination) + conn, err := dialer.Dial(hysteria.ContextWithDatagram(ctx, target.Network == net.Network_UDP), c.server.Destination) if err != nil { return errors.New("failed to find an available destination").AtWarning().Base(err) } @@ -118,7 +117,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter if target.Network == net.Network_UDP { iConn := stat.TryUnwrapStatsConn(conn) - _, ok := iConn.(*hysteria.InterUdpConn) + _, ok := iConn.(*hysteria.InterConn) if !ok { return errors.New("udp requires hysteria udp transport") } @@ -127,8 +126,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) writer := &UDPWriter{ - Writer: conn, - buf: make([]byte, MaxUDPSize), + writer: conn, addr: target.NetAddr(), } @@ -143,8 +141,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) reader := &UDPReader{ - Reader: conn, - buf: make([]byte, MaxUDPSize), + reader: conn, df: &Defragger{}, } @@ -173,28 +170,22 @@ func init() { } type UDPWriter struct { - Writer io.Writer - buf []byte + writer io.Writer addr string + buf [buf.Size]byte } -func (w *UDPWriter) sendMsg(msg *UDPMessage) error { - msgN := msg.Serialize(w.buf) +func (w *UDPWriter) SendMessage(msg *UDPMessage) error { + msgN := msg.Serialize(w.buf[:]) if msgN < 0 { return nil } - _, err := w.Writer.Write(w.buf[:msgN]) + _, err := w.writer.Write(w.buf[:msgN]) return err } func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { - for { - mb2, b := buf.SplitFirst(mb) - mb = mb2 - if b == nil { - break - } - + for i, b := range mb { addr := w.addr if b.UDP != nil { addr = b.UDP.NetAddr() @@ -209,22 +200,20 @@ func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { Data: b.Bytes(), } - err := w.sendMsg(msg) + err := w.SendMessage(msg) var errTooLarge *quic.DatagramTooLargeError if go_errors.As(err, &errTooLarge) { msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 fMsgs := FragUDPMessage(msg, int(errTooLarge.MaxDatagramPayloadSize)) for _, fMsg := range fMsgs { - err := w.sendMsg(&fMsg) + err := w.SendMessage(&fMsg) if err != nil { - b.Release() - buf.ReleaseMulti(mb) + buf.ReleaseMulti(mb[i:]) return err } } } else if err != nil { - b.Release() - buf.ReleaseMulti(mb) + buf.ReleaseMulti(mb[i:]) return err } @@ -235,34 +224,21 @@ func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { } type UDPReader struct { - Reader io.Reader - buf []byte - df *Defragger - firstMsg *UDPMessage - firstDest *net.Destination + reader io.Reader + df *Defragger + firstBuf *buf.Buffer } -func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { - if r.firstMsg != nil { - buffer := buf.New() - _, err := buffer.Write(r.firstMsg.Data) - if err != nil { - return nil, err - } - buffer.UDP = r.firstDest - - r.firstMsg = nil - r.firstDest = nil - - return buf.MultiBuffer{buffer}, nil - } +func (r *UDPReader) ReadFrom(p []byte) (n int, addr *net.Destination, err error) { for { - n, err := r.Reader.Read(r.buf) + var buf [hysteria.MaxDatagramFrameSize]byte + + n, err := r.reader.Read(buf[:]) if err != nil { - return nil, err + return 0, nil, err } - msg, err := ParseUDPMessage(r.buf[:n]) + msg, err := ParseUDPMessage(buf[:n]) if err != nil { continue } @@ -274,17 +250,31 @@ func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { dest, err := net.ParseDestination("udp:" + dfMsg.Addr) if err != nil { - errors.LogDebug(context.Background(), dfMsg.Addr, " ParseDestination err ", err) continue } - buffer := buf.New() - if _, err := buffer.Write(dfMsg.Data); err != nil { - return nil, err + if len(p) < len(dfMsg.Data) { + continue } - buffer.UDP = &dest - - return buf.MultiBuffer{buffer}, nil + return copy(p, dfMsg.Data), &dest, nil } } + +func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { + if r.firstBuf != nil { + mb := buf.MultiBuffer{r.firstBuf} + r.firstBuf = nil + return mb, nil + } + b := buf.New() + b.Resize(0, buf.Size) + n, addr, err := r.ReadFrom(b.Bytes()) + if err != nil { + b.Release() + return nil, err + } + b.Resize(0, int32(n)) + b.UDP = addr + return buf.MultiBuffer{b}, nil +} diff --git a/proxy/hysteria/config.go b/proxy/hysteria/config.go index 1daedf03..151f93c3 100644 --- a/proxy/hysteria/config.go +++ b/proxy/hysteria/config.go @@ -1,10 +1 @@ package hysteria - -import ( - "github.com/xtls/xray-core/transport/internet/hysteria/padding" -) - -var ( - tcpRequestPadding = padding.Padding{Min: 64, Max: 512} - tcpResponsePadding = padding.Padding{Min: 128, Max: 1024} -) diff --git a/proxy/hysteria/ctx/ctx.go b/proxy/hysteria/ctx/ctx.go deleted file mode 100644 index 4e1b290c..00000000 --- a/proxy/hysteria/ctx/ctx.go +++ /dev/null @@ -1,35 +0,0 @@ -package ctx - -import ( - "context" - - "github.com/xtls/xray-core/proxy/hysteria/account" -) - -type key int - -const ( - requireDatagram key = iota - validator -) - -func ContextWithRequireDatagram(ctx context.Context, udp bool) context.Context { - if !udp { - return ctx - } - return context.WithValue(ctx, requireDatagram, struct{}{}) -} - -func RequireDatagramFromContext(ctx context.Context) bool { - _, ok := ctx.Value(requireDatagram).(struct{}) - return ok -} - -func ContextWithValidator(ctx context.Context, v *account.Validator) context.Context { - return context.WithValue(ctx, validator, v) -} - -func ValidatorFromContext(ctx context.Context) *account.Validator { - v, _ := ctx.Value(validator).(*account.Validator) - return v -} diff --git a/proxy/hysteria/frag.go b/proxy/hysteria/frag.go deleted file mode 100644 index 64a6b0e1..00000000 --- a/proxy/hysteria/frag.go +++ /dev/null @@ -1,73 +0,0 @@ -package hysteria - -func FragUDPMessage(m *UDPMessage, maxSize int) []UDPMessage { - if m.Size() <= maxSize { - return []UDPMessage{*m} - } - fullPayload := m.Data - maxPayloadSize := maxSize - m.HeaderSize() - off := 0 - fragID := uint8(0) - fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up - frags := make([]UDPMessage, fragCount) - for off < len(fullPayload) { - payloadSize := len(fullPayload) - off - if payloadSize > maxPayloadSize { - payloadSize = maxPayloadSize - } - frag := *m - frag.FragID = fragID - frag.FragCount = fragCount - frag.Data = fullPayload[off : off+payloadSize] - frags[fragID] = frag - off += payloadSize - fragID++ - } - return frags -} - -// Defragger handles the defragmentation of UDP messages. -// The current implementation can only handle one packet ID at a time. -// If another packet arrives before a packet has received all fragments -// in their entirety, any previous state is discarded. -type Defragger struct { - pktID uint16 - frags []*UDPMessage - count uint8 - size int // data size -} - -func (d *Defragger) Feed(m *UDPMessage) *UDPMessage { - if m.FragCount <= 1 { - return m - } - if m.FragID >= m.FragCount { - // wtf is this? - return nil - } - if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) { - // new message, clear previous state - d.pktID = m.PacketID - d.frags = make([]*UDPMessage, m.FragCount) - d.frags[m.FragID] = m - d.count = 1 - d.size = len(m.Data) - } else if d.frags[m.FragID] == nil { - d.frags[m.FragID] = m - d.count++ - d.size += len(m.Data) - if int(d.count) == len(d.frags) { - // all fragments received, assemble - data := make([]byte, d.size) - off := 0 - for _, frag := range d.frags { - off += copy(data[off:], frag.Data) - } - m.Data = data - m.FragID = 0 - m.FragCount = 1 - return m - } - } - return nil -} diff --git a/proxy/hysteria/protocol.go b/proxy/hysteria/protocol.go index b838d15a..5434bd00 100644 --- a/proxy/hysteria/protocol.go +++ b/proxy/hysteria/protocol.go @@ -8,6 +8,7 @@ import ( "github.com/apernet/quic-go/quicvarint" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/hysteria" ) const ( @@ -17,8 +18,6 @@ const ( MaxMessageLength = 2048 MaxPaddingLength = 4096 - MaxUDPSize = 4096 - maxVarInt1 = 63 maxVarInt2 = 16383 maxVarInt4 = 1073741823 @@ -62,7 +61,7 @@ func ReadTCPRequest(r io.Reader) (string, error) { } func WriteTCPRequest(w io.Writer, addr string) error { - padding := tcpRequestPadding.String() + padding := hysteria.TcpRequestPadding.String() paddingLen := len(padding) addrLen := len(addr) sz := int(quicvarint.Len(uint64(addrLen))) + addrLen + @@ -122,7 +121,7 @@ func ReadTCPResponse(r io.Reader) (bool, string, error) { } func WriteTCPResponse(w io.Writer, ok bool, msg string) error { - padding := tcpResponsePadding.String() + padding := hysteria.TcpResponsePadding.String() paddingLen := len(padding) msgLen := len(msg) sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + @@ -247,3 +246,75 @@ func varintPut(b []byte, i uint64) int { } panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) } + +func FragUDPMessage(m *UDPMessage, maxSize int) []UDPMessage { + if m.Size() <= maxSize { + return []UDPMessage{*m} + } + fullPayload := m.Data + maxPayloadSize := maxSize - m.HeaderSize() + off := 0 + fragID := uint8(0) + fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up + frags := make([]UDPMessage, fragCount) + for off < len(fullPayload) { + payloadSize := len(fullPayload) - off + if payloadSize > maxPayloadSize { + payloadSize = maxPayloadSize + } + frag := *m + frag.FragID = fragID + frag.FragCount = fragCount + frag.Data = fullPayload[off : off+payloadSize] + frags[fragID] = frag + off += payloadSize + fragID++ + } + return frags +} + +// Defragger handles the defragmentation of UDP messages. +// The current implementation can only handle one packet ID at a time. +// If another packet arrives before a packet has received all fragments +// in their entirety, any previous state is discarded. +type Defragger struct { + pktID uint16 + frags []*UDPMessage + count uint8 + size int // data size +} + +func (d *Defragger) Feed(m *UDPMessage) *UDPMessage { + if m.FragCount <= 1 { + return m + } + if m.FragID >= m.FragCount { + // wtf is this? + return nil + } + if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) { + // new message, clear previous state + d.pktID = m.PacketID + d.frags = make([]*UDPMessage, m.FragCount) + d.frags[m.FragID] = m + d.count = 1 + d.size = len(m.Data) + } else if d.frags[m.FragID] == nil { + d.frags[m.FragID] = m + d.count++ + d.size += len(m.Data) + if int(d.count) == len(d.frags) { + // all fragments received, assemble + data := make([]byte, d.size) + off := 0 + for _, frag := range d.frags { + off += copy(data[off:], frag.Data) + } + m.Data = data + m.FragID = 0 + m.FragCount = 1 + return m + } + } + return nil +} diff --git a/proxy/hysteria/server.go b/proxy/hysteria/server.go index 3509a44c..815faca1 100644 --- a/proxy/hysteria/server.go +++ b/proxy/hysteria/server.go @@ -2,7 +2,6 @@ package hysteria import ( "context" - "io" "time" "github.com/xtls/xray-core/common" @@ -91,54 +90,30 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con inbound.User = v.User() } - if _, ok := iConn.(*hysteria.InterUdpConn); ok { - r := io.Reader(conn) - b := make([]byte, MaxUDPSize) - df := &Defragger{} - var firstMsg *UDPMessage - var firstDest net.Destination - - for { - n, err := r.Read(b) - if err != nil { - return err - } - - msg, err := ParseUDPMessage(b[:n]) - if err != nil { - continue - } - - dfMsg := df.Feed(msg) - if dfMsg == nil { - continue - } - - firstMsg = dfMsg - firstDest, err = net.ParseDestination("udp:" + firstMsg.Addr) - if err != nil { - errors.LogDebug(context.Background(), dfMsg.Addr, " ParseDestination err ", err) - continue - } - - break - } - + if _, ok := iConn.(*hysteria.InterConn); ok { reader := &UDPReader{ - Reader: r, - buf: b, - df: df, - firstMsg: firstMsg, - firstDest: &firstDest, + reader: conn, + df: &Defragger{}, } + b := buf.New() + b.Resize(0, buf.Size) + n, addr, err := reader.ReadFrom(b.Bytes()) + if err != nil { + b.Release() + return err + } + b.Resize(0, int32(n)) + b.UDP = addr + + reader.firstBuf = b + writer := &UDPWriter{ - Writer: conn, - buf: make([]byte, MaxUDPSize), - addr: firstMsg.Addr, + writer: conn, + addr: addr.NetAddr(), } - return dispatcher.DispatchLink(ctx, firstDest, &transport.Link{ + return dispatcher.DispatchLink(ctx, *addr, &transport.Link{ Reader: reader, Writer: writer, }) diff --git a/transport/internet/hysteria/config.go b/transport/internet/hysteria/config.go index 37bff909..bf87d604 100644 --- a/transport/internet/hysteria/config.go +++ b/transport/internet/hysteria/config.go @@ -1,45 +1,82 @@ package hysteria import ( + "context" + "math/rand" "time" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/proxy/hysteria/account" "github.com/xtls/xray-core/transport/internet" - "github.com/xtls/xray-core/transport/internet/hysteria/padding" ) const ( closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError - - MaxDatagramFrameSize = 1200 - - URLHost = "hysteria" - URLPath = "/auth" - - RequestHeaderAuth = "Hysteria-Auth" - ResponseHeaderUDPEnabled = "Hysteria-UDP" - CommonHeaderCCRX = "Hysteria-CC-RX" - CommonHeaderPadding = "Hysteria-Padding" - - StatusAuthOK = 233 - - udpMessageChanSize = 1024 - - FrameTypeTCPRequest = 0x401 - - idleCleanupInterval = 1 * time.Second + URLHost = "hysteria" + URLPath = "/auth" + RequestHeaderAuth = "Hysteria-Auth" + ResponseHeaderUDPEnabled = "Hysteria-UDP" + CommonHeaderCCRX = "Hysteria-CC-RX" + CommonHeaderPadding = "Hysteria-Padding" + StatusAuthOK = 233 + FrameTypeTCPRequest = 0x401 + MaxDatagramFrameSize = 1200 + udpMessageChanSize = 1024 + idleCleanupInterval = 1 * time.Second ) -var ( - authRequestPadding = padding.Padding{Min: 256, Max: 2048} - authResponsePadding = padding.Padding{Min: 256, Max: 2048} -) - -type Status int - const ( - StatusUnknown Status = iota + paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +) + +type padding struct { + Min int + Max int +} + +func (p padding) String() string { + n := p.Min + rand.Intn(p.Max-p.Min) + bs := make([]byte, n) + for i := range bs { + bs[i] = paddingChars[rand.Intn(len(paddingChars))] + } + return string(bs) +} + +var ( + AuthRequestPadding = padding{Min: 256, Max: 2048} + AuthResponsePadding = padding{Min: 256, Max: 2048} + TcpRequestPadding = padding{Min: 64, Max: 512} + TcpResponsePadding = padding{Min: 128, Max: 1024} +) + +type datagramKey struct{} + +func ContextWithDatagram(ctx context.Context, v bool) context.Context { + return context.WithValue(ctx, datagramKey{}, v) +} + +func DatagramFromContext(ctx context.Context) bool { + v, _ := ctx.Value(datagramKey{}).(bool) + return v +} + +type validatorKey struct{} + +func ContextWithValidator(ctx context.Context, v *account.Validator) context.Context { + return context.WithValue(ctx, validatorKey{}, v) +} + +func ValidatorFromContext(ctx context.Context) *account.Validator { + v, _ := ctx.Value(validatorKey{}).(*account.Validator) + return v +} + +type status int + +const ( + StatusNull status = iota StatusActive StatusInactive ) diff --git a/transport/internet/hysteria/conn.go b/transport/internet/hysteria/conn.go index cf0920d8..ce2a4af3 100644 --- a/transport/internet/hysteria/conn.go +++ b/transport/internet/hysteria/conn.go @@ -1,6 +1,7 @@ package hysteria import ( + "context" "encoding/binary" "io" "sync" @@ -8,8 +9,10 @@ import ( "github.com/apernet/quic-go" "github.com/apernet/quic-go/quicvarint" + "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/transport/internet" ) type interConn struct { @@ -18,144 +21,278 @@ type interConn struct { remote net.Addr client bool - mutex sync.Mutex - - user *protocol.MemoryUser + user *protocol.MemoryUser } -func (i *interConn) User() *protocol.MemoryUser { - return i.user +func (c *interConn) User() *protocol.MemoryUser { + return c.user } -func (i *interConn) Read(b []byte) (int, error) { - return i.stream.Read(b) +func (c *interConn) Read(b []byte) (int, error) { + return c.stream.Read(b) } -func (i *interConn) Write(b []byte) (int, error) { - if i.client { - i.mutex.Lock() - defer i.mutex.Unlock() - if i.client { - buf := make([]byte, 0, quicvarint.Len(FrameTypeTCPRequest)+len(b)) - buf = quicvarint.Append(buf, FrameTypeTCPRequest) - buf = append(buf, b...) - _, err := i.stream.Write(buf) - if err != nil { - return 0, err - } - i.client = false - return len(b), nil +func (c *interConn) Write(b []byte) (int, error) { + if c.client { + c.client = false + if _, err := c.stream.Write(append(quicvarint.Append(nil, FrameTypeTCPRequest), b...)); err != nil { + return 0, err } + return len(b), nil } - return i.stream.Write(b) + return c.stream.Write(b) } -func (i *interConn) Close() error { - i.stream.CancelRead(0) - return i.stream.Close() +func (c *interConn) Close() error { + c.stream.CancelRead(0) + return c.stream.Close() } -func (i *interConn) LocalAddr() net.Addr { - return i.local +func (c *interConn) LocalAddr() net.Addr { + return c.local } -func (i *interConn) RemoteAddr() net.Addr { - return i.remote +func (c *interConn) RemoteAddr() net.Addr { + return c.remote } -func (i *interConn) SetDeadline(t time.Time) error { - return i.stream.SetDeadline(t) +func (c *interConn) SetDeadline(t time.Time) error { + return c.stream.SetDeadline(t) } -func (i *interConn) SetReadDeadline(t time.Time) error { - return i.stream.SetReadDeadline(t) +func (c *interConn) SetReadDeadline(t time.Time) error { + return c.stream.SetReadDeadline(t) } -func (i *interConn) SetWriteDeadline(t time.Time) error { - return i.stream.SetWriteDeadline(t) +func (c *interConn) SetWriteDeadline(t time.Time) error { + return c.stream.SetWriteDeadline(t) } -type InterUdpConn struct { - conn *quic.Conn +type InterConn struct { local net.Addr remote net.Addr - id uint32 - ch chan []byte + id uint32 + ch chan []byte + time time.Time + mutex sync.Mutex + closed bool - closed bool - closeFunc func() - - last time.Time - mutex sync.Mutex - - user *protocol.MemoryUser + write func(p []byte) error + close func() + user *protocol.MemoryUser } -func (i *InterUdpConn) User() *protocol.MemoryUser { +func (i *InterConn) User() *protocol.MemoryUser { return i.user } -func (i *InterUdpConn) SetLast() { - i.mutex.Lock() - defer i.mutex.Unlock() - - i.last = time.Now() +func (c *InterConn) Time() time.Time { + c.mutex.Lock() + v := c.time + c.mutex.Unlock() + return v } -func (i *InterUdpConn) GetLast() time.Time { - i.mutex.Lock() - defer i.mutex.Unlock() - - return i.last +func (c *InterConn) Update() { + c.mutex.Lock() + c.time = time.Now() + c.mutex.Unlock() } -func (i *InterUdpConn) Read(p []byte) (int, error) { - b, ok := <-i.ch +func (c *InterConn) Read(p []byte) (int, error) { + b, ok := <-c.ch if !ok { return 0, io.EOF } - n := copy(p, b) - if n != len(b) { + if len(p) < len(b) { return 0, io.ErrShortBuffer } - - i.SetLast() - return n, nil + c.Update() + return copy(p, b), nil } -func (i *InterUdpConn) Write(p []byte) (int, error) { - i.SetLast() - - binary.BigEndian.PutUint32(p, i.id) - if err := i.conn.SendDatagram(p); err != nil { +func (c *InterConn) Write(p []byte) (int, error) { + if c.closed { + return 0, io.ErrClosedPipe + } + binary.BigEndian.PutUint32(p, c.id) + if err := c.write(p); err != nil { return 0, err } + c.Update() return len(p), nil } -func (i *InterUdpConn) Close() error { - i.closeFunc() +func (c *InterConn) Close() error { + c.close() return nil } -func (i *InterUdpConn) LocalAddr() net.Addr { - return i.local +func (c *InterConn) LocalAddr() net.Addr { + return c.local } -func (i *InterUdpConn) RemoteAddr() net.Addr { - return i.remote +func (c *InterConn) RemoteAddr() net.Addr { + return c.remote } -func (i *InterUdpConn) SetDeadline(t time.Time) error { +func (c *InterConn) SetDeadline(t time.Time) error { return nil } -func (i *InterUdpConn) SetReadDeadline(t time.Time) error { +func (c *InterConn) SetReadDeadline(t time.Time) error { return nil } -func (i *InterUdpConn) SetWriteDeadline(t time.Time) error { +func (c *InterConn) SetWriteDeadline(t time.Time) error { return nil } + +type udpSessionManager struct { + sync.RWMutex + + conn *quic.Conn + m map[uint32]*InterConn + next uint32 + closed bool + + addConn internet.ConnHandler + udpIdleTimeout time.Duration + user *protocol.MemoryUser +} + +func (m *udpSessionManager) close(udpConn *InterConn) { + if !udpConn.closed { + udpConn.closed = true + close(udpConn.ch) + delete(m.m, udpConn.id) + } +} + +func (m *udpSessionManager) clean() { + ticker := time.NewTicker(idleCleanupInterval) + defer ticker.Stop() + + for range ticker.C { + if m.closed { + return + } + + m.RLock() + now := time.Now() + timeoutConn := make([]*InterConn, 0, len(m.m)) + for _, udpConn := range m.m { + if now.Sub(udpConn.Time()) > m.udpIdleTimeout { + timeoutConn = append(timeoutConn, udpConn) + } + } + m.RUnlock() + + for _, udpConn := range timeoutConn { + m.Lock() + m.close(udpConn) + m.Unlock() + } + } +} + +func (m *udpSessionManager) run() { + for { + d, err := m.conn.ReceiveDatagram(context.Background()) + if err != nil { + break + } + + if len(d) < 4 { + continue + } + id := binary.BigEndian.Uint32(d[:4]) + + m.feed(id, d) + } + + m.Lock() + defer m.Unlock() + + m.closed = true + + for _, udpConn := range m.m { + m.close(udpConn) + } +} + +func (m *udpSessionManager) udp() (*InterConn, error) { + m.Lock() + defer m.Unlock() + + if m.closed { + return nil, errors.New("closed") + } + + udpConn := &InterConn{ + local: m.conn.LocalAddr(), + remote: m.conn.RemoteAddr(), + + id: m.next, + ch: make(chan []byte, udpMessageChanSize), + } + udpConn.write = m.conn.SendDatagram + udpConn.close = func() { + m.Lock() + m.close(udpConn) + m.Unlock() + } + m.m[m.next] = udpConn + m.next++ + + return udpConn, nil +} + +func (m *udpSessionManager) feed(id uint32, d []byte) { + m.RLock() + udpConn, ok := m.m[id] + if ok { + select { + case udpConn.ch <- d: + default: + } + m.RUnlock() + return + } + m.RUnlock() + + if m.addConn == nil { + return + } + + m.Lock() + defer m.Unlock() + + udpConn, ok = m.m[id] + if !ok { + udpConn = &InterConn{ + local: m.conn.LocalAddr(), + remote: m.conn.RemoteAddr(), + + id: id, + ch: make(chan []byte, udpMessageChanSize), + time: time.Now(), + } + udpConn.write = m.conn.SendDatagram + udpConn.close = func() { + m.Lock() + m.close(udpConn) + m.Unlock() + } + udpConn.user = m.user + m.m[id] = udpConn + m.addConn(udpConn) + } + + select { + case udpConn.ch <- d: + default: + } +} diff --git a/transport/internet/hysteria/dialer.go b/transport/internet/hysteria/dialer.go index 1b2b5876..408550a1 100644 --- a/transport/internet/hysteria/dialer.go +++ b/transport/internet/hysteria/dialer.go @@ -3,11 +3,10 @@ package hysteria import ( "context" go_tls "crypto/tls" - "encoding/binary" - "math/rand" "net/http" "net/url" "reflect" + "runtime" "strconv" "sync" "time" @@ -18,8 +17,6 @@ import ( "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net/cnc" - "github.com/xtls/xray-core/common/task" - hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/finalmask" "github.com/xtls/xray-core/transport/internet/hysteria/congestion" @@ -29,107 +26,25 @@ import ( "github.com/xtls/xray-core/transport/internet/tls" ) -type udpSessionManagerClient struct { - conn *quic.Conn - m map[uint32]*InterUdpConn - next uint32 - closed bool - mutex sync.RWMutex -} - -func (m *udpSessionManagerClient) close(udpConn *InterUdpConn) { - if !udpConn.closed { - udpConn.closed = true - close(udpConn.ch) - delete(m.m, udpConn.id) - } -} - -func (m *udpSessionManagerClient) run() { - for { - d, err := m.conn.ReceiveDatagram(context.Background()) - if err != nil { - break - } - - if len(d) < 4 { - continue - } - id := binary.BigEndian.Uint32(d[:4]) - - m.feed(id, d) - } - - m.mutex.Lock() - defer m.mutex.Unlock() - - m.closed = true - - for _, udpConn := range m.m { - m.close(udpConn) - } -} - -func (m *udpSessionManagerClient) udp() (*InterUdpConn, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closed { - return nil, errors.New("closed") - } - - udpConn := &InterUdpConn{ - conn: m.conn, - local: m.conn.LocalAddr(), - remote: m.conn.RemoteAddr(), - - id: m.next, - ch: make(chan []byte, udpMessageChanSize), - } - udpConn.closeFunc = func() { - m.mutex.Lock() - defer m.mutex.Unlock() - m.close(udpConn) - } - m.m[m.next] = udpConn - m.next++ - - return udpConn, nil -} - -func (m *udpSessionManagerClient) feed(id uint32, d []byte) { - m.mutex.RLock() - defer m.mutex.RUnlock() - - udpConn, ok := m.m[id] - if !ok { - return - } - - select { - case udpConn.ch <- d: - default: - } -} - type client struct { - ctx context.Context + sync.Mutex + dest net.Destination - pktConn net.PacketConn - conn *quic.Conn config *Config tlsConfig *go_tls.Config socketConfig *internet.SocketConfig udpmaskManager *finalmask.UdpmaskManager quicParams *internet.QuicParams - udpSM *udpSessionManagerClient - mutex sync.Mutex + conn *quic.Conn + tr *quic.Transport + pktConn net.PacketConn + udpSM *udpSessionManager } -func (c *client) status() Status { +func (c *client) status() status { if c.conn == nil { - return StatusUnknown + return StatusNull } select { case <-c.conn.Context().Done(): @@ -140,10 +55,12 @@ func (c *client) status() Status { } func (c *client) close() { - _ = c.conn.CloseWithError(closeErrCodeOK, "") - _ = c.pktConn.Close() - c.pktConn = nil + c.conn.CloseWithError(closeErrCodeOK, "") + c.tr.Close() + c.pktConn.Close() c.conn = nil + c.tr = nil + c.pktConn = nil c.udpSM = nil } @@ -164,61 +81,6 @@ func (c *client) dial() error { } } - var index int - if len(quicParams.UdpHop.Ports) > 0 { - index = rand.Intn(len(quicParams.UdpHop.Ports)) - c.dest.Port = net.Port(quicParams.UdpHop.Ports[index]) - } - - raw, err := internet.DialSystem(c.ctx, c.dest, c.socketConfig) - if err != nil { - return errors.New("failed to dial to dest").Base(err) - } - - var pktConn net.PacketConn - var remote *net.UDPAddr - - switch conn := raw.(type) { - case *internet.PacketConnWrapper: - pktConn = conn.PacketConn - remote = conn.RemoteAddr().(*net.UDPAddr) - case *net.UDPConn: - pktConn = conn - remote = conn.RemoteAddr().(*net.UDPAddr) - case *cnc.Connection: - fakeConn := &internet.FakePacketConn{Conn: conn} - pktConn = fakeConn - remote = fakeConn.RemoteAddr().(*net.UDPAddr) - - if len(quicParams.UdpHop.Ports) > 0 { - raw.Close() - return errors.New("udphop requires being at the outermost level") - } - default: - raw.Close() - return errors.New("unknown conn ", reflect.TypeOf(conn)) - } - - if len(quicParams.UdpHop.Ports) > 0 { - addr := &udphop.UDPHopAddr{ - IP: remote.IP, - Ports: quicParams.UdpHop.Ports, - } - pktConn, err = udphop.NewUDPHopPacketConn(addr, index, quicParams.UdpHop.IntervalMin, quicParams.UdpHop.IntervalMax, c.udphopDialer, pktConn) - if err != nil { - raw.Close() - return errors.New("udphop err").Base(err) - } - } - - if c.udpmaskManager != nil { - pktConn, err = c.udpmaskManager.WrapPacketConnClient(pktConn) - if err != nil { - raw.Close() - return errors.New("mask err").Base(err) - } - } - quicConfig := &quic.Config{ InitialStreamReceiveWindow: quicParams.InitStreamReceiveWindow, MaxStreamReceiveWindow: quicParams.MaxStreamReceiveWindow, @@ -226,9 +88,10 @@ func (c *client) dial() error { MaxConnectionReceiveWindow: quicParams.MaxConnReceiveWindow, MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second, KeepAlivePeriod: time.Duration(quicParams.KeepAlivePeriod) * time.Second, - DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery, + DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery || (runtime.GOOS != "linux" && runtime.GOOS != "windows" && runtime.GOOS != "darwin"), EnableDatagrams: true, MaxDatagramFrameSize: MaxDatagramFrameSize, + OmitMaxDatagramFrameSize: time.Now().After(time.Date(2026, 9, 1, 0, 0, 0, 0, time.UTC)), DisablePathManager: true, } if quicParams.InitStreamReceiveWindow == 0 { @@ -250,16 +113,56 @@ func (c *client) dial() error { // quicConfig.KeepAlivePeriod = 10 * time.Second // } - var quicConn *quic.Conn + var pktConn net.PacketConn + var udpAddr *net.UDPAddr + var err error + udpAddr, err = net.ResolveUDPAddr("udp", c.dest.NetAddr()) + if err != nil { + return err + } + if len(quicParams.UdpHop.Ports) > 0 { + pktConn, err = udphop.NewUDPHopPacketConn(udphop.ToAddrs(udpAddr.IP, quicParams.UdpHop.Ports), time.Duration(quicParams.UdpHop.IntervalMin)*time.Second, time.Duration(quicParams.UdpHop.IntervalMax)*time.Second, c.udpHopDialer) + if err != nil { + return err + } + } else { + conn, err := internet.DialSystem(context.Background(), c.dest, c.socketConfig) + if err != nil { + return err + } + switch c := conn.(type) { + case *internet.PacketConnWrapper: + pktConn = c.PacketConn + case *net.UDPConn: + pktConn = c + case *cnc.Connection: + pktConn = &internet.FakePacketConn{Conn: c} + default: + panic(reflect.TypeOf(c)) + } + } + + if c.udpmaskManager != nil { + newConn, err := c.udpmaskManager.WrapPacketConnClient(pktConn) + if err != nil { + pktConn.Close() + return errors.New("mask err").Base(err) + } + pktConn = newConn + } + + tr := &quic.Transport{Conn: pktConn} + + var conn *quic.Conn rt := &http3.Transport{ TLSClientConfig: c.tlsConfig, QUICConfig: quicConfig, Dial: func(ctx context.Context, _ string, tlsCfg *go_tls.Config, cfg *quic.Config) (*quic.Conn, error) { - qc, err := quic.DialEarly(ctx, pktConn, remote, tlsCfg, cfg) + qc, err := tr.DialEarly(ctx, udpAddr, tlsCfg, cfg) if err != nil { return nil, err } - quicConn = qc + conn = qc return qc, nil }, } @@ -273,75 +176,61 @@ func (c *client) dial() error { Header: http.Header{ RequestHeaderAuth: []string{c.config.Auth}, CommonHeaderCCRX: []string{strconv.FormatUint(quicParams.BrutalDown, 10)}, - CommonHeaderPadding: []string{authRequestPadding.String()}, + CommonHeaderPadding: []string{AuthRequestPadding.String()}, }, } resp, err := rt.RoundTrip(req) if err != nil { - if quicConn != nil { - _ = quicConn.CloseWithError(closeErrCodeProtocolError, "") + if conn != nil { + _ = conn.CloseWithError(closeErrCodeProtocolError, "") } + _ = tr.Close() _ = pktConn.Close() - return errors.New("RoundTrip err").Base(err) + return err } if resp.StatusCode != StatusAuthOK { - _ = quicConn.CloseWithError(closeErrCodeProtocolError, "") + _ = conn.CloseWithError(closeErrCodeProtocolError, "") + _ = tr.Close() _ = pktConn.Close() - return errors.New("auth failed") + return errors.New("auth failed code ", resp.StatusCode) } _ = resp.Body.Close() - serverUdp, _ := strconv.ParseBool(resp.Header.Get(ResponseHeaderUDPEnabled)) - serverAuto := resp.Header.Get(CommonHeaderCCRX) - serverDown, _ := strconv.ParseUint(serverAuto, 10, 64) + // udp, _ := strconv.ParseBool(resp.Header.Get(ResponseHeaderUDPEnabled)) + down, _ := strconv.ParseUint(resp.Header.Get(CommonHeaderCCRX), 10, 64) switch quicParams.Congestion { case "reno": - errors.LogDebug(c.ctx, "congestion reno") case "bbr": - errors.LogDebug(c.ctx, "congestion bbr ", quicParams.BbrProfile) - congestion.UseBBR(quicConn, bbr.Profile(quicParams.BbrProfile)) - case "brutal", "": - if serverAuto == "auto" || quicParams.BrutalUp == 0 || serverDown == 0 { - errors.LogDebug(c.ctx, "congestion bbr ", quicParams.BbrProfile) - congestion.UseBBR(quicConn, bbr.Profile(quicParams.BbrProfile)) + congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile)) + case "", "brutal": + if quicParams.BrutalUp == 0 || down == 0 { + congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile)) } else { - errors.LogDebug(c.ctx, "congestion brutal bytes per second ", min(quicParams.BrutalUp, serverDown)) - congestion.UseBrutal(quicConn, min(quicParams.BrutalUp, serverDown)) + congestion.UseBrutal(conn, min(quicParams.BrutalUp, down)) } case "force-brutal": - errors.LogDebug(c.ctx, "congestion brutal bytes per second ", quicParams.BrutalUp) - congestion.UseBrutal(quicConn, quicParams.BrutalUp) + congestion.UseBrutal(conn, quicParams.BrutalUp) default: - errors.LogDebug(c.ctx, "congestion reno") + panic(quicParams.Congestion) } c.pktConn = pktConn - c.conn = quicConn - if serverUdp { - c.udpSM = &udpSessionManagerClient{ - conn: quicConn, - m: make(map[uint32]*InterUdpConn), - next: 1, - } - go c.udpSM.run() + c.tr = tr + c.conn = conn + c.udpSM = &udpSessionManager{ + conn: conn, + m: make(map[uint32]*InterConn), + next: 1, } + go c.udpSM.run() return nil } -func (c *client) clean() { - c.mutex.Lock() - defer c.mutex.Unlock() - - if c.status() == StatusInactive { - c.close() - } -} - func (c *client) tcp() (stat.Connection, error) { - c.mutex.Lock() - defer c.mutex.Unlock() + c.Lock() + defer c.Unlock() err := c.dial() if err != nil { @@ -363,59 +252,43 @@ func (c *client) tcp() (stat.Connection, error) { } func (c *client) udp() (stat.Connection, error) { - c.mutex.Lock() - defer c.mutex.Unlock() + c.Lock() + defer c.Unlock() err := c.dial() if err != nil { return nil, err } - if c.udpSM == nil { - return nil, errors.New("server does not support udp") - } - return c.udpSM.udp() } -func (c *client) setCtx(ctx context.Context) { - c.mutex.Lock() - defer c.mutex.Unlock() - - c.ctx = ctx +func (c *client) clean() { + c.Lock() + if c.status() == StatusInactive { + c.close() + } + c.Unlock() } -func (c *client) udphopDialer(addr *net.UDPAddr) (net.PacketConn, error) { - c.mutex.Lock() - defer c.mutex.Unlock() - - if c.status() != StatusActive { - errors.LogDebug(context.Background(), "skip hop: disconnected QUIC") - return nil, errors.New() - } - - raw, err := internet.DialSystem(c.ctx, net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)), c.socketConfig) +func (c *client) udpHopDialer(addr *net.UDPAddr) (net.PacketConn, error) { + conn, err := internet.DialSystem(context.Background(), net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)), c.socketConfig) if err != nil { - errors.LogDebug(context.Background(), "skip hop: failed to dial to dest") - raw.Close() - return nil, errors.New() + errors.LogInfoInner(context.Background(), err, "skip hop: failed to dial to dest") + return nil, errors.New("failed to dial to dest").Base(err) } var pktConn net.PacketConn - switch conn := raw.(type) { + switch c := conn.(type) { case *internet.PacketConnWrapper: - pktConn = conn.PacketConn + pktConn = c.PacketConn case *net.UDPConn: - pktConn = conn - case *cnc.Connection: - errors.LogDebug(context.Background(), "skip hop: udphop requires being at the outermost level") - raw.Close() - return nil, errors.New() + pktConn = c default: - errors.LogDebug(context.Background(), "skip hop: unknown conn ", reflect.TypeOf(conn)) - raw.Close() - return nil, errors.New() + errors.LogInfo(context.Background(), "skip hop: invalid conn ", reflect.TypeOf(c)) + conn.Close() + return nil, errors.New("invalid conn ", reflect.TypeOf(c)) } return pktConn, nil @@ -427,16 +300,18 @@ type dialerConf struct { } type clientManager struct { - m map[dialerConf]*client - mutex sync.Mutex + sync.RWMutex + m map[dialerConf]*client } func (m *clientManager) clean() { - m.mutex.Lock() - defer m.mutex.Unlock() - - for _, c := range m.m { - c.clean() + ticker := time.NewTicker(idleCleanupInterval) + for range ticker.C { + m.RLock() + for _, c := range m.m { + c.clean() + } + m.RUnlock() } } @@ -449,41 +324,38 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me return nil, errors.New("tls config is nil") } - requireDatagram := hyCtx.RequireDatagramFromContext(ctx) + datagram := DatagramFromContext(ctx) dest.Network = net.Network_UDP - config := streamSettings.ProtocolSettings.(*Config) initmanager.Do(func() { manager = &clientManager{ m: make(map[dialerConf]*client), } - (&task.Periodic{ - Interval: 30 * time.Second, - Execute: func() error { - manager.clean() - return nil - }, - }).Start() + go manager.clean() }) - manager.mutex.Lock() - c, ok := manager.m[dialerConf{Destination: dest, MemoryStreamConfig: streamSettings}] - if !ok { - c = &client{ - ctx: ctx, - dest: dest, - config: config, - tlsConfig: tlsConfig.GetTLSConfig(), - socketConfig: streamSettings.SocketSettings, - udpmaskManager: streamSettings.UdpmaskManager, - quicParams: streamSettings.QuicParams, - } - manager.m[dialerConf{Destination: dest, MemoryStreamConfig: streamSettings}] = c - } - c.setCtx(ctx) - manager.mutex.Unlock() + manager.RLock() + c := manager.m[dialerConf{dest, streamSettings}] + manager.RUnlock() - if requireDatagram { + if c == nil { + manager.Lock() + c = manager.m[dialerConf{dest, streamSettings}] + if c == nil { + c = &client{ + dest: dest, + config: streamSettings.ProtocolSettings.(*Config), + tlsConfig: tlsConfig.GetTLSConfig(), + socketConfig: streamSettings.SocketSettings, + udpmaskManager: streamSettings.UdpmaskManager, + quicParams: streamSettings.QuicParams, + } + manager.m[dialerConf{dest, streamSettings}] = c + } + manager.Unlock() + } + + if datagram { return c.udp() } return c.tcp() diff --git a/transport/internet/hysteria/hub.go b/transport/internet/hysteria/hub.go index 89e18dac..17088938 100644 --- a/transport/internet/hysteria/hub.go +++ b/transport/internet/hysteria/hub.go @@ -3,10 +3,10 @@ package hysteria import ( "context" gotls "crypto/tls" - "encoding/binary" "net/http" "net/http/httputil" "net/url" + "runtime" "strconv" "strings" "sync" @@ -20,158 +20,41 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/proxy/hysteria/account" - hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/hysteria/congestion" "github.com/xtls/xray-core/transport/internet/hysteria/congestion/bbr" "github.com/xtls/xray-core/transport/internet/tls" ) -type udpSessionManagerServer struct { - conn *quic.Conn - m map[uint32]*InterUdpConn - addConn internet.ConnHandler - stopCh chan struct{} - udpIdleTimeout time.Duration - mutex sync.RWMutex +type httpHandler struct { + sync.Mutex + validator *account.Validator + config *Config + masqHandler http.Handler + quicParams *internet.QuicParams + addConn internet.ConnHandler + conn *quic.Conn + + auth bool user *protocol.MemoryUser } -func (m *udpSessionManagerServer) close(udpConn *InterUdpConn) { - if !udpConn.closed { - udpConn.closed = true - close(udpConn.ch) - delete(m.m, udpConn.id) - } -} - -func (m *udpSessionManagerServer) clean() { - ticker := time.NewTicker(idleCleanupInterval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - m.mutex.RLock() - now := time.Now() - timeoutConn := make([]*InterUdpConn, 0, len(m.m)) - for _, udpConn := range m.m { - if now.Sub(udpConn.GetLast()) > m.udpIdleTimeout { - timeoutConn = append(timeoutConn, udpConn) - } - } - m.mutex.RUnlock() - - for _, udpConn := range timeoutConn { - m.mutex.Lock() - m.close(udpConn) - m.mutex.Unlock() - } - case <-m.stopCh: - return - } - } -} - -func (m *udpSessionManagerServer) run() { - for { - d, err := m.conn.ReceiveDatagram(context.Background()) - if err != nil { - break - } - - if len(d) < 4 { - continue - } - id := binary.BigEndian.Uint32(d[:4]) - - m.feed(id, d) - } - - m.mutex.Lock() - defer m.mutex.Unlock() - - close(m.stopCh) - - for _, udpConn := range m.m { - m.close(udpConn) - } -} - -func (m *udpSessionManagerServer) feed(id uint32, d []byte) { - m.mutex.RLock() - udpConn, ok := m.m[id] - if ok { - select { - case udpConn.ch <- d: - default: - } - m.mutex.RUnlock() - return - } - m.mutex.RUnlock() - - m.mutex.Lock() - defer m.mutex.Unlock() - - udpConn, ok = m.m[id] - if !ok { - udpConn = &InterUdpConn{ - conn: m.conn, - local: m.conn.LocalAddr(), - remote: m.conn.RemoteAddr(), - - id: id, - ch: make(chan []byte, udpMessageChanSize), - last: time.Now(), - - user: m.user, - } - udpConn.closeFunc = func() { - m.mutex.Lock() - m.close(udpConn) - m.mutex.Unlock() - } - m.m[id] = udpConn - m.addConn(udpConn) - } - - select { - case udpConn.ch <- d: - default: - } -} - -type httpHandler struct { - ctx context.Context - conn *quic.Conn - addConn internet.ConnHandler - - config *Config - quicParams *internet.QuicParams - validator *account.Validator - masqHandler http.Handler - - auth bool - mutex sync.Mutex - user *protocol.MemoryUser -} - -func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *httpHandler) AuthHTTP(w http.ResponseWriter, r *http.Request) bool { if r.Method == http.MethodPost && r.Host == URLHost && r.URL.Path == URLPath { - h.mutex.Lock() - defer h.mutex.Unlock() + h.Lock() + defer h.Unlock() if h.auth { - w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(hyCtx.RequireDatagramFromContext(h.ctx))) + w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(h.validator != nil)) w.Header().Set(CommonHeaderCCRX, strconv.FormatUint(h.quicParams.BrutalDown, 10)) - w.Header().Set(CommonHeaderPadding, authResponsePadding.String()) + w.Header().Set(CommonHeaderPadding, AuthResponsePadding.String()) w.WriteHeader(StatusAuthOK) - return + return true } auth := r.Header.Get(RequestHeaderAuth) - clientDown, _ := strconv.ParseUint(r.Header.Get(CommonHeaderCCRX), 10, 64) + down, _ := strconv.ParseUint(r.Header.Get(CommonHeaderCCRX), 10, 64) var user *protocol.MemoryUser var ok bool @@ -185,49 +68,51 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.auth = true h.user = user - switch h.quicParams.Congestion { + conn := h.conn + quicParams := h.quicParams + switch quicParams.Congestion { case "reno": - errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion reno") case "bbr": - errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr ", h.quicParams.BbrProfile) - congestion.UseBBR(h.conn, bbr.Profile(h.quicParams.BbrProfile)) - case "brutal", "": - if h.quicParams.BrutalUp == 0 || clientDown == 0 { - errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr ", h.quicParams.BbrProfile) - congestion.UseBBR(h.conn, bbr.Profile(h.quicParams.BbrProfile)) + congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile)) + case "", "brutal": + if quicParams.BrutalUp == 0 || down == 0 { + congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile)) } else { - errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion brutal bytes per second ", min(h.quicParams.BrutalUp, clientDown)) - congestion.UseBrutal(h.conn, min(h.quicParams.BrutalUp, clientDown)) + congestion.UseBrutal(conn, min(quicParams.BrutalUp, down)) } case "force-brutal": - errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion brutal bytes per second ", h.quicParams.BrutalUp) - congestion.UseBrutal(h.conn, h.quicParams.BrutalUp) + congestion.UseBrutal(conn, quicParams.BrutalUp) default: - errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion reno") + panic(quicParams.Congestion) } - if hyCtx.RequireDatagramFromContext(h.ctx) { - udpSM := &udpSessionManagerServer{ - conn: h.conn, - m: make(map[uint32]*InterUdpConn), - addConn: h.addConn, - stopCh: make(chan struct{}), - udpIdleTimeout: time.Duration(h.config.UdpIdleTimeout) * time.Second, + if h.validator != nil { + udpSM := &udpSessionManager{ + conn: h.conn, + m: make(map[uint32]*InterConn), - user: h.user, + addConn: h.addConn, + udpIdleTimeout: time.Duration(h.config.UdpIdleTimeout) * time.Second, + user: h.user, } go udpSM.clean() go udpSM.run() } - w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(hyCtx.RequireDatagramFromContext(h.ctx))) + w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(h.validator != nil)) w.Header().Set(CommonHeaderCCRX, strconv.FormatUint(h.quicParams.BrutalDown, 10)) - w.Header().Set(CommonHeaderPadding, authResponsePadding.String()) + w.Header().Set(CommonHeaderPadding, AuthResponsePadding.String()) w.WriteHeader(StatusAuthOK) - return + return true } } + return false +} +func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h.AuthHTTP(w, r) { + return + } h.masqHandler.ServeHTTP(w, r) } @@ -256,42 +141,41 @@ func (h *httpHandler) StreamDispatcher(ft http3.FrameType, stream *quic.Stream, } type Listener struct { - ctx context.Context - pktConn net.PacketConn - listener *quic.Listener - addConn internet.ConnHandler - - config *Config - quicParams *internet.QuicParams validator *account.Validator + config *Config masqHandler http.Handler + quicParams *internet.QuicParams + addConn internet.ConnHandler + + pktConn net.PacketConn + tr *quic.Transport + listener *quic.Listener } func (l *Listener) handleClient(conn *quic.Conn) { handler := &httpHandler{ - ctx: l.ctx, - conn: conn, - addConn: l.addConn, - - config: l.config, - quicParams: l.quicParams, validator: l.validator, + config: l.config, masqHandler: l.masqHandler, + quicParams: l.quicParams, + addConn: l.addConn, + conn: conn, } - h3 := http3.Server{ + h3s := http3.Server{ Handler: handler, StreamDispatcher: handler.StreamDispatcher, } - err := h3.ServeQUICConn(conn) + _ = h3s.ServeQUICConn(conn) _ = conn.CloseWithError(closeErrCodeOK, "") - errors.LogDebug(context.Background(), conn.RemoteAddr(), " disconnected with err ", err) } func (l *Listener) keepAccepting() { for { conn, err := l.listener.Accept(context.Background()) if err != nil { - errors.LogInfoInner(context.Background(), err, "failed to accept QUIC connection") + if err != quic.ErrServerClosed { + errors.LogErrorInner(context.Background(), err, "failed to serve hysteria") + } break } go l.handleClient(conn) @@ -303,9 +187,7 @@ func (l *Listener) Addr() net.Addr { } func (l *Listener) Close() error { - err := l.listener.Close() - _ = l.pktConn.Close() - return err + return errors.Combine(l.listener.Close(), l.tr.Close(), l.pktConn.Close()) } func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) { @@ -318,11 +200,10 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti return nil, errors.New("tls config is nil") } + validator := ValidatorFromContext(ctx) config := streamSettings.ProtocolSettings.(*Config) - validator := hyCtx.ValidatorFromContext(ctx) - - if config.Auth == "" && validator == nil { + if validator == nil && config.Auth == "" { return nil, errors.New("validator is nil") } @@ -372,22 +253,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti return nil, errors.New("unknown masq type") } - raw, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{IP: address.IP(), Port: int(port)}, streamSettings.SocketSettings) - if err != nil { - return nil, err - } - - var pktConn net.PacketConn - pktConn = raw - - if streamSettings.UdpmaskManager != nil { - pktConn, err = streamSettings.UdpmaskManager.WrapPacketConnServer(raw) - if err != nil { - raw.Close() - return nil, errors.New("mask err").Base(err) - } - } - quicParams := streamSettings.QuicParams if quicParams == nil { quicParams = &internet.QuicParams{ @@ -403,9 +268,10 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti MaxConnectionReceiveWindow: quicParams.MaxConnReceiveWindow, MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second, MaxIncomingStreams: quicParams.MaxIncomingStreams, - DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery, + DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery || (runtime.GOOS != "linux" && runtime.GOOS != "windows" && runtime.GOOS != "darwin"), EnableDatagrams: true, MaxDatagramFrameSize: MaxDatagramFrameSize, + AssumePeerMaxDatagramFrameSize: MaxDatagramFrameSize, DisablePathManager: true, } if quicParams.InitStreamReceiveWindow == 0 { @@ -427,27 +293,44 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti quicConfig.MaxIncomingStreams = 1024 } - qListener, err := quic.Listen(pktConn, tlsConfig.GetTLSConfig(), quicConfig) + pktConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{IP: address.IP(), Port: int(port)}, streamSettings.SocketSettings) if err != nil { + return nil, err + } + + if streamSettings.UdpmaskManager != nil { + newConn, err := streamSettings.UdpmaskManager.WrapPacketConnServer(pktConn) + if err != nil { + pktConn.Close() + return nil, errors.New("mask err").Base(err) + } + pktConn = newConn + } + + tr := &quic.Transport{Conn: pktConn} + + listener, err := tr.Listen(tlsConfig.GetTLSConfig(), quicConfig) + if err != nil { + _ = tr.Close() _ = pktConn.Close() return nil, err } - listener := &Listener{ - ctx: ctx, - pktConn: pktConn, - listener: qListener, - addConn: handler, - - config: config, - quicParams: quicParams, + l := &Listener{ validator: validator, + config: config, masqHandler: masqHandler, + quicParams: quicParams, + addConn: handler, + + pktConn: pktConn, + tr: tr, + listener: listener, } - go listener.keepAccepting() + go l.keepAccepting() - return listener, nil + return l, nil } func init() { diff --git a/transport/internet/hysteria/padding/padding.go b/transport/internet/hysteria/padding/padding.go deleted file mode 100644 index b134601e..00000000 --- a/transport/internet/hysteria/padding/padding.go +++ /dev/null @@ -1,24 +0,0 @@ -package padding - -import ( - "math/rand" -) - -const ( - paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" -) - -// padding specifies a half-open range [Min, Max). -type Padding struct { - Min int - Max int -} - -func (p Padding) String() string { - n := p.Min + rand.Intn(p.Max-p.Min) - bs := make([]byte, n) - for i := range bs { - bs[i] = paddingChars[rand.Intn(len(paddingChars))] - } - return string(bs) -} diff --git a/transport/internet/hysteria/udphop/addr.go b/transport/internet/hysteria/udphop/addr.go deleted file mode 100644 index 70dae2a2..00000000 --- a/transport/internet/hysteria/udphop/addr.go +++ /dev/null @@ -1,65 +0,0 @@ -package udphop - -import ( - "fmt" - "net" -) - -type InvalidPortError struct { - PortStr string -} - -func (e InvalidPortError) Error() string { - return fmt.Sprintf("%s is not a valid port number or range", e.PortStr) -} - -// UDPHopAddr contains an IP address and a list of ports. -type UDPHopAddr struct { - IP net.IP - Ports []uint32 - PortStr string -} - -func (a *UDPHopAddr) Network() string { - return "udphop" -} - -func (a *UDPHopAddr) String() string { - return net.JoinHostPort(a.IP.String(), a.PortStr) -} - -// addrs returns a list of net.Addr's, one for each port. -func (a *UDPHopAddr) addrs() ([]net.Addr, error) { - var addrs []net.Addr - for _, port := range a.Ports { - addr := &net.UDPAddr{ - IP: a.IP, - Port: int(port), - } - addrs = append(addrs, addr) - } - return addrs, nil -} - -// func ResolveUDPHopAddr(addr string) (*UDPHopAddr, error) { -// host, portStr, err := net.SplitHostPort(addr) -// if err != nil { -// return nil, err -// } -// ip, err := net.ResolveIPAddr("ip", host) -// if err != nil { -// return nil, err -// } -// result := &UDPHopAddr{ -// IP: ip.IP, -// PortStr: portStr, -// } - -// pu := utils.ParsePortUnion(portStr) -// if pu == nil { -// return nil, InvalidPortError{portStr} -// } -// result.Ports = pu.Ports() - -// return result, nil -// } diff --git a/transport/internet/hysteria/udphop/conn.go b/transport/internet/hysteria/udphop/conn.go index 50dcc36d..e60755f9 100644 --- a/transport/internet/hysteria/udphop/conn.go +++ b/transport/internet/hysteria/udphop/conn.go @@ -8,7 +8,6 @@ import ( "syscall" "time" - "github.com/xtls/xray-core/common/crypto" "github.com/xtls/xray-core/transport/internet/finalmask" ) @@ -20,19 +19,19 @@ const ( ) type UdpHopPacketConn struct { - Addr net.Addr Addrs []net.Addr - HopIntervalMin int64 - HopIntervalMax int64 - ListenUDPFunc ListenUDPFunc + HopIntervalMin time.Duration + HopIntervalMax time.Duration + ListenUDPFunc func(addr *net.UDPAddr) (net.PacketConn, error) connMutex sync.RWMutex prevConn net.PacketConn currentConn net.PacketConn addrIndex int - readBufferSize int - writeBufferSize int + deadline time.Time + readDeadline time.Time + writeDeadline time.Time recvQueue chan *udpPacket closeChan chan struct{} @@ -48,41 +47,36 @@ type udpPacket struct { Err error } -type ListenUDPFunc = func(*net.UDPAddr) (net.PacketConn, error) - -func NewUDPHopPacketConn(addr *UDPHopAddr, index int, intervalMin int64, intervalMax int64, listenUDPFunc ListenUDPFunc, pktConn net.PacketConn) (net.PacketConn, error) { - if intervalMin == 0 || intervalMax == 0 { - intervalMin = int64(defaultHopInterval) - intervalMax = int64(defaultHopInterval) +func NewUDPHopPacketConn(addrs []net.Addr, hopIntervalMin time.Duration, hopIntervalMax time.Duration, listenUDPFunc func(addr *net.UDPAddr) (net.PacketConn, error)) (net.PacketConn, error) { + if len(addrs) == 0 { + panic("len(addrs) == 0") } - if intervalMin < 5 || intervalMax < 5 { - return nil, errors.New("hop interval must be at least 5 seconds") + if hopIntervalMin == 0 { + hopIntervalMin = defaultHopInterval + } + if hopIntervalMax == 0 { + hopIntervalMax = defaultHopInterval + } + if hopIntervalMin < 5*time.Second { + panic("hopIntervalMin < 5*time.Second") + } + if hopIntervalMax < 5*time.Second { + panic("hopIntervalMax < 5*time.Second") + } + if hopIntervalMax < hopIntervalMin { + panic("hopIntervalMax < hopIntervalMin") } - // if listenUDPFunc == nil { - // listenUDPFunc = func() (net.PacketConn, error) { - // return net.ListenUDP("udp", nil) - // } - // } if listenUDPFunc == nil { - return nil, errors.New("nil listenUDPFunc") + panic("listenUDPFunc is nil") } - addrs, err := addr.addrs() - if err != nil { - return nil, err - } - // curConn, err := listenUDPFunc() - // if err != nil { - // return nil, err - // } hConn := &UdpHopPacketConn{ - Addr: addr, Addrs: addrs, - HopIntervalMin: intervalMin, - HopIntervalMax: intervalMax, + HopIntervalMin: hopIntervalMin, + HopIntervalMax: hopIntervalMax, ListenUDPFunc: listenUDPFunc, prevConn: nil, - currentConn: pktConn, - addrIndex: index, + currentConn: nil, + addrIndex: rand.Intn(len(addrs)), recvQueue: make(chan *udpPacket, packetQueueSize), closeChan: make(chan struct{}), bufPool: sync.Pool{ @@ -91,7 +85,12 @@ func NewUDPHopPacketConn(addr *UDPHopAddr, index int, intervalMin int64, interva }, }, } - go hConn.recvLoop(pktConn) + var err error + hConn.currentConn, err = listenUDPFunc(hConn.Addrs[hConn.addrIndex].(*net.UDPAddr)) + if err != nil { + return nil, err + } + go hConn.recvLoop(hConn.currentConn) go hConn.hopLoop() return hConn, nil } @@ -104,69 +103,64 @@ func (u *UdpHopPacketConn) recvLoop(conn net.PacketConn) { u.bufPool.Put(buf) var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - // Only pass through timeout errors here, not permanent errors - // like connection closed. Connection close is normal as we close - // the old connection to exit this loop every time we hop. u.recvQueue <- &udpPacket{nil, 0, nil, netErr} + continue } return } select { case u.recvQueue <- &udpPacket{buf, n, addr, nil}: - // Packet successfully queued default: - // Queue is full, drop the packet u.bufPool.Put(buf) } } } func (u *UdpHopPacketConn) hopLoop() { - ticker := time.NewTicker(time.Duration(crypto.RandBetween(u.HopIntervalMin, u.HopIntervalMax)) * time.Second) - defer ticker.Stop() + timer := time.NewTimer(u.nextHopInterval()) + defer timer.Stop() for { select { - case <-ticker.C: + case <-timer.C: u.hop() - ticker.Reset(time.Duration(crypto.RandBetween(u.HopIntervalMin, u.HopIntervalMax)) * time.Second) + timer.Reset(u.nextHopInterval()) case <-u.closeChan: return } } } +func (u *UdpHopPacketConn) nextHopInterval() time.Duration { + if u.HopIntervalMin == u.HopIntervalMax { + return u.HopIntervalMin + } + return u.HopIntervalMin + time.Duration(rand.Int63n(int64(u.HopIntervalMax-u.HopIntervalMin)+1)) +} + func (u *UdpHopPacketConn) hop() { u.connMutex.Lock() defer u.connMutex.Unlock() if u.closed { return } - // Update addrIndex to a new random value u.addrIndex = rand.Intn(len(u.Addrs)) newConn, err := u.ListenUDPFunc(u.Addrs[u.addrIndex].(*net.UDPAddr)) if err != nil { - // Could be temporary, just skip this hop return } - // We need to keep receiving packets from the previous connection, - // because otherwise there will be packet loss due to the time gap - // between we hop to a new port and the server acknowledges this change. - // So we do the following: - // Close prevConn, - // move currentConn to prevConn, - // set newConn as currentConn, - // start recvLoop on newConn. if u.prevConn != nil { - _ = u.prevConn.Close() // recvLoop for this conn will exit + _ = u.prevConn.Close() } u.prevConn = u.currentConn u.currentConn = newConn - // Set buffer sizes if previously set - if u.readBufferSize > 0 { - _ = trySetReadBuffer(u.currentConn, u.readBufferSize) + if !u.deadline.IsZero() { + _ = u.currentConn.SetDeadline(u.deadline) } - if u.writeBufferSize > 0 { - _ = trySetWriteBuffer(u.currentConn, u.writeBufferSize) + if !u.readDeadline.IsZero() { + _ = u.currentConn.SetReadDeadline(u.readDeadline) + } + if !u.writeDeadline.IsZero() { + _ = u.currentConn.SetWriteDeadline(u.writeDeadline) } go u.recvLoop(newConn) } @@ -178,11 +172,9 @@ func (u *UdpHopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) if p.Err != nil { return 0, nil, p.Err } - // Currently we do not check whether the packet is from - // the server or not due to performance reasons. n := copy(b, p.Buf[:p.N]) u.bufPool.Put(p.Buf) - return n, u.Addr, nil + return n, p.Addr, nil case <-u.closeChan: return 0, nil, net.ErrClosed } @@ -195,8 +187,6 @@ func (u *UdpHopPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { if u.closed { return 0, net.ErrClosed } - // Skip the check for now, always write to the server, - // for the same reason as in ReadFrom. return u.currentConn.WriteTo(b, u.Addrs[u.addrIndex]) } @@ -206,16 +196,13 @@ func (u *UdpHopPacketConn) Close() error { if u.closed { return nil } - // Close prevConn and currentConn - // Close closeChan to unblock ReadFrom & hopLoop - // Set closed flag to true to prevent double close if u.prevConn != nil { _ = u.prevConn.Close() } err := u.currentConn.Close() close(u.closeChan) u.closed = true - u.Addrs = nil // For GC + u.Addrs = nil return err } @@ -226,8 +213,11 @@ func (u *UdpHopPacketConn) LocalAddr() net.Addr { } func (u *UdpHopPacketConn) SetDeadline(t time.Time) error { - u.connMutex.RLock() - defer u.connMutex.RUnlock() + u.connMutex.Lock() + defer u.connMutex.Unlock() + u.deadline = t + u.readDeadline = t + u.writeDeadline = t if u.prevConn != nil { _ = u.prevConn.SetDeadline(t) } @@ -235,8 +225,10 @@ func (u *UdpHopPacketConn) SetDeadline(t time.Time) error { } func (u *UdpHopPacketConn) SetReadDeadline(t time.Time) error { - u.connMutex.RLock() - defer u.connMutex.RUnlock() + u.connMutex.Lock() + defer u.connMutex.Unlock() + u.deadline = time.Time{} + u.readDeadline = t if u.prevConn != nil { _ = u.prevConn.SetReadDeadline(t) } @@ -244,36 +236,16 @@ func (u *UdpHopPacketConn) SetReadDeadline(t time.Time) error { } func (u *UdpHopPacketConn) SetWriteDeadline(t time.Time) error { - u.connMutex.RLock() - defer u.connMutex.RUnlock() + u.connMutex.Lock() + defer u.connMutex.Unlock() + u.deadline = time.Time{} + u.writeDeadline = t if u.prevConn != nil { _ = u.prevConn.SetWriteDeadline(t) } return u.currentConn.SetWriteDeadline(t) } -// UDP-specific methods below - -func (u *UdpHopPacketConn) SetReadBuffer(bytes int) error { - u.connMutex.Lock() - defer u.connMutex.Unlock() - u.readBufferSize = bytes - if u.prevConn != nil { - _ = trySetReadBuffer(u.prevConn, bytes) - } - return trySetReadBuffer(u.currentConn, bytes) -} - -func (u *UdpHopPacketConn) SetWriteBuffer(bytes int) error { - u.connMutex.Lock() - defer u.connMutex.Unlock() - u.writeBufferSize = bytes - if u.prevConn != nil { - _ = trySetWriteBuffer(u.prevConn, bytes) - } - return trySetWriteBuffer(u.currentConn, bytes) -} - func (u *UdpHopPacketConn) SyscallConn() (syscall.RawConn, error) { u.connMutex.RLock() defer u.connMutex.RUnlock() @@ -284,22 +256,14 @@ func (u *UdpHopPacketConn) SyscallConn() (syscall.RawConn, error) { return sc.SyscallConn() } -func trySetReadBuffer(pc net.PacketConn, bytes int) error { - sc, ok := pc.(interface { - SetReadBuffer(bytes int) error - }) - if ok { - return sc.SetReadBuffer(bytes) +func ToAddrs(ip net.IP, ports []uint32) []net.Addr { + var addrs []net.Addr + for _, port := range ports { + addr := &net.UDPAddr{ + IP: ip, + Port: int(port), + } + addrs = append(addrs, addr) } - return nil -} - -func trySetWriteBuffer(pc net.PacketConn, bytes int) error { - sc, ok := pc.(interface { - SetWriteBuffer(bytes int) error - }) - if ok { - return sc.SetWriteBuffer(bytes) - } - return nil + return addrs } diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index 8586d3e6..3690f9e0 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -57,41 +57,27 @@ func DialKCP(ctx context.Context, dest net.Destination, streamSettings *internet } if streamSettings.UdpmaskManager != nil { + var pktConn net.PacketConn + var udpAddr = conn.RemoteAddr().(*net.UDPAddr) switch c := conn.(type) { case *internet.PacketConnWrapper: - pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c.PacketConn) - if err != nil { - conn.Close() - return nil, errors.New("mask err").Base(err) - } - c.PacketConn = pktConn + pktConn = c.PacketConn case *net.UDPConn: - pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c) - if err != nil { - conn.Close() - return nil, errors.New("mask err").Base(err) - } - conn = &internet.PacketConnWrapper{ - PacketConn: pktConn, - Dest: c.RemoteAddr().(*net.UDPAddr), - } + pktConn = c case *cnc.Connection: - fakeConn := &internet.FakePacketConn{Conn: c} - pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(fakeConn) - if err != nil { - conn.Close() - return nil, errors.New("mask err").Base(err) - } - conn = &internet.PacketConnWrapper{ - PacketConn: pktConn, - Dest: &net.UDPAddr{ - IP: []byte{0, 0, 0, 0}, - Port: 0, - }, - } + pktConn = &internet.FakePacketConn{Conn: c} default: - conn.Close() - return nil, errors.New("unknown conn ", reflect.TypeOf(c)) + panic(reflect.TypeOf(c)) + } + newConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(pktConn) + if err != nil { + pktConn.Close() + return nil, errors.New("mask err").Base(err) + } + pktConn = newConn + conn = &internet.PacketConnWrapper{ + PacketConn: pktConn, + Dest: udpAddr, } } diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 19b5c3f3..e85bd7b3 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -5,11 +5,11 @@ import ( gotls "crypto/tls" "fmt" "io" - "math/rand" "net/http" "net/http/httptrace" "net/url" reflect "reflect" + "runtime" "strconv" "sync" "sync/atomic" @@ -21,6 +21,7 @@ import ( "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/net/cnc" "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/common/uuid" "github.com/xtls/xray-core/transport/internet" @@ -173,7 +174,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second, KeepAlivePeriod: time.Duration(quicParams.KeepAlivePeriod) * time.Second, MaxIncomingStreams: quicParams.MaxIncomingStreams, - DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery, + DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery || (runtime.GOOS != "linux" && runtime.GOOS != "windows" && runtime.GOOS != "darwin"), } if quicParams.MaxIdleTimeout == 0 { quicConfig.MaxIdleTimeout = net.ConnIdleTimeout @@ -194,110 +195,83 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea QUICConfig: quicConfig, TLSClientConfig: gotlsConfig, Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (*quic.Conn, error) { - udphopDialer := func(addr *net.UDPAddr) (net.PacketConn, error) { + udpHopDialer := func(addr *net.UDPAddr) (net.PacketConn, error) { conn, err := internet.DialSystem(ctx, net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)), streamSettings.SocketSettings) if err != nil { - errors.LogDebug(context.Background(), "skip hop: failed to dial to dest") - conn.Close() - return nil, errors.New() + errors.LogInfoInner(context.Background(), err, "skip hop: failed to dial to dest") + return nil, errors.New("failed to dial to dest").Base(err) } - var udpConn net.PacketConn + var pktConn net.PacketConn switch c := conn.(type) { case *internet.PacketConnWrapper: - udpConn = c.PacketConn + pktConn = c.PacketConn case *net.UDPConn: - udpConn = c + pktConn = c default: - errors.LogDebug(context.Background(), "skip hop: udphop requires being at the outermost level ", reflect.TypeOf(c)) + errors.LogInfo(context.Background(), "skip hop: invalid conn ", reflect.TypeOf(c)) conn.Close() - return nil, errors.New() + return nil, errors.New("invalid conn ", reflect.TypeOf(c)) } - return udpConn, nil + return pktConn, nil } - var index int - if len(quicParams.UdpHop.Ports) > 0 { - index = rand.Intn(len(quicParams.UdpHop.Ports)) - dest.Port = net.Port(quicParams.UdpHop.Ports[index]) - } - - conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) + var pktConn net.PacketConn + var udpAddr *net.UDPAddr + var err error + udpAddr, err = net.ResolveUDPAddr("udp", dest.NetAddr()) if err != nil { return nil, err } - - var udpConn net.PacketConn - var udpAddr *net.UDPAddr - - switch c := conn.(type) { - case *internet.PacketConnWrapper: - udpConn = c.PacketConn - udpAddr, err = net.ResolveUDPAddr("udp", c.Dest.String()) - if err != nil { - conn.Close() - return nil, err - } - case *net.UDPConn: - udpConn = c - udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String()) - if err != nil { - conn.Close() - return nil, err - } - default: - udpConn = &internet.FakePacketConn{Conn: c} - udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String()) - if err != nil { - conn.Close() - return nil, err - } - - if len(quicParams.UdpHop.Ports) > 0 { - conn.Close() - return nil, errors.New("udphop requires being at the outermost level ", reflect.TypeOf(c)) - } - } - if len(quicParams.UdpHop.Ports) > 0 { - addr := &udphop.UDPHopAddr{ - IP: udpAddr.IP, - Ports: quicParams.UdpHop.Ports, - } - udpConn, err = udphop.NewUDPHopPacketConn(addr, index, quicParams.UdpHop.IntervalMin, quicParams.UdpHop.IntervalMax, udphopDialer, udpConn) + pktConn, err = udphop.NewUDPHopPacketConn(udphop.ToAddrs(udpAddr.IP, quicParams.UdpHop.Ports), time.Duration(quicParams.UdpHop.IntervalMin)*time.Second, time.Duration(quicParams.UdpHop.IntervalMax)*time.Second, udpHopDialer) if err != nil { - conn.Close() - return nil, errors.New("udphop err").Base(err) + return nil, err + } + } else { + conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) + if err != nil { + return nil, err + } + switch c := conn.(type) { + case *internet.PacketConnWrapper: + pktConn = c.PacketConn + case *net.UDPConn: + pktConn = c + case *cnc.Connection: + pktConn = &internet.FakePacketConn{Conn: c} + default: + panic(reflect.TypeOf(c)) } } if streamSettings.UdpmaskManager != nil { - udpConn, err = streamSettings.UdpmaskManager.WrapPacketConnClient(udpConn) + newConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(pktConn) if err != nil { - conn.Close() + pktConn.Close() return nil, errors.New("mask err").Base(err) } + pktConn = newConn } - quicConn, err := quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg) + conn, err := quic.DialEarly(ctx, pktConn, udpAddr, tlsCfg, cfg) if err != nil { return nil, err } switch quicParams.Congestion { - case "force-brutal": - errors.LogDebug(context.Background(), quicConn.RemoteAddr(), " ", "congestion brutal bytes per second ", quicParams.BrutalUp) - congestion.UseBrutal(quicConn, quicParams.BrutalUp) case "reno": - errors.LogDebug(context.Background(), quicConn.RemoteAddr(), " ", "congestion reno") + case "", "bbr": + congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile)) + case "force-brutal": + congestion.UseBrutal(conn, quicParams.BrutalUp) default: - errors.LogDebug(context.Background(), quicConn.RemoteAddr(), " ", "congestion bbr ", quicParams.BbrProfile) - congestion.UseBBR(quicConn, bbr.Profile(quicParams.BbrProfile)) + panic(quicParams.Congestion) } - return quicConn, nil + return conn, nil }, } } else if httpVersion == "2" { diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 8b281457..10b6005e 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "runtime" "slices" "strconv" "strings" @@ -440,7 +441,7 @@ type Listener struct { server http.Server h3server *http3.Server listener net.Listener - h3listener *quic.EarlyListener + h3listener Qface config *Config addConn internet.ConnHandler isH3 bool @@ -487,12 +488,12 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err) } if streamSettings.UdpmaskManager != nil { - pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnServer(Conn) + newConn, err := streamSettings.UdpmaskManager.WrapPacketConnServer(Conn) if err != nil { Conn.Close() return nil, errors.New("mask err").Base(err) } - Conn = pktConn + Conn = newConn } quicParams := streamSettings.QuicParams @@ -510,13 +511,17 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet MaxConnectionReceiveWindow: quicParams.MaxConnReceiveWindow, MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second, MaxIncomingStreams: quicParams.MaxIncomingStreams, - DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery, + DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery || (runtime.GOOS != "linux" && runtime.GOOS != "windows" && runtime.GOOS != "darwin"), } l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, quicConfig) if err != nil { return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err) } + l.h3listener = &QListener{ + Qface: l.h3listener, + quicParams: quicParams, + } errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port) handler.localAddr = l.h3listener.Addr() @@ -525,30 +530,8 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet Handler: handler, } go func() { - for { - conn, err := l.h3listener.Accept(context.Background()) - if err != nil { - errors.LogInfoInner(ctx, err, "XHTTP/3 listener closed") - return - } - - switch quicParams.Congestion { - case "force-brutal": - errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion brutal bytes per second ", quicParams.BrutalUp) - congestion.UseBrutal(conn, quicParams.BrutalUp) - case "reno": - errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion reno") - default: - errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion bbr ", quicParams.BbrProfile) - congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile)) - } - - go func() { - if err := l.h3server.ServeQUICConn(conn); err != nil { - errors.LogDebugInner(ctx, err, "XHTTP/3 connection ended") - } - _ = conn.CloseWithError(0, "") - }() + if err := l.h3server.ServeListener(l.h3listener); err != nil { + errors.LogErrorInner(ctx, err, "failed to serve HTTP/3 for XHTTP/3") } }() } else { // tcp @@ -614,10 +597,8 @@ func (ln *Listener) Addr() net.Addr { func (ln *Listener) Close() error { if ln.h3server != nil { if err := ln.h3server.Close(); err != nil { - _ = ln.h3listener.Close() return err } - return ln.h3listener.Close() } else if ln.listener != nil { return ln.listener.Close() } @@ -633,3 +614,33 @@ func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *gotls.Config { func init() { common.Must(internet.RegisterTransportListener(protocolName, ListenXH)) } + +type Qface interface { + Accept(ctx context.Context) (*quic.Conn, error) + Addr() net.Addr + Close() error +} + +var _ Qface = (*quic.EarlyListener)(nil) + +type QListener struct { + Qface + quicParams *internet.QuicParams +} + +func (l *QListener) Accept(ctx context.Context) (*quic.Conn, error) { + conn, err := l.Qface.Accept(ctx) + if err != nil { + return nil, err + } + switch l.quicParams.Congestion { + case "reno": + case "", "bbr": + congestion.UseBBR(conn, bbr.Profile(l.quicParams.BbrProfile)) + case "force-brutal": + congestion.UseBrutal(conn, l.quicParams.BrutalUp) + default: + panic(l.quicParams.Congestion) + } + return conn, nil +} diff --git a/transport/internet/udp/dialer.go b/transport/internet/udp/dialer.go index 06f4f474..79649ed7 100644 --- a/transport/internet/udp/dialer.go +++ b/transport/internet/udp/dialer.go @@ -25,48 +25,30 @@ func init() { } if streamSettings != nil && streamSettings.UdpmaskManager != nil { + var pktConn net.PacketConn + var udpAddr = conn.RemoteAddr().(*net.UDPAddr) switch c := conn.(type) { case *internet.PacketConnWrapper: - pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c.PacketConn) - if err != nil { - conn.Close() - return nil, errors.New("mask err").Base(err) - } - c.PacketConn = pktConn - errors.LogInfo(ctx, "finalmask udp dialer: wrapped existing PacketConnWrapper with ", reflect.TypeOf(pktConn)) + pktConn = c.PacketConn case *net.UDPConn: - pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c) - if err != nil { - conn.Close() - return nil, errors.New("mask err").Base(err) - } - conn = &internet.PacketConnWrapper{ - PacketConn: pktConn, - Dest: c.RemoteAddr().(*net.UDPAddr), - } - errors.LogInfo(ctx, "finalmask udp dialer: wrapped UDPConn with ", reflect.TypeOf(pktConn)) + pktConn = c case *cnc.Connection: - fakeConn := &internet.FakePacketConn{Conn: c} - pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(fakeConn) - if err != nil { - conn.Close() - return nil, errors.New("mask err").Base(err) - } - conn = &internet.PacketConnWrapper{ - PacketConn: pktConn, - Dest: &net.UDPAddr{ - IP: []byte{0, 0, 0, 0}, - Port: 0, - }, - } - errors.LogInfo(ctx, "finalmask udp dialer: wrapped cnc.Connection with ", reflect.TypeOf(pktConn)) + pktConn = &internet.FakePacketConn{Conn: c} default: - conn.Close() - return nil, errors.New("unknown conn ", reflect.TypeOf(c)) + panic(reflect.TypeOf(c)) + } + newConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(pktConn) + if err != nil { + pktConn.Close() + return nil, errors.New("mask err").Base(err) + } + pktConn = newConn + conn = &internet.PacketConnWrapper{ + PacketConn: pktConn, + Dest: udpAddr, } } - // TODO: handle dialer options return conn, nil })) }