LjhAUMEM
2026-05-02 20:27:27 +08:00
committed by GitHub
parent 52cf9ef5d6
commit 1d62941bd2
20 changed files with 845 additions and 1171 deletions

View File

@@ -18,9 +18,9 @@ import (
"github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/proxy/hysteria/account" hysteria_proxy "github.com/xtls/xray-core/proxy/hysteria"
hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/hysteria"
"github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/transport/internet/tcp" "github.com/xtls/xray-core/transport/internet/tcp"
"github.com/xtls/xray-core/transport/internet/udp" "github.com/xtls/xray-core/transport/internet/udp"
@@ -134,10 +134,8 @@ func (w *tcpWorker) Proxy() proxy.Inbound {
func (w *tcpWorker) Start() error { func (w *tcpWorker) Start() error {
ctx := context.Background() ctx := context.Background()
type HysteriaInboundValidator interface{ HysteriaInboundValidator() *account.Validator } if v, ok := w.proxy.(*hysteria_proxy.Server); ok {
if v, ok := w.proxy.(HysteriaInboundValidator); ok { ctx = hysteria.ContextWithValidator(ctx, v.HysteriaInboundValidator())
ctx = hyCtx.ContextWithRequireDatagram(ctx, true)
ctx = hyCtx.ContextWithValidator(ctx, v.HysteriaInboundValidator())
} }
hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) { hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) {

2
go.mod
View File

@@ -3,7 +3,7 @@ module github.com/xtls/xray-core
go 1.26 go 1.26
require ( require (
github.com/apernet/quic-go v0.59.1-0.20260330051153-c402ee641eb6 github.com/apernet/quic-go v0.59.1-0.20260425001925-6c6cc9bcb716
github.com/cloudflare/circl v1.6.3 github.com/cloudflare/circl v1.6.3
github.com/ghodss/yaml v1.0.1-0.20220118164431-d8423dcdf344 github.com/ghodss/yaml v1.0.1-0.20220118164431-d8423dcdf344
github.com/golang/mock v1.7.0-rc.1 github.com/golang/mock v1.7.0-rc.1

4
go.sum
View File

@@ -1,7 +1,7 @@
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/apernet/quic-go v0.59.1-0.20260330051153-c402ee641eb6 h1:cbF95uMsQwCwAzH2i8+2lNO2TReoELLuqeeMfyBjFbY= github.com/apernet/quic-go v0.59.1-0.20260425001925-6c6cc9bcb716 h1:J1O+xpLuJWkdYbw5JPGwBqIHs2J8tiEP7Py9lPqkN2I=
github.com/apernet/quic-go v0.59.1-0.20260330051153-c402ee641eb6/go.mod h1:Npbg8qBtAZlsAB3FWmqwlVh5jtVG6a4DlYsOylUpvzA= github.com/apernet/quic-go v0.59.1-0.20260425001925-6c6cc9bcb716/go.mod h1:Npbg8qBtAZlsAB3FWmqwlVh5jtVG6a4DlYsOylUpvzA=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8=

View File

@@ -17,7 +17,6 @@ import (
"github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core" "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx"
"github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/hysteria" "github.com/xtls/xray-core/transport/internet/hysteria"
@@ -56,7 +55,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
ob.CanSpliceCopy = 3 ob.CanSpliceCopy = 3
target := ob.Target target := ob.Target
conn, err := dialer.Dial(hyCtx.ContextWithRequireDatagram(ctx, target.Network == net.Network_UDP), c.server.Destination) conn, err := dialer.Dial(hysteria.ContextWithDatagram(ctx, target.Network == net.Network_UDP), c.server.Destination)
if err != nil { if err != nil {
return errors.New("failed to find an available destination").AtWarning().Base(err) return errors.New("failed to find an available destination").AtWarning().Base(err)
} }
@@ -118,7 +117,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
if target.Network == net.Network_UDP { if target.Network == net.Network_UDP {
iConn := stat.TryUnwrapStatsConn(conn) iConn := stat.TryUnwrapStatsConn(conn)
_, ok := iConn.(*hysteria.InterUdpConn) _, ok := iConn.(*hysteria.InterConn)
if !ok { if !ok {
return errors.New("udp requires hysteria udp transport") return errors.New("udp requires hysteria udp transport")
} }
@@ -127,8 +126,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
writer := &UDPWriter{ writer := &UDPWriter{
Writer: conn, writer: conn,
buf: make([]byte, MaxUDPSize),
addr: target.NetAddr(), addr: target.NetAddr(),
} }
@@ -143,8 +141,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
reader := &UDPReader{ reader := &UDPReader{
Reader: conn, reader: conn,
buf: make([]byte, MaxUDPSize),
df: &Defragger{}, df: &Defragger{},
} }
@@ -173,28 +170,22 @@ func init() {
} }
type UDPWriter struct { type UDPWriter struct {
Writer io.Writer writer io.Writer
buf []byte
addr string addr string
buf [buf.Size]byte
} }
func (w *UDPWriter) sendMsg(msg *UDPMessage) error { func (w *UDPWriter) SendMessage(msg *UDPMessage) error {
msgN := msg.Serialize(w.buf) msgN := msg.Serialize(w.buf[:])
if msgN < 0 { if msgN < 0 {
return nil return nil
} }
_, err := w.Writer.Write(w.buf[:msgN]) _, err := w.writer.Write(w.buf[:msgN])
return err return err
} }
func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
for { for i, b := range mb {
mb2, b := buf.SplitFirst(mb)
mb = mb2
if b == nil {
break
}
addr := w.addr addr := w.addr
if b.UDP != nil { if b.UDP != nil {
addr = b.UDP.NetAddr() addr = b.UDP.NetAddr()
@@ -209,22 +200,20 @@ func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
Data: b.Bytes(), Data: b.Bytes(),
} }
err := w.sendMsg(msg) err := w.SendMessage(msg)
var errTooLarge *quic.DatagramTooLargeError var errTooLarge *quic.DatagramTooLargeError
if go_errors.As(err, &errTooLarge) { if go_errors.As(err, &errTooLarge) {
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
fMsgs := FragUDPMessage(msg, int(errTooLarge.MaxDatagramPayloadSize)) fMsgs := FragUDPMessage(msg, int(errTooLarge.MaxDatagramPayloadSize))
for _, fMsg := range fMsgs { for _, fMsg := range fMsgs {
err := w.sendMsg(&fMsg) err := w.SendMessage(&fMsg)
if err != nil { if err != nil {
b.Release() buf.ReleaseMulti(mb[i:])
buf.ReleaseMulti(mb)
return err return err
} }
} }
} else if err != nil { } else if err != nil {
b.Release() buf.ReleaseMulti(mb[i:])
buf.ReleaseMulti(mb)
return err return err
} }
@@ -235,34 +224,21 @@ func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
} }
type UDPReader struct { type UDPReader struct {
Reader io.Reader reader io.Reader
buf []byte
df *Defragger df *Defragger
firstMsg *UDPMessage firstBuf *buf.Buffer
firstDest *net.Destination
} }
func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { func (r *UDPReader) ReadFrom(p []byte) (n int, addr *net.Destination, err error) {
if r.firstMsg != nil {
buffer := buf.New()
_, err := buffer.Write(r.firstMsg.Data)
if err != nil {
return nil, err
}
buffer.UDP = r.firstDest
r.firstMsg = nil
r.firstDest = nil
return buf.MultiBuffer{buffer}, nil
}
for { for {
n, err := r.Reader.Read(r.buf) var buf [hysteria.MaxDatagramFrameSize]byte
n, err := r.reader.Read(buf[:])
if err != nil { if err != nil {
return nil, err return 0, nil, err
} }
msg, err := ParseUDPMessage(r.buf[:n]) msg, err := ParseUDPMessage(buf[:n])
if err != nil { if err != nil {
continue continue
} }
@@ -274,17 +250,31 @@ func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
dest, err := net.ParseDestination("udp:" + dfMsg.Addr) dest, err := net.ParseDestination("udp:" + dfMsg.Addr)
if err != nil { if err != nil {
errors.LogDebug(context.Background(), dfMsg.Addr, " ParseDestination err ", err)
continue continue
} }
buffer := buf.New() if len(p) < len(dfMsg.Data) {
if _, err := buffer.Write(dfMsg.Data); err != nil { continue
}
return copy(p, dfMsg.Data), &dest, nil
}
}
func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
if r.firstBuf != nil {
mb := buf.MultiBuffer{r.firstBuf}
r.firstBuf = nil
return mb, nil
}
b := buf.New()
b.Resize(0, buf.Size)
n, addr, err := r.ReadFrom(b.Bytes())
if err != nil {
b.Release()
return nil, err return nil, err
} }
b.Resize(0, int32(n))
buffer.UDP = &dest b.UDP = addr
return buf.MultiBuffer{b}, nil
return buf.MultiBuffer{buffer}, nil
}
} }

View File

@@ -1,10 +1 @@
package hysteria package hysteria
import (
"github.com/xtls/xray-core/transport/internet/hysteria/padding"
)
var (
tcpRequestPadding = padding.Padding{Min: 64, Max: 512}
tcpResponsePadding = padding.Padding{Min: 128, Max: 1024}
)

View File

@@ -1,35 +0,0 @@
package ctx
import (
"context"
"github.com/xtls/xray-core/proxy/hysteria/account"
)
type key int
const (
requireDatagram key = iota
validator
)
func ContextWithRequireDatagram(ctx context.Context, udp bool) context.Context {
if !udp {
return ctx
}
return context.WithValue(ctx, requireDatagram, struct{}{})
}
func RequireDatagramFromContext(ctx context.Context) bool {
_, ok := ctx.Value(requireDatagram).(struct{})
return ok
}
func ContextWithValidator(ctx context.Context, v *account.Validator) context.Context {
return context.WithValue(ctx, validator, v)
}
func ValidatorFromContext(ctx context.Context) *account.Validator {
v, _ := ctx.Value(validator).(*account.Validator)
return v
}

View File

@@ -1,73 +0,0 @@
package hysteria
func FragUDPMessage(m *UDPMessage, maxSize int) []UDPMessage {
if m.Size() <= maxSize {
return []UDPMessage{*m}
}
fullPayload := m.Data
maxPayloadSize := maxSize - m.HeaderSize()
off := 0
fragID := uint8(0)
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
frags := make([]UDPMessage, fragCount)
for off < len(fullPayload) {
payloadSize := len(fullPayload) - off
if payloadSize > maxPayloadSize {
payloadSize = maxPayloadSize
}
frag := *m
frag.FragID = fragID
frag.FragCount = fragCount
frag.Data = fullPayload[off : off+payloadSize]
frags[fragID] = frag
off += payloadSize
fragID++
}
return frags
}
// Defragger handles the defragmentation of UDP messages.
// The current implementation can only handle one packet ID at a time.
// If another packet arrives before a packet has received all fragments
// in their entirety, any previous state is discarded.
type Defragger struct {
pktID uint16
frags []*UDPMessage
count uint8
size int // data size
}
func (d *Defragger) Feed(m *UDPMessage) *UDPMessage {
if m.FragCount <= 1 {
return m
}
if m.FragID >= m.FragCount {
// wtf is this?
return nil
}
if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) {
// new message, clear previous state
d.pktID = m.PacketID
d.frags = make([]*UDPMessage, m.FragCount)
d.frags[m.FragID] = m
d.count = 1
d.size = len(m.Data)
} else if d.frags[m.FragID] == nil {
d.frags[m.FragID] = m
d.count++
d.size += len(m.Data)
if int(d.count) == len(d.frags) {
// all fragments received, assemble
data := make([]byte, d.size)
off := 0
for _, frag := range d.frags {
off += copy(data[off:], frag.Data)
}
m.Data = data
m.FragID = 0
m.FragCount = 1
return m
}
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/apernet/quic-go/quicvarint" "github.com/apernet/quic-go/quicvarint"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/transport/internet/hysteria"
) )
const ( const (
@@ -17,8 +18,6 @@ const (
MaxMessageLength = 2048 MaxMessageLength = 2048
MaxPaddingLength = 4096 MaxPaddingLength = 4096
MaxUDPSize = 4096
maxVarInt1 = 63 maxVarInt1 = 63
maxVarInt2 = 16383 maxVarInt2 = 16383
maxVarInt4 = 1073741823 maxVarInt4 = 1073741823
@@ -62,7 +61,7 @@ func ReadTCPRequest(r io.Reader) (string, error) {
} }
func WriteTCPRequest(w io.Writer, addr string) error { func WriteTCPRequest(w io.Writer, addr string) error {
padding := tcpRequestPadding.String() padding := hysteria.TcpRequestPadding.String()
paddingLen := len(padding) paddingLen := len(padding)
addrLen := len(addr) addrLen := len(addr)
sz := int(quicvarint.Len(uint64(addrLen))) + addrLen + sz := int(quicvarint.Len(uint64(addrLen))) + addrLen +
@@ -122,7 +121,7 @@ func ReadTCPResponse(r io.Reader) (bool, string, error) {
} }
func WriteTCPResponse(w io.Writer, ok bool, msg string) error { func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
padding := tcpResponsePadding.String() padding := hysteria.TcpResponsePadding.String()
paddingLen := len(padding) paddingLen := len(padding)
msgLen := len(msg) msgLen := len(msg)
sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
@@ -247,3 +246,75 @@ func varintPut(b []byte, i uint64) int {
} }
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
} }
func FragUDPMessage(m *UDPMessage, maxSize int) []UDPMessage {
if m.Size() <= maxSize {
return []UDPMessage{*m}
}
fullPayload := m.Data
maxPayloadSize := maxSize - m.HeaderSize()
off := 0
fragID := uint8(0)
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
frags := make([]UDPMessage, fragCount)
for off < len(fullPayload) {
payloadSize := len(fullPayload) - off
if payloadSize > maxPayloadSize {
payloadSize = maxPayloadSize
}
frag := *m
frag.FragID = fragID
frag.FragCount = fragCount
frag.Data = fullPayload[off : off+payloadSize]
frags[fragID] = frag
off += payloadSize
fragID++
}
return frags
}
// Defragger handles the defragmentation of UDP messages.
// The current implementation can only handle one packet ID at a time.
// If another packet arrives before a packet has received all fragments
// in their entirety, any previous state is discarded.
type Defragger struct {
pktID uint16
frags []*UDPMessage
count uint8
size int // data size
}
func (d *Defragger) Feed(m *UDPMessage) *UDPMessage {
if m.FragCount <= 1 {
return m
}
if m.FragID >= m.FragCount {
// wtf is this?
return nil
}
if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) {
// new message, clear previous state
d.pktID = m.PacketID
d.frags = make([]*UDPMessage, m.FragCount)
d.frags[m.FragID] = m
d.count = 1
d.size = len(m.Data)
} else if d.frags[m.FragID] == nil {
d.frags[m.FragID] = m
d.count++
d.size += len(m.Data)
if int(d.count) == len(d.frags) {
// all fragments received, assemble
data := make([]byte, d.size)
off := 0
for _, frag := range d.frags {
off += copy(data[off:], frag.Data)
}
m.Data = data
m.FragID = 0
m.FragCount = 1
return m
}
}
return nil
}

View File

@@ -2,7 +2,6 @@ package hysteria
import ( import (
"context" "context"
"io"
"time" "time"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
@@ -91,54 +90,30 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
inbound.User = v.User() inbound.User = v.User()
} }
if _, ok := iConn.(*hysteria.InterUdpConn); ok { if _, ok := iConn.(*hysteria.InterConn); ok {
r := io.Reader(conn) reader := &UDPReader{
b := make([]byte, MaxUDPSize) reader: conn,
df := &Defragger{} df: &Defragger{},
var firstMsg *UDPMessage }
var firstDest net.Destination
for { b := buf.New()
n, err := r.Read(b) b.Resize(0, buf.Size)
n, addr, err := reader.ReadFrom(b.Bytes())
if err != nil { if err != nil {
b.Release()
return err return err
} }
b.Resize(0, int32(n))
b.UDP = addr
msg, err := ParseUDPMessage(b[:n]) reader.firstBuf = b
if err != nil {
continue
}
dfMsg := df.Feed(msg)
if dfMsg == nil {
continue
}
firstMsg = dfMsg
firstDest, err = net.ParseDestination("udp:" + firstMsg.Addr)
if err != nil {
errors.LogDebug(context.Background(), dfMsg.Addr, " ParseDestination err ", err)
continue
}
break
}
reader := &UDPReader{
Reader: r,
buf: b,
df: df,
firstMsg: firstMsg,
firstDest: &firstDest,
}
writer := &UDPWriter{ writer := &UDPWriter{
Writer: conn, writer: conn,
buf: make([]byte, MaxUDPSize), addr: addr.NetAddr(),
addr: firstMsg.Addr,
} }
return dispatcher.DispatchLink(ctx, firstDest, &transport.Link{ return dispatcher.DispatchLink(ctx, *addr, &transport.Link{
Reader: reader, Reader: reader,
Writer: writer, Writer: writer,
}) })

View File

@@ -1,45 +1,82 @@
package hysteria package hysteria
import ( import (
"context"
"math/rand"
"time" "time"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/proxy/hysteria/account"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/hysteria/padding"
) )
const ( const (
closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError
closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError
MaxDatagramFrameSize = 1200
URLHost = "hysteria" URLHost = "hysteria"
URLPath = "/auth" URLPath = "/auth"
RequestHeaderAuth = "Hysteria-Auth" RequestHeaderAuth = "Hysteria-Auth"
ResponseHeaderUDPEnabled = "Hysteria-UDP" ResponseHeaderUDPEnabled = "Hysteria-UDP"
CommonHeaderCCRX = "Hysteria-CC-RX" CommonHeaderCCRX = "Hysteria-CC-RX"
CommonHeaderPadding = "Hysteria-Padding" CommonHeaderPadding = "Hysteria-Padding"
StatusAuthOK = 233 StatusAuthOK = 233
udpMessageChanSize = 1024
FrameTypeTCPRequest = 0x401 FrameTypeTCPRequest = 0x401
MaxDatagramFrameSize = 1200
udpMessageChanSize = 1024
idleCleanupInterval = 1 * time.Second idleCleanupInterval = 1 * time.Second
) )
var ( const (
authRequestPadding = padding.Padding{Min: 256, Max: 2048} paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
authResponsePadding = padding.Padding{Min: 256, Max: 2048}
) )
type Status int type padding struct {
Min int
Max int
}
func (p padding) String() string {
n := p.Min + rand.Intn(p.Max-p.Min)
bs := make([]byte, n)
for i := range bs {
bs[i] = paddingChars[rand.Intn(len(paddingChars))]
}
return string(bs)
}
var (
AuthRequestPadding = padding{Min: 256, Max: 2048}
AuthResponsePadding = padding{Min: 256, Max: 2048}
TcpRequestPadding = padding{Min: 64, Max: 512}
TcpResponsePadding = padding{Min: 128, Max: 1024}
)
type datagramKey struct{}
func ContextWithDatagram(ctx context.Context, v bool) context.Context {
return context.WithValue(ctx, datagramKey{}, v)
}
func DatagramFromContext(ctx context.Context) bool {
v, _ := ctx.Value(datagramKey{}).(bool)
return v
}
type validatorKey struct{}
func ContextWithValidator(ctx context.Context, v *account.Validator) context.Context {
return context.WithValue(ctx, validatorKey{}, v)
}
func ValidatorFromContext(ctx context.Context) *account.Validator {
v, _ := ctx.Value(validatorKey{}).(*account.Validator)
return v
}
type status int
const ( const (
StatusUnknown Status = iota StatusNull status = iota
StatusActive StatusActive
StatusInactive StatusInactive
) )

View File

@@ -1,6 +1,7 @@
package hysteria package hysteria
import ( import (
"context"
"encoding/binary" "encoding/binary"
"io" "io"
"sync" "sync"
@@ -8,8 +9,10 @@ import (
"github.com/apernet/quic-go" "github.com/apernet/quic-go"
"github.com/apernet/quic-go/quicvarint" "github.com/apernet/quic-go/quicvarint"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/transport/internet"
) )
type interConn struct { type interConn struct {
@@ -18,144 +21,278 @@ type interConn struct {
remote net.Addr remote net.Addr
client bool client bool
mutex sync.Mutex
user *protocol.MemoryUser user *protocol.MemoryUser
} }
func (i *interConn) User() *protocol.MemoryUser { func (c *interConn) User() *protocol.MemoryUser {
return i.user return c.user
} }
func (i *interConn) Read(b []byte) (int, error) { func (c *interConn) Read(b []byte) (int, error) {
return i.stream.Read(b) return c.stream.Read(b)
} }
func (i *interConn) Write(b []byte) (int, error) { func (c *interConn) Write(b []byte) (int, error) {
if i.client { if c.client {
i.mutex.Lock() c.client = false
defer i.mutex.Unlock() if _, err := c.stream.Write(append(quicvarint.Append(nil, FrameTypeTCPRequest), b...)); err != nil {
if i.client {
buf := make([]byte, 0, quicvarint.Len(FrameTypeTCPRequest)+len(b))
buf = quicvarint.Append(buf, FrameTypeTCPRequest)
buf = append(buf, b...)
_, err := i.stream.Write(buf)
if err != nil {
return 0, err return 0, err
} }
i.client = false
return len(b), nil return len(b), nil
} }
return c.stream.Write(b)
} }
return i.stream.Write(b) func (c *interConn) Close() error {
c.stream.CancelRead(0)
return c.stream.Close()
} }
func (i *interConn) Close() error { func (c *interConn) LocalAddr() net.Addr {
i.stream.CancelRead(0) return c.local
return i.stream.Close()
} }
func (i *interConn) LocalAddr() net.Addr { func (c *interConn) RemoteAddr() net.Addr {
return i.local return c.remote
} }
func (i *interConn) RemoteAddr() net.Addr { func (c *interConn) SetDeadline(t time.Time) error {
return i.remote return c.stream.SetDeadline(t)
} }
func (i *interConn) SetDeadline(t time.Time) error { func (c *interConn) SetReadDeadline(t time.Time) error {
return i.stream.SetDeadline(t) return c.stream.SetReadDeadline(t)
} }
func (i *interConn) SetReadDeadline(t time.Time) error { func (c *interConn) SetWriteDeadline(t time.Time) error {
return i.stream.SetReadDeadline(t) return c.stream.SetWriteDeadline(t)
} }
func (i *interConn) SetWriteDeadline(t time.Time) error { type InterConn struct {
return i.stream.SetWriteDeadline(t)
}
type InterUdpConn struct {
conn *quic.Conn
local net.Addr local net.Addr
remote net.Addr remote net.Addr
id uint32 id uint32
ch chan []byte ch chan []byte
time time.Time
closed bool
closeFunc func()
last time.Time
mutex sync.Mutex mutex sync.Mutex
closed bool
write func(p []byte) error
close func()
user *protocol.MemoryUser user *protocol.MemoryUser
} }
func (i *InterUdpConn) User() *protocol.MemoryUser { func (i *InterConn) User() *protocol.MemoryUser {
return i.user return i.user
} }
func (i *InterUdpConn) SetLast() { func (c *InterConn) Time() time.Time {
i.mutex.Lock() c.mutex.Lock()
defer i.mutex.Unlock() v := c.time
c.mutex.Unlock()
i.last = time.Now() return v
} }
func (i *InterUdpConn) GetLast() time.Time { func (c *InterConn) Update() {
i.mutex.Lock() c.mutex.Lock()
defer i.mutex.Unlock() c.time = time.Now()
c.mutex.Unlock()
return i.last
} }
func (i *InterUdpConn) Read(p []byte) (int, error) { func (c *InterConn) Read(p []byte) (int, error) {
b, ok := <-i.ch b, ok := <-c.ch
if !ok { if !ok {
return 0, io.EOF return 0, io.EOF
} }
n := copy(p, b) if len(p) < len(b) {
if n != len(b) {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
c.Update()
i.SetLast() return copy(p, b), nil
return n, nil
} }
func (i *InterUdpConn) Write(p []byte) (int, error) { func (c *InterConn) Write(p []byte) (int, error) {
i.SetLast() if c.closed {
return 0, io.ErrClosedPipe
binary.BigEndian.PutUint32(p, i.id) }
if err := i.conn.SendDatagram(p); err != nil { binary.BigEndian.PutUint32(p, c.id)
if err := c.write(p); err != nil {
return 0, err return 0, err
} }
c.Update()
return len(p), nil return len(p), nil
} }
func (i *InterUdpConn) Close() error { func (c *InterConn) Close() error {
i.closeFunc() c.close()
return nil return nil
} }
func (i *InterUdpConn) LocalAddr() net.Addr { func (c *InterConn) LocalAddr() net.Addr {
return i.local return c.local
} }
func (i *InterUdpConn) RemoteAddr() net.Addr { func (c *InterConn) RemoteAddr() net.Addr {
return i.remote return c.remote
} }
func (i *InterUdpConn) SetDeadline(t time.Time) error { func (c *InterConn) SetDeadline(t time.Time) error {
return nil return nil
} }
func (i *InterUdpConn) SetReadDeadline(t time.Time) error { func (c *InterConn) SetReadDeadline(t time.Time) error {
return nil return nil
} }
func (i *InterUdpConn) SetWriteDeadline(t time.Time) error { func (c *InterConn) SetWriteDeadline(t time.Time) error {
return nil return nil
} }
type udpSessionManager struct {
sync.RWMutex
conn *quic.Conn
m map[uint32]*InterConn
next uint32
closed bool
addConn internet.ConnHandler
udpIdleTimeout time.Duration
user *protocol.MemoryUser
}
func (m *udpSessionManager) close(udpConn *InterConn) {
if !udpConn.closed {
udpConn.closed = true
close(udpConn.ch)
delete(m.m, udpConn.id)
}
}
func (m *udpSessionManager) clean() {
ticker := time.NewTicker(idleCleanupInterval)
defer ticker.Stop()
for range ticker.C {
if m.closed {
return
}
m.RLock()
now := time.Now()
timeoutConn := make([]*InterConn, 0, len(m.m))
for _, udpConn := range m.m {
if now.Sub(udpConn.Time()) > m.udpIdleTimeout {
timeoutConn = append(timeoutConn, udpConn)
}
}
m.RUnlock()
for _, udpConn := range timeoutConn {
m.Lock()
m.close(udpConn)
m.Unlock()
}
}
}
func (m *udpSessionManager) run() {
for {
d, err := m.conn.ReceiveDatagram(context.Background())
if err != nil {
break
}
if len(d) < 4 {
continue
}
id := binary.BigEndian.Uint32(d[:4])
m.feed(id, d)
}
m.Lock()
defer m.Unlock()
m.closed = true
for _, udpConn := range m.m {
m.close(udpConn)
}
}
func (m *udpSessionManager) udp() (*InterConn, error) {
m.Lock()
defer m.Unlock()
if m.closed {
return nil, errors.New("closed")
}
udpConn := &InterConn{
local: m.conn.LocalAddr(),
remote: m.conn.RemoteAddr(),
id: m.next,
ch: make(chan []byte, udpMessageChanSize),
}
udpConn.write = m.conn.SendDatagram
udpConn.close = func() {
m.Lock()
m.close(udpConn)
m.Unlock()
}
m.m[m.next] = udpConn
m.next++
return udpConn, nil
}
func (m *udpSessionManager) feed(id uint32, d []byte) {
m.RLock()
udpConn, ok := m.m[id]
if ok {
select {
case udpConn.ch <- d:
default:
}
m.RUnlock()
return
}
m.RUnlock()
if m.addConn == nil {
return
}
m.Lock()
defer m.Unlock()
udpConn, ok = m.m[id]
if !ok {
udpConn = &InterConn{
local: m.conn.LocalAddr(),
remote: m.conn.RemoteAddr(),
id: id,
ch: make(chan []byte, udpMessageChanSize),
time: time.Now(),
}
udpConn.write = m.conn.SendDatagram
udpConn.close = func() {
m.Lock()
m.close(udpConn)
m.Unlock()
}
udpConn.user = m.user
m.m[id] = udpConn
m.addConn(udpConn)
}
select {
case udpConn.ch <- d:
default:
}
}

View File

@@ -3,11 +3,10 @@ package hysteria
import ( import (
"context" "context"
go_tls "crypto/tls" go_tls "crypto/tls"
"encoding/binary"
"math/rand"
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
"runtime"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@@ -18,8 +17,6 @@ import (
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/net/cnc" "github.com/xtls/xray-core/common/net/cnc"
"github.com/xtls/xray-core/common/task"
hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/finalmask" "github.com/xtls/xray-core/transport/internet/finalmask"
"github.com/xtls/xray-core/transport/internet/hysteria/congestion" "github.com/xtls/xray-core/transport/internet/hysteria/congestion"
@@ -29,107 +26,25 @@ import (
"github.com/xtls/xray-core/transport/internet/tls" "github.com/xtls/xray-core/transport/internet/tls"
) )
type udpSessionManagerClient struct {
conn *quic.Conn
m map[uint32]*InterUdpConn
next uint32
closed bool
mutex sync.RWMutex
}
func (m *udpSessionManagerClient) close(udpConn *InterUdpConn) {
if !udpConn.closed {
udpConn.closed = true
close(udpConn.ch)
delete(m.m, udpConn.id)
}
}
func (m *udpSessionManagerClient) run() {
for {
d, err := m.conn.ReceiveDatagram(context.Background())
if err != nil {
break
}
if len(d) < 4 {
continue
}
id := binary.BigEndian.Uint32(d[:4])
m.feed(id, d)
}
m.mutex.Lock()
defer m.mutex.Unlock()
m.closed = true
for _, udpConn := range m.m {
m.close(udpConn)
}
}
func (m *udpSessionManagerClient) udp() (*InterUdpConn, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.closed {
return nil, errors.New("closed")
}
udpConn := &InterUdpConn{
conn: m.conn,
local: m.conn.LocalAddr(),
remote: m.conn.RemoteAddr(),
id: m.next,
ch: make(chan []byte, udpMessageChanSize),
}
udpConn.closeFunc = func() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.close(udpConn)
}
m.m[m.next] = udpConn
m.next++
return udpConn, nil
}
func (m *udpSessionManagerClient) feed(id uint32, d []byte) {
m.mutex.RLock()
defer m.mutex.RUnlock()
udpConn, ok := m.m[id]
if !ok {
return
}
select {
case udpConn.ch <- d:
default:
}
}
type client struct { type client struct {
ctx context.Context sync.Mutex
dest net.Destination dest net.Destination
pktConn net.PacketConn
conn *quic.Conn
config *Config config *Config
tlsConfig *go_tls.Config tlsConfig *go_tls.Config
socketConfig *internet.SocketConfig socketConfig *internet.SocketConfig
udpmaskManager *finalmask.UdpmaskManager udpmaskManager *finalmask.UdpmaskManager
quicParams *internet.QuicParams quicParams *internet.QuicParams
udpSM *udpSessionManagerClient conn *quic.Conn
mutex sync.Mutex tr *quic.Transport
pktConn net.PacketConn
udpSM *udpSessionManager
} }
func (c *client) status() Status { func (c *client) status() status {
if c.conn == nil { if c.conn == nil {
return StatusUnknown return StatusNull
} }
select { select {
case <-c.conn.Context().Done(): case <-c.conn.Context().Done():
@@ -140,10 +55,12 @@ func (c *client) status() Status {
} }
func (c *client) close() { func (c *client) close() {
_ = c.conn.CloseWithError(closeErrCodeOK, "") c.conn.CloseWithError(closeErrCodeOK, "")
_ = c.pktConn.Close() c.tr.Close()
c.pktConn = nil c.pktConn.Close()
c.conn = nil c.conn = nil
c.tr = nil
c.pktConn = nil
c.udpSM = nil c.udpSM = nil
} }
@@ -164,61 +81,6 @@ func (c *client) dial() error {
} }
} }
var index int
if len(quicParams.UdpHop.Ports) > 0 {
index = rand.Intn(len(quicParams.UdpHop.Ports))
c.dest.Port = net.Port(quicParams.UdpHop.Ports[index])
}
raw, err := internet.DialSystem(c.ctx, c.dest, c.socketConfig)
if err != nil {
return errors.New("failed to dial to dest").Base(err)
}
var pktConn net.PacketConn
var remote *net.UDPAddr
switch conn := raw.(type) {
case *internet.PacketConnWrapper:
pktConn = conn.PacketConn
remote = conn.RemoteAddr().(*net.UDPAddr)
case *net.UDPConn:
pktConn = conn
remote = conn.RemoteAddr().(*net.UDPAddr)
case *cnc.Connection:
fakeConn := &internet.FakePacketConn{Conn: conn}
pktConn = fakeConn
remote = fakeConn.RemoteAddr().(*net.UDPAddr)
if len(quicParams.UdpHop.Ports) > 0 {
raw.Close()
return errors.New("udphop requires being at the outermost level")
}
default:
raw.Close()
return errors.New("unknown conn ", reflect.TypeOf(conn))
}
if len(quicParams.UdpHop.Ports) > 0 {
addr := &udphop.UDPHopAddr{
IP: remote.IP,
Ports: quicParams.UdpHop.Ports,
}
pktConn, err = udphop.NewUDPHopPacketConn(addr, index, quicParams.UdpHop.IntervalMin, quicParams.UdpHop.IntervalMax, c.udphopDialer, pktConn)
if err != nil {
raw.Close()
return errors.New("udphop err").Base(err)
}
}
if c.udpmaskManager != nil {
pktConn, err = c.udpmaskManager.WrapPacketConnClient(pktConn)
if err != nil {
raw.Close()
return errors.New("mask err").Base(err)
}
}
quicConfig := &quic.Config{ quicConfig := &quic.Config{
InitialStreamReceiveWindow: quicParams.InitStreamReceiveWindow, InitialStreamReceiveWindow: quicParams.InitStreamReceiveWindow,
MaxStreamReceiveWindow: quicParams.MaxStreamReceiveWindow, MaxStreamReceiveWindow: quicParams.MaxStreamReceiveWindow,
@@ -226,9 +88,10 @@ func (c *client) dial() error {
MaxConnectionReceiveWindow: quicParams.MaxConnReceiveWindow, MaxConnectionReceiveWindow: quicParams.MaxConnReceiveWindow,
MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second, MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second,
KeepAlivePeriod: time.Duration(quicParams.KeepAlivePeriod) * time.Second, KeepAlivePeriod: time.Duration(quicParams.KeepAlivePeriod) * time.Second,
DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery, DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery || (runtime.GOOS != "linux" && runtime.GOOS != "windows" && runtime.GOOS != "darwin"),
EnableDatagrams: true, EnableDatagrams: true,
MaxDatagramFrameSize: MaxDatagramFrameSize, MaxDatagramFrameSize: MaxDatagramFrameSize,
OmitMaxDatagramFrameSize: time.Now().After(time.Date(2026, 9, 1, 0, 0, 0, 0, time.UTC)),
DisablePathManager: true, DisablePathManager: true,
} }
if quicParams.InitStreamReceiveWindow == 0 { if quicParams.InitStreamReceiveWindow == 0 {
@@ -250,16 +113,56 @@ func (c *client) dial() error {
// quicConfig.KeepAlivePeriod = 10 * time.Second // quicConfig.KeepAlivePeriod = 10 * time.Second
// } // }
var quicConn *quic.Conn var pktConn net.PacketConn
var udpAddr *net.UDPAddr
var err error
udpAddr, err = net.ResolveUDPAddr("udp", c.dest.NetAddr())
if err != nil {
return err
}
if len(quicParams.UdpHop.Ports) > 0 {
pktConn, err = udphop.NewUDPHopPacketConn(udphop.ToAddrs(udpAddr.IP, quicParams.UdpHop.Ports), time.Duration(quicParams.UdpHop.IntervalMin)*time.Second, time.Duration(quicParams.UdpHop.IntervalMax)*time.Second, c.udpHopDialer)
if err != nil {
return err
}
} else {
conn, err := internet.DialSystem(context.Background(), c.dest, c.socketConfig)
if err != nil {
return err
}
switch c := conn.(type) {
case *internet.PacketConnWrapper:
pktConn = c.PacketConn
case *net.UDPConn:
pktConn = c
case *cnc.Connection:
pktConn = &internet.FakePacketConn{Conn: c}
default:
panic(reflect.TypeOf(c))
}
}
if c.udpmaskManager != nil {
newConn, err := c.udpmaskManager.WrapPacketConnClient(pktConn)
if err != nil {
pktConn.Close()
return errors.New("mask err").Base(err)
}
pktConn = newConn
}
tr := &quic.Transport{Conn: pktConn}
var conn *quic.Conn
rt := &http3.Transport{ rt := &http3.Transport{
TLSClientConfig: c.tlsConfig, TLSClientConfig: c.tlsConfig,
QUICConfig: quicConfig, QUICConfig: quicConfig,
Dial: func(ctx context.Context, _ string, tlsCfg *go_tls.Config, cfg *quic.Config) (*quic.Conn, error) { Dial: func(ctx context.Context, _ string, tlsCfg *go_tls.Config, cfg *quic.Config) (*quic.Conn, error) {
qc, err := quic.DialEarly(ctx, pktConn, remote, tlsCfg, cfg) qc, err := tr.DialEarly(ctx, udpAddr, tlsCfg, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
quicConn = qc conn = qc
return qc, nil return qc, nil
}, },
} }
@@ -273,75 +176,61 @@ func (c *client) dial() error {
Header: http.Header{ Header: http.Header{
RequestHeaderAuth: []string{c.config.Auth}, RequestHeaderAuth: []string{c.config.Auth},
CommonHeaderCCRX: []string{strconv.FormatUint(quicParams.BrutalDown, 10)}, CommonHeaderCCRX: []string{strconv.FormatUint(quicParams.BrutalDown, 10)},
CommonHeaderPadding: []string{authRequestPadding.String()}, CommonHeaderPadding: []string{AuthRequestPadding.String()},
}, },
} }
resp, err := rt.RoundTrip(req) resp, err := rt.RoundTrip(req)
if err != nil { if err != nil {
if quicConn != nil { if conn != nil {
_ = quicConn.CloseWithError(closeErrCodeProtocolError, "") _ = conn.CloseWithError(closeErrCodeProtocolError, "")
} }
_ = tr.Close()
_ = pktConn.Close() _ = pktConn.Close()
return errors.New("RoundTrip err").Base(err) return err
} }
if resp.StatusCode != StatusAuthOK { if resp.StatusCode != StatusAuthOK {
_ = quicConn.CloseWithError(closeErrCodeProtocolError, "") _ = conn.CloseWithError(closeErrCodeProtocolError, "")
_ = tr.Close()
_ = pktConn.Close() _ = pktConn.Close()
return errors.New("auth failed") return errors.New("auth failed code ", resp.StatusCode)
} }
_ = resp.Body.Close() _ = resp.Body.Close()
serverUdp, _ := strconv.ParseBool(resp.Header.Get(ResponseHeaderUDPEnabled)) // udp, _ := strconv.ParseBool(resp.Header.Get(ResponseHeaderUDPEnabled))
serverAuto := resp.Header.Get(CommonHeaderCCRX) down, _ := strconv.ParseUint(resp.Header.Get(CommonHeaderCCRX), 10, 64)
serverDown, _ := strconv.ParseUint(serverAuto, 10, 64)
switch quicParams.Congestion { switch quicParams.Congestion {
case "reno": case "reno":
errors.LogDebug(c.ctx, "congestion reno")
case "bbr": case "bbr":
errors.LogDebug(c.ctx, "congestion bbr ", quicParams.BbrProfile) congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile))
congestion.UseBBR(quicConn, bbr.Profile(quicParams.BbrProfile)) case "", "brutal":
case "brutal", "": if quicParams.BrutalUp == 0 || down == 0 {
if serverAuto == "auto" || quicParams.BrutalUp == 0 || serverDown == 0 { congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile))
errors.LogDebug(c.ctx, "congestion bbr ", quicParams.BbrProfile)
congestion.UseBBR(quicConn, bbr.Profile(quicParams.BbrProfile))
} else { } else {
errors.LogDebug(c.ctx, "congestion brutal bytes per second ", min(quicParams.BrutalUp, serverDown)) congestion.UseBrutal(conn, min(quicParams.BrutalUp, down))
congestion.UseBrutal(quicConn, min(quicParams.BrutalUp, serverDown))
} }
case "force-brutal": case "force-brutal":
errors.LogDebug(c.ctx, "congestion brutal bytes per second ", quicParams.BrutalUp) congestion.UseBrutal(conn, quicParams.BrutalUp)
congestion.UseBrutal(quicConn, quicParams.BrutalUp)
default: default:
errors.LogDebug(c.ctx, "congestion reno") panic(quicParams.Congestion)
} }
c.pktConn = pktConn c.pktConn = pktConn
c.conn = quicConn c.tr = tr
if serverUdp { c.conn = conn
c.udpSM = &udpSessionManagerClient{ c.udpSM = &udpSessionManager{
conn: quicConn, conn: conn,
m: make(map[uint32]*InterUdpConn), m: make(map[uint32]*InterConn),
next: 1, next: 1,
} }
go c.udpSM.run() go c.udpSM.run()
}
return nil return nil
} }
func (c *client) clean() {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.status() == StatusInactive {
c.close()
}
}
func (c *client) tcp() (stat.Connection, error) { func (c *client) tcp() (stat.Connection, error) {
c.mutex.Lock() c.Lock()
defer c.mutex.Unlock() defer c.Unlock()
err := c.dial() err := c.dial()
if err != nil { if err != nil {
@@ -363,59 +252,43 @@ func (c *client) tcp() (stat.Connection, error) {
} }
func (c *client) udp() (stat.Connection, error) { func (c *client) udp() (stat.Connection, error) {
c.mutex.Lock() c.Lock()
defer c.mutex.Unlock() defer c.Unlock()
err := c.dial() err := c.dial()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if c.udpSM == nil {
return nil, errors.New("server does not support udp")
}
return c.udpSM.udp() return c.udpSM.udp()
} }
func (c *client) setCtx(ctx context.Context) { func (c *client) clean() {
c.mutex.Lock() c.Lock()
defer c.mutex.Unlock() if c.status() == StatusInactive {
c.close()
c.ctx = ctx }
c.Unlock()
} }
func (c *client) udphopDialer(addr *net.UDPAddr) (net.PacketConn, error) { func (c *client) udpHopDialer(addr *net.UDPAddr) (net.PacketConn, error) {
c.mutex.Lock() conn, err := internet.DialSystem(context.Background(), net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)), c.socketConfig)
defer c.mutex.Unlock()
if c.status() != StatusActive {
errors.LogDebug(context.Background(), "skip hop: disconnected QUIC")
return nil, errors.New()
}
raw, err := internet.DialSystem(c.ctx, net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)), c.socketConfig)
if err != nil { if err != nil {
errors.LogDebug(context.Background(), "skip hop: failed to dial to dest") errors.LogInfoInner(context.Background(), err, "skip hop: failed to dial to dest")
raw.Close() return nil, errors.New("failed to dial to dest").Base(err)
return nil, errors.New()
} }
var pktConn net.PacketConn var pktConn net.PacketConn
switch conn := raw.(type) { switch c := conn.(type) {
case *internet.PacketConnWrapper: case *internet.PacketConnWrapper:
pktConn = conn.PacketConn pktConn = c.PacketConn
case *net.UDPConn: case *net.UDPConn:
pktConn = conn pktConn = c
case *cnc.Connection:
errors.LogDebug(context.Background(), "skip hop: udphop requires being at the outermost level")
raw.Close()
return nil, errors.New()
default: default:
errors.LogDebug(context.Background(), "skip hop: unknown conn ", reflect.TypeOf(conn)) errors.LogInfo(context.Background(), "skip hop: invalid conn ", reflect.TypeOf(c))
raw.Close() conn.Close()
return nil, errors.New() return nil, errors.New("invalid conn ", reflect.TypeOf(c))
} }
return pktConn, nil return pktConn, nil
@@ -427,17 +300,19 @@ type dialerConf struct {
} }
type clientManager struct { type clientManager struct {
sync.RWMutex
m map[dialerConf]*client m map[dialerConf]*client
mutex sync.Mutex
} }
func (m *clientManager) clean() { func (m *clientManager) clean() {
m.mutex.Lock() ticker := time.NewTicker(idleCleanupInterval)
defer m.mutex.Unlock() for range ticker.C {
m.RLock()
for _, c := range m.m { for _, c := range m.m {
c.clean() c.clean()
} }
m.RUnlock()
}
} }
var manager *clientManager var manager *clientManager
@@ -449,41 +324,38 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
return nil, errors.New("tls config is nil") return nil, errors.New("tls config is nil")
} }
requireDatagram := hyCtx.RequireDatagramFromContext(ctx) datagram := DatagramFromContext(ctx)
dest.Network = net.Network_UDP dest.Network = net.Network_UDP
config := streamSettings.ProtocolSettings.(*Config)
initmanager.Do(func() { initmanager.Do(func() {
manager = &clientManager{ manager = &clientManager{
m: make(map[dialerConf]*client), m: make(map[dialerConf]*client),
} }
(&task.Periodic{ go manager.clean()
Interval: 30 * time.Second,
Execute: func() error {
manager.clean()
return nil
},
}).Start()
}) })
manager.mutex.Lock() manager.RLock()
c, ok := manager.m[dialerConf{Destination: dest, MemoryStreamConfig: streamSettings}] c := manager.m[dialerConf{dest, streamSettings}]
if !ok { manager.RUnlock()
if c == nil {
manager.Lock()
c = manager.m[dialerConf{dest, streamSettings}]
if c == nil {
c = &client{ c = &client{
ctx: ctx,
dest: dest, dest: dest,
config: config, config: streamSettings.ProtocolSettings.(*Config),
tlsConfig: tlsConfig.GetTLSConfig(), tlsConfig: tlsConfig.GetTLSConfig(),
socketConfig: streamSettings.SocketSettings, socketConfig: streamSettings.SocketSettings,
udpmaskManager: streamSettings.UdpmaskManager, udpmaskManager: streamSettings.UdpmaskManager,
quicParams: streamSettings.QuicParams, quicParams: streamSettings.QuicParams,
} }
manager.m[dialerConf{Destination: dest, MemoryStreamConfig: streamSettings}] = c manager.m[dialerConf{dest, streamSettings}] = c
}
manager.Unlock()
} }
c.setCtx(ctx)
manager.mutex.Unlock()
if requireDatagram { if datagram {
return c.udp() return c.udp()
} }
return c.tcp() return c.tcp()

View File

@@ -3,10 +3,10 @@ package hysteria
import ( import (
"context" "context"
gotls "crypto/tls" gotls "crypto/tls"
"encoding/binary"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -20,158 +20,41 @@ import (
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/proxy/hysteria/account" "github.com/xtls/xray-core/proxy/hysteria/account"
hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/hysteria/congestion" "github.com/xtls/xray-core/transport/internet/hysteria/congestion"
"github.com/xtls/xray-core/transport/internet/hysteria/congestion/bbr" "github.com/xtls/xray-core/transport/internet/hysteria/congestion/bbr"
"github.com/xtls/xray-core/transport/internet/tls" "github.com/xtls/xray-core/transport/internet/tls"
) )
type udpSessionManagerServer struct {
conn *quic.Conn
m map[uint32]*InterUdpConn
addConn internet.ConnHandler
stopCh chan struct{}
udpIdleTimeout time.Duration
mutex sync.RWMutex
user *protocol.MemoryUser
}
func (m *udpSessionManagerServer) close(udpConn *InterUdpConn) {
if !udpConn.closed {
udpConn.closed = true
close(udpConn.ch)
delete(m.m, udpConn.id)
}
}
func (m *udpSessionManagerServer) clean() {
ticker := time.NewTicker(idleCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.mutex.RLock()
now := time.Now()
timeoutConn := make([]*InterUdpConn, 0, len(m.m))
for _, udpConn := range m.m {
if now.Sub(udpConn.GetLast()) > m.udpIdleTimeout {
timeoutConn = append(timeoutConn, udpConn)
}
}
m.mutex.RUnlock()
for _, udpConn := range timeoutConn {
m.mutex.Lock()
m.close(udpConn)
m.mutex.Unlock()
}
case <-m.stopCh:
return
}
}
}
func (m *udpSessionManagerServer) run() {
for {
d, err := m.conn.ReceiveDatagram(context.Background())
if err != nil {
break
}
if len(d) < 4 {
continue
}
id := binary.BigEndian.Uint32(d[:4])
m.feed(id, d)
}
m.mutex.Lock()
defer m.mutex.Unlock()
close(m.stopCh)
for _, udpConn := range m.m {
m.close(udpConn)
}
}
func (m *udpSessionManagerServer) feed(id uint32, d []byte) {
m.mutex.RLock()
udpConn, ok := m.m[id]
if ok {
select {
case udpConn.ch <- d:
default:
}
m.mutex.RUnlock()
return
}
m.mutex.RUnlock()
m.mutex.Lock()
defer m.mutex.Unlock()
udpConn, ok = m.m[id]
if !ok {
udpConn = &InterUdpConn{
conn: m.conn,
local: m.conn.LocalAddr(),
remote: m.conn.RemoteAddr(),
id: id,
ch: make(chan []byte, udpMessageChanSize),
last: time.Now(),
user: m.user,
}
udpConn.closeFunc = func() {
m.mutex.Lock()
m.close(udpConn)
m.mutex.Unlock()
}
m.m[id] = udpConn
m.addConn(udpConn)
}
select {
case udpConn.ch <- d:
default:
}
}
type httpHandler struct { type httpHandler struct {
ctx context.Context sync.Mutex
conn *quic.Conn
addConn internet.ConnHandler
config *Config
quicParams *internet.QuicParams
validator *account.Validator validator *account.Validator
config *Config
masqHandler http.Handler masqHandler http.Handler
quicParams *internet.QuicParams
addConn internet.ConnHandler
conn *quic.Conn
auth bool auth bool
mutex sync.Mutex
user *protocol.MemoryUser user *protocol.MemoryUser
} }
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *httpHandler) AuthHTTP(w http.ResponseWriter, r *http.Request) bool {
if r.Method == http.MethodPost && r.Host == URLHost && r.URL.Path == URLPath { if r.Method == http.MethodPost && r.Host == URLHost && r.URL.Path == URLPath {
h.mutex.Lock() h.Lock()
defer h.mutex.Unlock() defer h.Unlock()
if h.auth { if h.auth {
w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(hyCtx.RequireDatagramFromContext(h.ctx))) w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(h.validator != nil))
w.Header().Set(CommonHeaderCCRX, strconv.FormatUint(h.quicParams.BrutalDown, 10)) w.Header().Set(CommonHeaderCCRX, strconv.FormatUint(h.quicParams.BrutalDown, 10))
w.Header().Set(CommonHeaderPadding, authResponsePadding.String()) w.Header().Set(CommonHeaderPadding, AuthResponsePadding.String())
w.WriteHeader(StatusAuthOK) w.WriteHeader(StatusAuthOK)
return return true
} }
auth := r.Header.Get(RequestHeaderAuth) auth := r.Header.Get(RequestHeaderAuth)
clientDown, _ := strconv.ParseUint(r.Header.Get(CommonHeaderCCRX), 10, 64) down, _ := strconv.ParseUint(r.Header.Get(CommonHeaderCCRX), 10, 64)
var user *protocol.MemoryUser var user *protocol.MemoryUser
var ok bool var ok bool
@@ -185,49 +68,51 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.auth = true h.auth = true
h.user = user h.user = user
switch h.quicParams.Congestion { conn := h.conn
quicParams := h.quicParams
switch quicParams.Congestion {
case "reno": case "reno":
errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion reno")
case "bbr": case "bbr":
errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr ", h.quicParams.BbrProfile) congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile))
congestion.UseBBR(h.conn, bbr.Profile(h.quicParams.BbrProfile)) case "", "brutal":
case "brutal", "": if quicParams.BrutalUp == 0 || down == 0 {
if h.quicParams.BrutalUp == 0 || clientDown == 0 { congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile))
errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr ", h.quicParams.BbrProfile)
congestion.UseBBR(h.conn, bbr.Profile(h.quicParams.BbrProfile))
} else { } else {
errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion brutal bytes per second ", min(h.quicParams.BrutalUp, clientDown)) congestion.UseBrutal(conn, min(quicParams.BrutalUp, down))
congestion.UseBrutal(h.conn, min(h.quicParams.BrutalUp, clientDown))
} }
case "force-brutal": case "force-brutal":
errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion brutal bytes per second ", h.quicParams.BrutalUp) congestion.UseBrutal(conn, quicParams.BrutalUp)
congestion.UseBrutal(h.conn, h.quicParams.BrutalUp)
default: default:
errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion reno") panic(quicParams.Congestion)
} }
if hyCtx.RequireDatagramFromContext(h.ctx) { if h.validator != nil {
udpSM := &udpSessionManagerServer{ udpSM := &udpSessionManager{
conn: h.conn, conn: h.conn,
m: make(map[uint32]*InterUdpConn), m: make(map[uint32]*InterConn),
addConn: h.addConn,
stopCh: make(chan struct{}),
udpIdleTimeout: time.Duration(h.config.UdpIdleTimeout) * time.Second,
addConn: h.addConn,
udpIdleTimeout: time.Duration(h.config.UdpIdleTimeout) * time.Second,
user: h.user, user: h.user,
} }
go udpSM.clean() go udpSM.clean()
go udpSM.run() go udpSM.run()
} }
w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(hyCtx.RequireDatagramFromContext(h.ctx))) w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(h.validator != nil))
w.Header().Set(CommonHeaderCCRX, strconv.FormatUint(h.quicParams.BrutalDown, 10)) w.Header().Set(CommonHeaderCCRX, strconv.FormatUint(h.quicParams.BrutalDown, 10))
w.Header().Set(CommonHeaderPadding, authResponsePadding.String()) w.Header().Set(CommonHeaderPadding, AuthResponsePadding.String())
w.WriteHeader(StatusAuthOK) w.WriteHeader(StatusAuthOK)
return return true
} }
} }
return false
}
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h.AuthHTTP(w, r) {
return
}
h.masqHandler.ServeHTTP(w, r) h.masqHandler.ServeHTTP(w, r)
} }
@@ -256,42 +141,41 @@ func (h *httpHandler) StreamDispatcher(ft http3.FrameType, stream *quic.Stream,
} }
type Listener struct { type Listener struct {
ctx context.Context validator *account.Validator
pktConn net.PacketConn config *Config
listener *quic.Listener masqHandler http.Handler
quicParams *internet.QuicParams
addConn internet.ConnHandler addConn internet.ConnHandler
config *Config pktConn net.PacketConn
quicParams *internet.QuicParams tr *quic.Transport
validator *account.Validator listener *quic.Listener
masqHandler http.Handler
} }
func (l *Listener) handleClient(conn *quic.Conn) { func (l *Listener) handleClient(conn *quic.Conn) {
handler := &httpHandler{ handler := &httpHandler{
ctx: l.ctx,
conn: conn,
addConn: l.addConn,
config: l.config,
quicParams: l.quicParams,
validator: l.validator, validator: l.validator,
config: l.config,
masqHandler: l.masqHandler, masqHandler: l.masqHandler,
quicParams: l.quicParams,
addConn: l.addConn,
conn: conn,
} }
h3 := http3.Server{ h3s := http3.Server{
Handler: handler, Handler: handler,
StreamDispatcher: handler.StreamDispatcher, StreamDispatcher: handler.StreamDispatcher,
} }
err := h3.ServeQUICConn(conn) _ = h3s.ServeQUICConn(conn)
_ = conn.CloseWithError(closeErrCodeOK, "") _ = conn.CloseWithError(closeErrCodeOK, "")
errors.LogDebug(context.Background(), conn.RemoteAddr(), " disconnected with err ", err)
} }
func (l *Listener) keepAccepting() { func (l *Listener) keepAccepting() {
for { for {
conn, err := l.listener.Accept(context.Background()) conn, err := l.listener.Accept(context.Background())
if err != nil { if err != nil {
errors.LogInfoInner(context.Background(), err, "failed to accept QUIC connection") if err != quic.ErrServerClosed {
errors.LogErrorInner(context.Background(), err, "failed to serve hysteria")
}
break break
} }
go l.handleClient(conn) go l.handleClient(conn)
@@ -303,9 +187,7 @@ func (l *Listener) Addr() net.Addr {
} }
func (l *Listener) Close() error { func (l *Listener) Close() error {
err := l.listener.Close() return errors.Combine(l.listener.Close(), l.tr.Close(), l.pktConn.Close())
_ = l.pktConn.Close()
return err
} }
func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) { func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
@@ -318,11 +200,10 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
return nil, errors.New("tls config is nil") return nil, errors.New("tls config is nil")
} }
validator := ValidatorFromContext(ctx)
config := streamSettings.ProtocolSettings.(*Config) config := streamSettings.ProtocolSettings.(*Config)
validator := hyCtx.ValidatorFromContext(ctx) if validator == nil && config.Auth == "" {
if config.Auth == "" && validator == nil {
return nil, errors.New("validator is nil") return nil, errors.New("validator is nil")
} }
@@ -372,22 +253,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
return nil, errors.New("unknown masq type") return nil, errors.New("unknown masq type")
} }
raw, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{IP: address.IP(), Port: int(port)}, streamSettings.SocketSettings)
if err != nil {
return nil, err
}
var pktConn net.PacketConn
pktConn = raw
if streamSettings.UdpmaskManager != nil {
pktConn, err = streamSettings.UdpmaskManager.WrapPacketConnServer(raw)
if err != nil {
raw.Close()
return nil, errors.New("mask err").Base(err)
}
}
quicParams := streamSettings.QuicParams quicParams := streamSettings.QuicParams
if quicParams == nil { if quicParams == nil {
quicParams = &internet.QuicParams{ quicParams = &internet.QuicParams{
@@ -403,9 +268,10 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
MaxConnectionReceiveWindow: quicParams.MaxConnReceiveWindow, MaxConnectionReceiveWindow: quicParams.MaxConnReceiveWindow,
MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second, MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second,
MaxIncomingStreams: quicParams.MaxIncomingStreams, MaxIncomingStreams: quicParams.MaxIncomingStreams,
DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery, DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery || (runtime.GOOS != "linux" && runtime.GOOS != "windows" && runtime.GOOS != "darwin"),
EnableDatagrams: true, EnableDatagrams: true,
MaxDatagramFrameSize: MaxDatagramFrameSize, MaxDatagramFrameSize: MaxDatagramFrameSize,
AssumePeerMaxDatagramFrameSize: MaxDatagramFrameSize,
DisablePathManager: true, DisablePathManager: true,
} }
if quicParams.InitStreamReceiveWindow == 0 { if quicParams.InitStreamReceiveWindow == 0 {
@@ -427,27 +293,44 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
quicConfig.MaxIncomingStreams = 1024 quicConfig.MaxIncomingStreams = 1024
} }
qListener, err := quic.Listen(pktConn, tlsConfig.GetTLSConfig(), quicConfig) pktConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{IP: address.IP(), Port: int(port)}, streamSettings.SocketSettings)
if err != nil { if err != nil {
return nil, err
}
if streamSettings.UdpmaskManager != nil {
newConn, err := streamSettings.UdpmaskManager.WrapPacketConnServer(pktConn)
if err != nil {
pktConn.Close()
return nil, errors.New("mask err").Base(err)
}
pktConn = newConn
}
tr := &quic.Transport{Conn: pktConn}
listener, err := tr.Listen(tlsConfig.GetTLSConfig(), quicConfig)
if err != nil {
_ = tr.Close()
_ = pktConn.Close() _ = pktConn.Close()
return nil, err return nil, err
} }
listener := &Listener{ l := &Listener{
ctx: ctx, validator: validator,
pktConn: pktConn, config: config,
listener: qListener, masqHandler: masqHandler,
quicParams: quicParams,
addConn: handler, addConn: handler,
config: config, pktConn: pktConn,
quicParams: quicParams, tr: tr,
validator: validator, listener: listener,
masqHandler: masqHandler,
} }
go listener.keepAccepting() go l.keepAccepting()
return listener, nil return l, nil
} }
func init() { func init() {

View File

@@ -1,24 +0,0 @@
package padding
import (
"math/rand"
)
const (
paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)
// padding specifies a half-open range [Min, Max).
type Padding struct {
Min int
Max int
}
func (p Padding) String() string {
n := p.Min + rand.Intn(p.Max-p.Min)
bs := make([]byte, n)
for i := range bs {
bs[i] = paddingChars[rand.Intn(len(paddingChars))]
}
return string(bs)
}

View File

@@ -1,65 +0,0 @@
package udphop
import (
"fmt"
"net"
)
type InvalidPortError struct {
PortStr string
}
func (e InvalidPortError) Error() string {
return fmt.Sprintf("%s is not a valid port number or range", e.PortStr)
}
// UDPHopAddr contains an IP address and a list of ports.
type UDPHopAddr struct {
IP net.IP
Ports []uint32
PortStr string
}
func (a *UDPHopAddr) Network() string {
return "udphop"
}
func (a *UDPHopAddr) String() string {
return net.JoinHostPort(a.IP.String(), a.PortStr)
}
// addrs returns a list of net.Addr's, one for each port.
func (a *UDPHopAddr) addrs() ([]net.Addr, error) {
var addrs []net.Addr
for _, port := range a.Ports {
addr := &net.UDPAddr{
IP: a.IP,
Port: int(port),
}
addrs = append(addrs, addr)
}
return addrs, nil
}
// func ResolveUDPHopAddr(addr string) (*UDPHopAddr, error) {
// host, portStr, err := net.SplitHostPort(addr)
// if err != nil {
// return nil, err
// }
// ip, err := net.ResolveIPAddr("ip", host)
// if err != nil {
// return nil, err
// }
// result := &UDPHopAddr{
// IP: ip.IP,
// PortStr: portStr,
// }
// pu := utils.ParsePortUnion(portStr)
// if pu == nil {
// return nil, InvalidPortError{portStr}
// }
// result.Ports = pu.Ports()
// return result, nil
// }

View File

@@ -8,7 +8,6 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/xtls/xray-core/common/crypto"
"github.com/xtls/xray-core/transport/internet/finalmask" "github.com/xtls/xray-core/transport/internet/finalmask"
) )
@@ -20,19 +19,19 @@ const (
) )
type UdpHopPacketConn struct { type UdpHopPacketConn struct {
Addr net.Addr
Addrs []net.Addr Addrs []net.Addr
HopIntervalMin int64 HopIntervalMin time.Duration
HopIntervalMax int64 HopIntervalMax time.Duration
ListenUDPFunc ListenUDPFunc ListenUDPFunc func(addr *net.UDPAddr) (net.PacketConn, error)
connMutex sync.RWMutex connMutex sync.RWMutex
prevConn net.PacketConn prevConn net.PacketConn
currentConn net.PacketConn currentConn net.PacketConn
addrIndex int addrIndex int
readBufferSize int deadline time.Time
writeBufferSize int readDeadline time.Time
writeDeadline time.Time
recvQueue chan *udpPacket recvQueue chan *udpPacket
closeChan chan struct{} closeChan chan struct{}
@@ -48,41 +47,36 @@ type udpPacket struct {
Err error Err error
} }
type ListenUDPFunc = func(*net.UDPAddr) (net.PacketConn, error) func NewUDPHopPacketConn(addrs []net.Addr, hopIntervalMin time.Duration, hopIntervalMax time.Duration, listenUDPFunc func(addr *net.UDPAddr) (net.PacketConn, error)) (net.PacketConn, error) {
if len(addrs) == 0 {
func NewUDPHopPacketConn(addr *UDPHopAddr, index int, intervalMin int64, intervalMax int64, listenUDPFunc ListenUDPFunc, pktConn net.PacketConn) (net.PacketConn, error) { panic("len(addrs) == 0")
if intervalMin == 0 || intervalMax == 0 {
intervalMin = int64(defaultHopInterval)
intervalMax = int64(defaultHopInterval)
} }
if intervalMin < 5 || intervalMax < 5 { if hopIntervalMin == 0 {
return nil, errors.New("hop interval must be at least 5 seconds") hopIntervalMin = defaultHopInterval
}
if hopIntervalMax == 0 {
hopIntervalMax = defaultHopInterval
}
if hopIntervalMin < 5*time.Second {
panic("hopIntervalMin < 5*time.Second")
}
if hopIntervalMax < 5*time.Second {
panic("hopIntervalMax < 5*time.Second")
}
if hopIntervalMax < hopIntervalMin {
panic("hopIntervalMax < hopIntervalMin")
} }
// if listenUDPFunc == nil {
// listenUDPFunc = func() (net.PacketConn, error) {
// return net.ListenUDP("udp", nil)
// }
// }
if listenUDPFunc == nil { if listenUDPFunc == nil {
return nil, errors.New("nil listenUDPFunc") panic("listenUDPFunc is nil")
} }
addrs, err := addr.addrs()
if err != nil {
return nil, err
}
// curConn, err := listenUDPFunc()
// if err != nil {
// return nil, err
// }
hConn := &UdpHopPacketConn{ hConn := &UdpHopPacketConn{
Addr: addr,
Addrs: addrs, Addrs: addrs,
HopIntervalMin: intervalMin, HopIntervalMin: hopIntervalMin,
HopIntervalMax: intervalMax, HopIntervalMax: hopIntervalMax,
ListenUDPFunc: listenUDPFunc, ListenUDPFunc: listenUDPFunc,
prevConn: nil, prevConn: nil,
currentConn: pktConn, currentConn: nil,
addrIndex: index, addrIndex: rand.Intn(len(addrs)),
recvQueue: make(chan *udpPacket, packetQueueSize), recvQueue: make(chan *udpPacket, packetQueueSize),
closeChan: make(chan struct{}), closeChan: make(chan struct{}),
bufPool: sync.Pool{ bufPool: sync.Pool{
@@ -91,7 +85,12 @@ func NewUDPHopPacketConn(addr *UDPHopAddr, index int, intervalMin int64, interva
}, },
}, },
} }
go hConn.recvLoop(pktConn) var err error
hConn.currentConn, err = listenUDPFunc(hConn.Addrs[hConn.addrIndex].(*net.UDPAddr))
if err != nil {
return nil, err
}
go hConn.recvLoop(hConn.currentConn)
go hConn.hopLoop() go hConn.hopLoop()
return hConn, nil return hConn, nil
} }
@@ -104,69 +103,64 @@ func (u *UdpHopPacketConn) recvLoop(conn net.PacketConn) {
u.bufPool.Put(buf) u.bufPool.Put(buf)
var netErr net.Error var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() { if errors.As(err, &netErr) && netErr.Timeout() {
// Only pass through timeout errors here, not permanent errors
// like connection closed. Connection close is normal as we close
// the old connection to exit this loop every time we hop.
u.recvQueue <- &udpPacket{nil, 0, nil, netErr} u.recvQueue <- &udpPacket{nil, 0, nil, netErr}
continue
} }
return return
} }
select { select {
case u.recvQueue <- &udpPacket{buf, n, addr, nil}: case u.recvQueue <- &udpPacket{buf, n, addr, nil}:
// Packet successfully queued
default: default:
// Queue is full, drop the packet
u.bufPool.Put(buf) u.bufPool.Put(buf)
} }
} }
} }
func (u *UdpHopPacketConn) hopLoop() { func (u *UdpHopPacketConn) hopLoop() {
ticker := time.NewTicker(time.Duration(crypto.RandBetween(u.HopIntervalMin, u.HopIntervalMax)) * time.Second) timer := time.NewTimer(u.nextHopInterval())
defer ticker.Stop() defer timer.Stop()
for { for {
select { select {
case <-ticker.C: case <-timer.C:
u.hop() u.hop()
ticker.Reset(time.Duration(crypto.RandBetween(u.HopIntervalMin, u.HopIntervalMax)) * time.Second) timer.Reset(u.nextHopInterval())
case <-u.closeChan: case <-u.closeChan:
return return
} }
} }
} }
func (u *UdpHopPacketConn) nextHopInterval() time.Duration {
if u.HopIntervalMin == u.HopIntervalMax {
return u.HopIntervalMin
}
return u.HopIntervalMin + time.Duration(rand.Int63n(int64(u.HopIntervalMax-u.HopIntervalMin)+1))
}
func (u *UdpHopPacketConn) hop() { func (u *UdpHopPacketConn) hop() {
u.connMutex.Lock() u.connMutex.Lock()
defer u.connMutex.Unlock() defer u.connMutex.Unlock()
if u.closed { if u.closed {
return return
} }
// Update addrIndex to a new random value
u.addrIndex = rand.Intn(len(u.Addrs)) u.addrIndex = rand.Intn(len(u.Addrs))
newConn, err := u.ListenUDPFunc(u.Addrs[u.addrIndex].(*net.UDPAddr)) newConn, err := u.ListenUDPFunc(u.Addrs[u.addrIndex].(*net.UDPAddr))
if err != nil { if err != nil {
// Could be temporary, just skip this hop
return return
} }
// We need to keep receiving packets from the previous connection,
// because otherwise there will be packet loss due to the time gap
// between we hop to a new port and the server acknowledges this change.
// So we do the following:
// Close prevConn,
// move currentConn to prevConn,
// set newConn as currentConn,
// start recvLoop on newConn.
if u.prevConn != nil { if u.prevConn != nil {
_ = u.prevConn.Close() // recvLoop for this conn will exit _ = u.prevConn.Close()
} }
u.prevConn = u.currentConn u.prevConn = u.currentConn
u.currentConn = newConn u.currentConn = newConn
// Set buffer sizes if previously set if !u.deadline.IsZero() {
if u.readBufferSize > 0 { _ = u.currentConn.SetDeadline(u.deadline)
_ = trySetReadBuffer(u.currentConn, u.readBufferSize)
} }
if u.writeBufferSize > 0 { if !u.readDeadline.IsZero() {
_ = trySetWriteBuffer(u.currentConn, u.writeBufferSize) _ = u.currentConn.SetReadDeadline(u.readDeadline)
}
if !u.writeDeadline.IsZero() {
_ = u.currentConn.SetWriteDeadline(u.writeDeadline)
} }
go u.recvLoop(newConn) go u.recvLoop(newConn)
} }
@@ -178,11 +172,9 @@ func (u *UdpHopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error)
if p.Err != nil { if p.Err != nil {
return 0, nil, p.Err return 0, nil, p.Err
} }
// Currently we do not check whether the packet is from
// the server or not due to performance reasons.
n := copy(b, p.Buf[:p.N]) n := copy(b, p.Buf[:p.N])
u.bufPool.Put(p.Buf) u.bufPool.Put(p.Buf)
return n, u.Addr, nil return n, p.Addr, nil
case <-u.closeChan: case <-u.closeChan:
return 0, nil, net.ErrClosed return 0, nil, net.ErrClosed
} }
@@ -195,8 +187,6 @@ func (u *UdpHopPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
if u.closed { if u.closed {
return 0, net.ErrClosed return 0, net.ErrClosed
} }
// Skip the check for now, always write to the server,
// for the same reason as in ReadFrom.
return u.currentConn.WriteTo(b, u.Addrs[u.addrIndex]) return u.currentConn.WriteTo(b, u.Addrs[u.addrIndex])
} }
@@ -206,16 +196,13 @@ func (u *UdpHopPacketConn) Close() error {
if u.closed { if u.closed {
return nil return nil
} }
// Close prevConn and currentConn
// Close closeChan to unblock ReadFrom & hopLoop
// Set closed flag to true to prevent double close
if u.prevConn != nil { if u.prevConn != nil {
_ = u.prevConn.Close() _ = u.prevConn.Close()
} }
err := u.currentConn.Close() err := u.currentConn.Close()
close(u.closeChan) close(u.closeChan)
u.closed = true u.closed = true
u.Addrs = nil // For GC u.Addrs = nil
return err return err
} }
@@ -226,8 +213,11 @@ func (u *UdpHopPacketConn) LocalAddr() net.Addr {
} }
func (u *UdpHopPacketConn) SetDeadline(t time.Time) error { func (u *UdpHopPacketConn) SetDeadline(t time.Time) error {
u.connMutex.RLock() u.connMutex.Lock()
defer u.connMutex.RUnlock() defer u.connMutex.Unlock()
u.deadline = t
u.readDeadline = t
u.writeDeadline = t
if u.prevConn != nil { if u.prevConn != nil {
_ = u.prevConn.SetDeadline(t) _ = u.prevConn.SetDeadline(t)
} }
@@ -235,8 +225,10 @@ func (u *UdpHopPacketConn) SetDeadline(t time.Time) error {
} }
func (u *UdpHopPacketConn) SetReadDeadline(t time.Time) error { func (u *UdpHopPacketConn) SetReadDeadline(t time.Time) error {
u.connMutex.RLock() u.connMutex.Lock()
defer u.connMutex.RUnlock() defer u.connMutex.Unlock()
u.deadline = time.Time{}
u.readDeadline = t
if u.prevConn != nil { if u.prevConn != nil {
_ = u.prevConn.SetReadDeadline(t) _ = u.prevConn.SetReadDeadline(t)
} }
@@ -244,36 +236,16 @@ func (u *UdpHopPacketConn) SetReadDeadline(t time.Time) error {
} }
func (u *UdpHopPacketConn) SetWriteDeadline(t time.Time) error { func (u *UdpHopPacketConn) SetWriteDeadline(t time.Time) error {
u.connMutex.RLock() u.connMutex.Lock()
defer u.connMutex.RUnlock() defer u.connMutex.Unlock()
u.deadline = time.Time{}
u.writeDeadline = t
if u.prevConn != nil { if u.prevConn != nil {
_ = u.prevConn.SetWriteDeadline(t) _ = u.prevConn.SetWriteDeadline(t)
} }
return u.currentConn.SetWriteDeadline(t) return u.currentConn.SetWriteDeadline(t)
} }
// UDP-specific methods below
func (u *UdpHopPacketConn) SetReadBuffer(bytes int) error {
u.connMutex.Lock()
defer u.connMutex.Unlock()
u.readBufferSize = bytes
if u.prevConn != nil {
_ = trySetReadBuffer(u.prevConn, bytes)
}
return trySetReadBuffer(u.currentConn, bytes)
}
func (u *UdpHopPacketConn) SetWriteBuffer(bytes int) error {
u.connMutex.Lock()
defer u.connMutex.Unlock()
u.writeBufferSize = bytes
if u.prevConn != nil {
_ = trySetWriteBuffer(u.prevConn, bytes)
}
return trySetWriteBuffer(u.currentConn, bytes)
}
func (u *UdpHopPacketConn) SyscallConn() (syscall.RawConn, error) { func (u *UdpHopPacketConn) SyscallConn() (syscall.RawConn, error) {
u.connMutex.RLock() u.connMutex.RLock()
defer u.connMutex.RUnlock() defer u.connMutex.RUnlock()
@@ -284,22 +256,14 @@ func (u *UdpHopPacketConn) SyscallConn() (syscall.RawConn, error) {
return sc.SyscallConn() return sc.SyscallConn()
} }
func trySetReadBuffer(pc net.PacketConn, bytes int) error { func ToAddrs(ip net.IP, ports []uint32) []net.Addr {
sc, ok := pc.(interface { var addrs []net.Addr
SetReadBuffer(bytes int) error for _, port := range ports {
}) addr := &net.UDPAddr{
if ok { IP: ip,
return sc.SetReadBuffer(bytes) Port: int(port),
} }
return nil addrs = append(addrs, addr)
} }
return addrs
func trySetWriteBuffer(pc net.PacketConn, bytes int) error {
sc, ok := pc.(interface {
SetWriteBuffer(bytes int) error
})
if ok {
return sc.SetWriteBuffer(bytes)
}
return nil
} }

View File

@@ -57,41 +57,27 @@ func DialKCP(ctx context.Context, dest net.Destination, streamSettings *internet
} }
if streamSettings.UdpmaskManager != nil { if streamSettings.UdpmaskManager != nil {
var pktConn net.PacketConn
var udpAddr = conn.RemoteAddr().(*net.UDPAddr)
switch c := conn.(type) { switch c := conn.(type) {
case *internet.PacketConnWrapper: case *internet.PacketConnWrapper:
pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c.PacketConn) pktConn = c.PacketConn
if err != nil {
conn.Close()
return nil, errors.New("mask err").Base(err)
}
c.PacketConn = pktConn
case *net.UDPConn: case *net.UDPConn:
pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c) pktConn = c
if err != nil {
conn.Close()
return nil, errors.New("mask err").Base(err)
}
conn = &internet.PacketConnWrapper{
PacketConn: pktConn,
Dest: c.RemoteAddr().(*net.UDPAddr),
}
case *cnc.Connection: case *cnc.Connection:
fakeConn := &internet.FakePacketConn{Conn: c} pktConn = &internet.FakePacketConn{Conn: c}
pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(fakeConn) default:
panic(reflect.TypeOf(c))
}
newConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(pktConn)
if err != nil { if err != nil {
conn.Close() pktConn.Close()
return nil, errors.New("mask err").Base(err) return nil, errors.New("mask err").Base(err)
} }
pktConn = newConn
conn = &internet.PacketConnWrapper{ conn = &internet.PacketConnWrapper{
PacketConn: pktConn, PacketConn: pktConn,
Dest: &net.UDPAddr{ Dest: udpAddr,
IP: []byte{0, 0, 0, 0},
Port: 0,
},
}
default:
conn.Close()
return nil, errors.New("unknown conn ", reflect.TypeOf(c))
} }
} }

View File

@@ -5,11 +5,11 @@ import (
gotls "crypto/tls" gotls "crypto/tls"
"fmt" "fmt"
"io" "io"
"math/rand"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"net/url" "net/url"
reflect "reflect" reflect "reflect"
"runtime"
"strconv" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -21,6 +21,7 @@ import (
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/net/cnc"
"github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/common/uuid" "github.com/xtls/xray-core/common/uuid"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
@@ -173,7 +174,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second, MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second,
KeepAlivePeriod: time.Duration(quicParams.KeepAlivePeriod) * time.Second, KeepAlivePeriod: time.Duration(quicParams.KeepAlivePeriod) * time.Second,
MaxIncomingStreams: quicParams.MaxIncomingStreams, MaxIncomingStreams: quicParams.MaxIncomingStreams,
DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery, DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery || (runtime.GOOS != "linux" && runtime.GOOS != "windows" && runtime.GOOS != "darwin"),
} }
if quicParams.MaxIdleTimeout == 0 { if quicParams.MaxIdleTimeout == 0 {
quicConfig.MaxIdleTimeout = net.ConnIdleTimeout quicConfig.MaxIdleTimeout = net.ConnIdleTimeout
@@ -194,110 +195,83 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
QUICConfig: quicConfig, QUICConfig: quicConfig,
TLSClientConfig: gotlsConfig, TLSClientConfig: gotlsConfig,
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (*quic.Conn, error) { Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (*quic.Conn, error) {
udphopDialer := func(addr *net.UDPAddr) (net.PacketConn, error) { udpHopDialer := func(addr *net.UDPAddr) (net.PacketConn, error) {
conn, err := internet.DialSystem(ctx, net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)), streamSettings.SocketSettings) conn, err := internet.DialSystem(ctx, net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)), streamSettings.SocketSettings)
if err != nil { if err != nil {
errors.LogDebug(context.Background(), "skip hop: failed to dial to dest") errors.LogInfoInner(context.Background(), err, "skip hop: failed to dial to dest")
conn.Close() return nil, errors.New("failed to dial to dest").Base(err)
return nil, errors.New()
} }
var udpConn net.PacketConn var pktConn net.PacketConn
switch c := conn.(type) { switch c := conn.(type) {
case *internet.PacketConnWrapper: case *internet.PacketConnWrapper:
udpConn = c.PacketConn pktConn = c.PacketConn
case *net.UDPConn: case *net.UDPConn:
udpConn = c pktConn = c
default: default:
errors.LogDebug(context.Background(), "skip hop: udphop requires being at the outermost level ", reflect.TypeOf(c)) errors.LogInfo(context.Background(), "skip hop: invalid conn ", reflect.TypeOf(c))
conn.Close() conn.Close()
return nil, errors.New() return nil, errors.New("invalid conn ", reflect.TypeOf(c))
} }
return udpConn, nil return pktConn, nil
} }
var index int var pktConn net.PacketConn
var udpAddr *net.UDPAddr
var err error
udpAddr, err = net.ResolveUDPAddr("udp", dest.NetAddr())
if err != nil {
return nil, err
}
if len(quicParams.UdpHop.Ports) > 0 { if len(quicParams.UdpHop.Ports) > 0 {
index = rand.Intn(len(quicParams.UdpHop.Ports)) pktConn, err = udphop.NewUDPHopPacketConn(udphop.ToAddrs(udpAddr.IP, quicParams.UdpHop.Ports), time.Duration(quicParams.UdpHop.IntervalMin)*time.Second, time.Duration(quicParams.UdpHop.IntervalMax)*time.Second, udpHopDialer)
dest.Port = net.Port(quicParams.UdpHop.Ports[index]) if err != nil {
return nil, err
} }
} else {
conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var udpConn net.PacketConn
var udpAddr *net.UDPAddr
switch c := conn.(type) { switch c := conn.(type) {
case *internet.PacketConnWrapper: case *internet.PacketConnWrapper:
udpConn = c.PacketConn pktConn = c.PacketConn
udpAddr, err = net.ResolveUDPAddr("udp", c.Dest.String())
if err != nil {
conn.Close()
return nil, err
}
case *net.UDPConn: case *net.UDPConn:
udpConn = c pktConn = c
udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String()) case *cnc.Connection:
if err != nil { pktConn = &internet.FakePacketConn{Conn: c}
conn.Close()
return nil, err
}
default: default:
udpConn = &internet.FakePacketConn{Conn: c} panic(reflect.TypeOf(c))
udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
if err != nil {
conn.Close()
return nil, err
}
if len(quicParams.UdpHop.Ports) > 0 {
conn.Close()
return nil, errors.New("udphop requires being at the outermost level ", reflect.TypeOf(c))
}
}
if len(quicParams.UdpHop.Ports) > 0 {
addr := &udphop.UDPHopAddr{
IP: udpAddr.IP,
Ports: quicParams.UdpHop.Ports,
}
udpConn, err = udphop.NewUDPHopPacketConn(addr, index, quicParams.UdpHop.IntervalMin, quicParams.UdpHop.IntervalMax, udphopDialer, udpConn)
if err != nil {
conn.Close()
return nil, errors.New("udphop err").Base(err)
} }
} }
if streamSettings.UdpmaskManager != nil { if streamSettings.UdpmaskManager != nil {
udpConn, err = streamSettings.UdpmaskManager.WrapPacketConnClient(udpConn) newConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(pktConn)
if err != nil { if err != nil {
conn.Close() pktConn.Close()
return nil, errors.New("mask err").Base(err) return nil, errors.New("mask err").Base(err)
} }
pktConn = newConn
} }
quicConn, err := quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg) conn, err := quic.DialEarly(ctx, pktConn, udpAddr, tlsCfg, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch quicParams.Congestion { switch quicParams.Congestion {
case "force-brutal":
errors.LogDebug(context.Background(), quicConn.RemoteAddr(), " ", "congestion brutal bytes per second ", quicParams.BrutalUp)
congestion.UseBrutal(quicConn, quicParams.BrutalUp)
case "reno": case "reno":
errors.LogDebug(context.Background(), quicConn.RemoteAddr(), " ", "congestion reno") case "", "bbr":
congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile))
case "force-brutal":
congestion.UseBrutal(conn, quicParams.BrutalUp)
default: default:
errors.LogDebug(context.Background(), quicConn.RemoteAddr(), " ", "congestion bbr ", quicParams.BbrProfile) panic(quicParams.Congestion)
congestion.UseBBR(quicConn, bbr.Profile(quicParams.BbrProfile))
} }
return quicConn, nil return conn, nil
}, },
} }
} else if httpVersion == "2" { } else if httpVersion == "2" {

View File

@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"runtime"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@@ -440,7 +441,7 @@ type Listener struct {
server http.Server server http.Server
h3server *http3.Server h3server *http3.Server
listener net.Listener listener net.Listener
h3listener *quic.EarlyListener h3listener Qface
config *Config config *Config
addConn internet.ConnHandler addConn internet.ConnHandler
isH3 bool isH3 bool
@@ -487,12 +488,12 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet
return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err) return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err)
} }
if streamSettings.UdpmaskManager != nil { if streamSettings.UdpmaskManager != nil {
pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnServer(Conn) newConn, err := streamSettings.UdpmaskManager.WrapPacketConnServer(Conn)
if err != nil { if err != nil {
Conn.Close() Conn.Close()
return nil, errors.New("mask err").Base(err) return nil, errors.New("mask err").Base(err)
} }
Conn = pktConn Conn = newConn
} }
quicParams := streamSettings.QuicParams quicParams := streamSettings.QuicParams
@@ -510,13 +511,17 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet
MaxConnectionReceiveWindow: quicParams.MaxConnReceiveWindow, MaxConnectionReceiveWindow: quicParams.MaxConnReceiveWindow,
MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second, MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second,
MaxIncomingStreams: quicParams.MaxIncomingStreams, MaxIncomingStreams: quicParams.MaxIncomingStreams,
DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery, DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery || (runtime.GOOS != "linux" && runtime.GOOS != "windows" && runtime.GOOS != "darwin"),
} }
l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, quicConfig) l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, quicConfig)
if err != nil { if err != nil {
return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err) return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err)
} }
l.h3listener = &QListener{
Qface: l.h3listener,
quicParams: quicParams,
}
errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port) errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port)
handler.localAddr = l.h3listener.Addr() handler.localAddr = l.h3listener.Addr()
@@ -525,30 +530,8 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet
Handler: handler, Handler: handler,
} }
go func() { go func() {
for { if err := l.h3server.ServeListener(l.h3listener); err != nil {
conn, err := l.h3listener.Accept(context.Background()) errors.LogErrorInner(ctx, err, "failed to serve HTTP/3 for XHTTP/3")
if err != nil {
errors.LogInfoInner(ctx, err, "XHTTP/3 listener closed")
return
}
switch quicParams.Congestion {
case "force-brutal":
errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion brutal bytes per second ", quicParams.BrutalUp)
congestion.UseBrutal(conn, quicParams.BrutalUp)
case "reno":
errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion reno")
default:
errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion bbr ", quicParams.BbrProfile)
congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile))
}
go func() {
if err := l.h3server.ServeQUICConn(conn); err != nil {
errors.LogDebugInner(ctx, err, "XHTTP/3 connection ended")
}
_ = conn.CloseWithError(0, "")
}()
} }
}() }()
} else { // tcp } else { // tcp
@@ -614,10 +597,8 @@ func (ln *Listener) Addr() net.Addr {
func (ln *Listener) Close() error { func (ln *Listener) Close() error {
if ln.h3server != nil { if ln.h3server != nil {
if err := ln.h3server.Close(); err != nil { if err := ln.h3server.Close(); err != nil {
_ = ln.h3listener.Close()
return err return err
} }
return ln.h3listener.Close()
} else if ln.listener != nil { } else if ln.listener != nil {
return ln.listener.Close() return ln.listener.Close()
} }
@@ -633,3 +614,33 @@ func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *gotls.Config {
func init() { func init() {
common.Must(internet.RegisterTransportListener(protocolName, ListenXH)) common.Must(internet.RegisterTransportListener(protocolName, ListenXH))
} }
type Qface interface {
Accept(ctx context.Context) (*quic.Conn, error)
Addr() net.Addr
Close() error
}
var _ Qface = (*quic.EarlyListener)(nil)
type QListener struct {
Qface
quicParams *internet.QuicParams
}
func (l *QListener) Accept(ctx context.Context) (*quic.Conn, error) {
conn, err := l.Qface.Accept(ctx)
if err != nil {
return nil, err
}
switch l.quicParams.Congestion {
case "reno":
case "", "bbr":
congestion.UseBBR(conn, bbr.Profile(l.quicParams.BbrProfile))
case "force-brutal":
congestion.UseBrutal(conn, l.quicParams.BrutalUp)
default:
panic(l.quicParams.Congestion)
}
return conn, nil
}

View File

@@ -25,48 +25,30 @@ func init() {
} }
if streamSettings != nil && streamSettings.UdpmaskManager != nil { if streamSettings != nil && streamSettings.UdpmaskManager != nil {
var pktConn net.PacketConn
var udpAddr = conn.RemoteAddr().(*net.UDPAddr)
switch c := conn.(type) { switch c := conn.(type) {
case *internet.PacketConnWrapper: case *internet.PacketConnWrapper:
pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c.PacketConn) pktConn = c.PacketConn
if err != nil {
conn.Close()
return nil, errors.New("mask err").Base(err)
}
c.PacketConn = pktConn
errors.LogInfo(ctx, "finalmask udp dialer: wrapped existing PacketConnWrapper with ", reflect.TypeOf(pktConn))
case *net.UDPConn: case *net.UDPConn:
pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c) pktConn = c
if err != nil {
conn.Close()
return nil, errors.New("mask err").Base(err)
}
conn = &internet.PacketConnWrapper{
PacketConn: pktConn,
Dest: c.RemoteAddr().(*net.UDPAddr),
}
errors.LogInfo(ctx, "finalmask udp dialer: wrapped UDPConn with ", reflect.TypeOf(pktConn))
case *cnc.Connection: case *cnc.Connection:
fakeConn := &internet.FakePacketConn{Conn: c} pktConn = &internet.FakePacketConn{Conn: c}
pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(fakeConn) default:
panic(reflect.TypeOf(c))
}
newConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(pktConn)
if err != nil { if err != nil {
conn.Close() pktConn.Close()
return nil, errors.New("mask err").Base(err) return nil, errors.New("mask err").Base(err)
} }
pktConn = newConn
conn = &internet.PacketConnWrapper{ conn = &internet.PacketConnWrapper{
PacketConn: pktConn, PacketConn: pktConn,
Dest: &net.UDPAddr{ Dest: udpAddr,
IP: []byte{0, 0, 0, 0},
Port: 0,
},
}
errors.LogInfo(ctx, "finalmask udp dialer: wrapped cnc.Connection with ", reflect.TypeOf(pktConn))
default:
conn.Close()
return nil, errors.New("unknown conn ", reflect.TypeOf(c))
} }
} }
// TODO: handle dialer options
return conn, nil return conn, nil
})) }))
} }