mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-05-08 14:13:22 +00:00
https://github.com/XTLS/Xray-core/pull/6041#issuecomment-4357417742 And fixes https://github.com/XTLS/Xray-core/issues/6039
321 lines
7.5 KiB
Go
321 lines
7.5 KiB
Go
package hysteria
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
|
|
"github.com/apernet/quic-go/quicvarint"
|
|
"github.com/xtls/xray-core/common/errors"
|
|
"github.com/xtls/xray-core/transport/internet/hysteria"
|
|
)
|
|
|
|
const (
|
|
// Max length values are for preventing DoS attacks
|
|
|
|
MaxAddressLength = 2048
|
|
MaxMessageLength = 2048
|
|
MaxPaddingLength = 4096
|
|
|
|
maxVarInt1 = 63
|
|
maxVarInt2 = 16383
|
|
maxVarInt4 = 1073741823
|
|
maxVarInt8 = 4611686018427387903
|
|
)
|
|
|
|
// TCPRequest format:
|
|
// Address length (QUIC varint)
|
|
// Address (bytes)
|
|
// Padding length (QUIC varint)
|
|
// Padding (bytes)
|
|
|
|
func ReadTCPRequest(r io.Reader) (string, error) {
|
|
bReader := quicvarint.NewReader(r)
|
|
addrLen, err := quicvarint.Read(bReader)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if addrLen == 0 || addrLen > MaxAddressLength {
|
|
return "", errors.New("invalid address length")
|
|
}
|
|
addrBuf := make([]byte, addrLen)
|
|
_, err = io.ReadFull(r, addrBuf)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
paddingLen, err := quicvarint.Read(bReader)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if paddingLen > MaxPaddingLength {
|
|
return "", errors.New("invalid padding length")
|
|
}
|
|
if paddingLen > 0 {
|
|
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
return string(addrBuf), nil
|
|
}
|
|
|
|
func WriteTCPRequest(w io.Writer, addr string) error {
|
|
padding := hysteria.TcpRequestPadding.String()
|
|
paddingLen := len(padding)
|
|
addrLen := len(addr)
|
|
sz := int(quicvarint.Len(uint64(addrLen))) + addrLen +
|
|
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
|
buf := make([]byte, sz)
|
|
i := varintPut(buf, uint64(addrLen))
|
|
i += copy(buf[i:], addr)
|
|
i += varintPut(buf[i:], uint64(paddingLen))
|
|
copy(buf[i:], padding)
|
|
_, err := w.Write(buf)
|
|
return err
|
|
}
|
|
|
|
// TCPResponse format:
|
|
// Status (byte, 0=ok, 1=error)
|
|
// Message length (QUIC varint)
|
|
// Message (bytes)
|
|
// Padding length (QUIC varint)
|
|
// Padding (bytes)
|
|
|
|
func ReadTCPResponse(r io.Reader) (bool, string, error) {
|
|
var status [1]byte
|
|
if _, err := io.ReadFull(r, status[:]); err != nil {
|
|
return false, "", err
|
|
}
|
|
bReader := quicvarint.NewReader(r)
|
|
msgLen, err := quicvarint.Read(bReader)
|
|
if err != nil {
|
|
return false, "", err
|
|
}
|
|
if msgLen > MaxMessageLength {
|
|
return false, "", errors.New("invalid message length")
|
|
}
|
|
var msgBuf []byte
|
|
// No message is fine
|
|
if msgLen > 0 {
|
|
msgBuf = make([]byte, msgLen)
|
|
_, err = io.ReadFull(r, msgBuf)
|
|
if err != nil {
|
|
return false, "", err
|
|
}
|
|
}
|
|
paddingLen, err := quicvarint.Read(bReader)
|
|
if err != nil {
|
|
return false, "", err
|
|
}
|
|
if paddingLen > MaxPaddingLength {
|
|
return false, "", errors.New("invalid padding length")
|
|
}
|
|
if paddingLen > 0 {
|
|
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
|
if err != nil {
|
|
return false, "", err
|
|
}
|
|
}
|
|
return status[0] == 0, string(msgBuf), nil
|
|
}
|
|
|
|
func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
|
|
padding := hysteria.TcpResponsePadding.String()
|
|
paddingLen := len(padding)
|
|
msgLen := len(msg)
|
|
sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
|
|
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
|
buf := make([]byte, sz)
|
|
if ok {
|
|
buf[0] = 0
|
|
} else {
|
|
buf[0] = 1
|
|
}
|
|
i := varintPut(buf[1:], uint64(msgLen))
|
|
i += copy(buf[1+i:], msg)
|
|
i += varintPut(buf[1+i:], uint64(paddingLen))
|
|
copy(buf[1+i:], padding)
|
|
_, err := w.Write(buf)
|
|
return err
|
|
}
|
|
|
|
// UDPMessage format:
|
|
// Session ID (uint32 BE)
|
|
// Packet ID (uint16 BE)
|
|
// Fragment ID (uint8)
|
|
// Fragment count (uint8)
|
|
// Address length (QUIC varint)
|
|
// Address (bytes)
|
|
// Data...
|
|
|
|
type UDPMessage struct {
|
|
SessionID uint32 // 4
|
|
PacketID uint16 // 2
|
|
FragID uint8 // 1
|
|
FragCount uint8 // 1
|
|
Addr string // varint + bytes
|
|
Data []byte
|
|
}
|
|
|
|
func (m *UDPMessage) HeaderSize() int {
|
|
lAddr := len(m.Addr)
|
|
return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr
|
|
}
|
|
|
|
func (m *UDPMessage) Size() int {
|
|
return m.HeaderSize() + len(m.Data)
|
|
}
|
|
|
|
func (m *UDPMessage) Serialize(buf []byte) int {
|
|
// Make sure the buffer is big enough
|
|
if len(buf) < m.Size() {
|
|
return -1
|
|
}
|
|
// binary.BigEndian.PutUint32(buf, m.SessionID)
|
|
binary.BigEndian.PutUint16(buf[4:], m.PacketID)
|
|
buf[6] = m.FragID
|
|
buf[7] = m.FragCount
|
|
i := varintPut(buf[8:], uint64(len(m.Addr)))
|
|
i += copy(buf[8+i:], m.Addr)
|
|
i += copy(buf[8+i:], m.Data)
|
|
return 8 + i
|
|
}
|
|
|
|
func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
|
|
m := &UDPMessage{}
|
|
buf := bytes.NewBuffer(msg)
|
|
if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil {
|
|
return nil, err
|
|
}
|
|
lAddr, err := quicvarint.Read(buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if lAddr == 0 || lAddr > MaxMessageLength {
|
|
return nil, errors.New("invalid address length")
|
|
}
|
|
bs := buf.Bytes()
|
|
if len(bs) <= int(lAddr) {
|
|
// We use <= instead of < here as we expect at least one byte of data after the address
|
|
return nil, errors.New("invalid message length")
|
|
}
|
|
m.Addr = string(bs[:lAddr])
|
|
m.Data = bs[lAddr:]
|
|
return m, nil
|
|
}
|
|
|
|
// varintPut is like quicvarint.Append, but instead of appending to a slice,
|
|
// it writes to a fixed-size buffer. Returns the number of bytes written.
|
|
func varintPut(b []byte, i uint64) int {
|
|
if i <= maxVarInt1 {
|
|
b[0] = uint8(i)
|
|
return 1
|
|
}
|
|
if i <= maxVarInt2 {
|
|
b[0] = uint8(i>>8) | 0x40
|
|
b[1] = uint8(i)
|
|
return 2
|
|
}
|
|
if i <= maxVarInt4 {
|
|
b[0] = uint8(i>>24) | 0x80
|
|
b[1] = uint8(i >> 16)
|
|
b[2] = uint8(i >> 8)
|
|
b[3] = uint8(i)
|
|
return 4
|
|
}
|
|
if i <= maxVarInt8 {
|
|
b[0] = uint8(i>>56) | 0xc0
|
|
b[1] = uint8(i >> 48)
|
|
b[2] = uint8(i >> 40)
|
|
b[3] = uint8(i >> 32)
|
|
b[4] = uint8(i >> 24)
|
|
b[5] = uint8(i >> 16)
|
|
b[6] = uint8(i >> 8)
|
|
b[7] = uint8(i)
|
|
return 8
|
|
}
|
|
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
|
|
}
|
|
|
|
func FragUDPMessage(m *UDPMessage, maxSize int) []UDPMessage {
|
|
if m.Size() <= maxSize {
|
|
return []UDPMessage{*m}
|
|
}
|
|
fullPayload := m.Data
|
|
maxPayloadSize := maxSize - m.HeaderSize()
|
|
off := 0
|
|
fragID := uint8(0)
|
|
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
|
|
frags := make([]UDPMessage, fragCount)
|
|
for off < len(fullPayload) {
|
|
payloadSize := len(fullPayload) - off
|
|
if payloadSize > maxPayloadSize {
|
|
payloadSize = maxPayloadSize
|
|
}
|
|
frag := *m
|
|
frag.FragID = fragID
|
|
frag.FragCount = fragCount
|
|
frag.Data = fullPayload[off : off+payloadSize]
|
|
frags[fragID] = frag
|
|
off += payloadSize
|
|
fragID++
|
|
}
|
|
return frags
|
|
}
|
|
|
|
// Defragger handles the defragmentation of UDP messages.
|
|
// The current implementation can only handle one packet ID at a time.
|
|
// If another packet arrives before a packet has received all fragments
|
|
// in their entirety, any previous state is discarded.
|
|
type Defragger struct {
|
|
pktID uint16
|
|
frags []*UDPMessage
|
|
count uint8
|
|
size int // data size
|
|
}
|
|
|
|
func (d *Defragger) Feed(m *UDPMessage) *UDPMessage {
|
|
if m.FragCount <= 1 {
|
|
return m
|
|
}
|
|
if m.FragID >= m.FragCount {
|
|
// wtf is this?
|
|
return nil
|
|
}
|
|
if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) {
|
|
// new message, clear previous state
|
|
d.pktID = m.PacketID
|
|
d.frags = make([]*UDPMessage, m.FragCount)
|
|
d.frags[m.FragID] = m
|
|
d.count = 1
|
|
d.size = len(m.Data)
|
|
} else if d.frags[m.FragID] == nil {
|
|
d.frags[m.FragID] = m
|
|
d.count++
|
|
d.size += len(m.Data)
|
|
if int(d.count) == len(d.frags) {
|
|
// all fragments received, assemble
|
|
data := make([]byte, d.size)
|
|
off := 0
|
|
for _, frag := range d.frags {
|
|
off += copy(data[off:], frag.Data)
|
|
}
|
|
m.Data = data
|
|
m.FragID = 0
|
|
m.FragCount = 1
|
|
return m
|
|
}
|
|
}
|
|
return nil
|
|
}
|