package hysteria import ( "bytes" "encoding/binary" "fmt" "io" "github.com/apernet/quic-go/quicvarint" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/transport/internet/hysteria" ) const ( // Max length values are for preventing DoS attacks MaxAddressLength = 2048 MaxMessageLength = 2048 MaxPaddingLength = 4096 maxVarInt1 = 63 maxVarInt2 = 16383 maxVarInt4 = 1073741823 maxVarInt8 = 4611686018427387903 ) // TCPRequest format: // Address length (QUIC varint) // Address (bytes) // Padding length (QUIC varint) // Padding (bytes) func ReadTCPRequest(r io.Reader) (string, error) { bReader := quicvarint.NewReader(r) addrLen, err := quicvarint.Read(bReader) if err != nil { return "", err } if addrLen == 0 || addrLen > MaxAddressLength { return "", errors.New("invalid address length") } addrBuf := make([]byte, addrLen) _, err = io.ReadFull(r, addrBuf) if err != nil { return "", err } paddingLen, err := quicvarint.Read(bReader) if err != nil { return "", err } if paddingLen > MaxPaddingLength { return "", errors.New("invalid padding length") } if paddingLen > 0 { _, err = io.CopyN(io.Discard, r, int64(paddingLen)) if err != nil { return "", err } } return string(addrBuf), nil } func WriteTCPRequest(w io.Writer, addr string) error { padding := hysteria.TcpRequestPadding.String() paddingLen := len(padding) addrLen := len(addr) sz := int(quicvarint.Len(uint64(addrLen))) + addrLen + int(quicvarint.Len(uint64(paddingLen))) + paddingLen buf := make([]byte, sz) i := varintPut(buf, uint64(addrLen)) i += copy(buf[i:], addr) i += varintPut(buf[i:], uint64(paddingLen)) copy(buf[i:], padding) _, err := w.Write(buf) return err } // TCPResponse format: // Status (byte, 0=ok, 1=error) // Message length (QUIC varint) // Message (bytes) // Padding length (QUIC varint) // Padding (bytes) func ReadTCPResponse(r io.Reader) (bool, string, error) { var status [1]byte if _, err := io.ReadFull(r, status[:]); err != nil { return false, "", err } bReader := quicvarint.NewReader(r) msgLen, err := quicvarint.Read(bReader) if err != nil { return false, "", err } if msgLen > MaxMessageLength { return false, "", errors.New("invalid message length") } var msgBuf []byte // No message is fine if msgLen > 0 { msgBuf = make([]byte, msgLen) _, err = io.ReadFull(r, msgBuf) if err != nil { return false, "", err } } paddingLen, err := quicvarint.Read(bReader) if err != nil { return false, "", err } if paddingLen > MaxPaddingLength { return false, "", errors.New("invalid padding length") } if paddingLen > 0 { _, err = io.CopyN(io.Discard, r, int64(paddingLen)) if err != nil { return false, "", err } } return status[0] == 0, string(msgBuf), nil } func WriteTCPResponse(w io.Writer, ok bool, msg string) error { padding := hysteria.TcpResponsePadding.String() paddingLen := len(padding) msgLen := len(msg) sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + int(quicvarint.Len(uint64(paddingLen))) + paddingLen buf := make([]byte, sz) if ok { buf[0] = 0 } else { buf[0] = 1 } i := varintPut(buf[1:], uint64(msgLen)) i += copy(buf[1+i:], msg) i += varintPut(buf[1+i:], uint64(paddingLen)) copy(buf[1+i:], padding) _, err := w.Write(buf) return err } // UDPMessage format: // Session ID (uint32 BE) // Packet ID (uint16 BE) // Fragment ID (uint8) // Fragment count (uint8) // Address length (QUIC varint) // Address (bytes) // Data... type UDPMessage struct { SessionID uint32 // 4 PacketID uint16 // 2 FragID uint8 // 1 FragCount uint8 // 1 Addr string // varint + bytes Data []byte } func (m *UDPMessage) HeaderSize() int { lAddr := len(m.Addr) return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr } func (m *UDPMessage) Size() int { return m.HeaderSize() + len(m.Data) } func (m *UDPMessage) Serialize(buf []byte) int { // Make sure the buffer is big enough if len(buf) < m.Size() { return -1 } // binary.BigEndian.PutUint32(buf, m.SessionID) binary.BigEndian.PutUint16(buf[4:], m.PacketID) buf[6] = m.FragID buf[7] = m.FragCount i := varintPut(buf[8:], uint64(len(m.Addr))) i += copy(buf[8+i:], m.Addr) i += copy(buf[8+i:], m.Data) return 8 + i } func ParseUDPMessage(msg []byte) (*UDPMessage, error) { m := &UDPMessage{} buf := bytes.NewBuffer(msg) if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil { return nil, err } if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil { return nil, err } if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil { return nil, err } if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil { return nil, err } lAddr, err := quicvarint.Read(buf) if err != nil { return nil, err } if lAddr == 0 || lAddr > MaxMessageLength { return nil, errors.New("invalid address length") } bs := buf.Bytes() if len(bs) <= int(lAddr) { // We use <= instead of < here as we expect at least one byte of data after the address return nil, errors.New("invalid message length") } m.Addr = string(bs[:lAddr]) m.Data = bs[lAddr:] return m, nil } // varintPut is like quicvarint.Append, but instead of appending to a slice, // it writes to a fixed-size buffer. Returns the number of bytes written. func varintPut(b []byte, i uint64) int { if i <= maxVarInt1 { b[0] = uint8(i) return 1 } if i <= maxVarInt2 { b[0] = uint8(i>>8) | 0x40 b[1] = uint8(i) return 2 } if i <= maxVarInt4 { b[0] = uint8(i>>24) | 0x80 b[1] = uint8(i >> 16) b[2] = uint8(i >> 8) b[3] = uint8(i) return 4 } if i <= maxVarInt8 { b[0] = uint8(i>>56) | 0xc0 b[1] = uint8(i >> 48) b[2] = uint8(i >> 40) b[3] = uint8(i >> 32) b[4] = uint8(i >> 24) b[5] = uint8(i >> 16) b[6] = uint8(i >> 8) b[7] = uint8(i) return 8 } panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) } func FragUDPMessage(m *UDPMessage, maxSize int) []UDPMessage { if m.Size() <= maxSize { return []UDPMessage{*m} } fullPayload := m.Data maxPayloadSize := maxSize - m.HeaderSize() off := 0 fragID := uint8(0) fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up frags := make([]UDPMessage, fragCount) for off < len(fullPayload) { payloadSize := len(fullPayload) - off if payloadSize > maxPayloadSize { payloadSize = maxPayloadSize } frag := *m frag.FragID = fragID frag.FragCount = fragCount frag.Data = fullPayload[off : off+payloadSize] frags[fragID] = frag off += payloadSize fragID++ } return frags } // Defragger handles the defragmentation of UDP messages. // The current implementation can only handle one packet ID at a time. // If another packet arrives before a packet has received all fragments // in their entirety, any previous state is discarded. type Defragger struct { pktID uint16 frags []*UDPMessage count uint8 size int // data size } func (d *Defragger) Feed(m *UDPMessage) *UDPMessage { if m.FragCount <= 1 { return m } if m.FragID >= m.FragCount { // wtf is this? return nil } if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) { // new message, clear previous state d.pktID = m.PacketID d.frags = make([]*UDPMessage, m.FragCount) d.frags[m.FragID] = m d.count = 1 d.size = len(m.Data) } else if d.frags[m.FragID] == nil { d.frags[m.FragID] = m d.count++ d.size += len(m.Data) if int(d.count) == len(d.frags) { // all fragments received, assemble data := make([]byte, d.size) off := 0 for _, frag := range d.frags { off += copy(data[off:], frag.Data) } m.Data = data m.FragID = 0 m.FragCount = 1 return m } } return nil }