package splithttp import ( "bytes" "context" gotls "crypto/tls" "encoding/base64" "fmt" "io" "net/http" "slices" "strconv" "strings" "sync" "time" "github.com/apernet/quic-go" "github.com/apernet/quic-go/http3" goreality "github.com/xtls/reality" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" http_proto "github.com/xtls/xray-core/common/protocol/http" "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/hysteria/congestion" "github.com/xtls/xray-core/transport/internet/hysteria/congestion/bbr" "github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tls" ) type requestHandler struct { config *Config host string path string ln *Listener sessionMu *sync.Mutex sessions sync.Map localAddr net.Addr socketSettings *internet.SocketConfig } type httpSession struct { uploadQueue *uploadQueue // for as long as the GET request is not opened by the client, this will be // open ("undone"), and the session may be expired within a certain TTL. // after the client connects, this becomes "done" and the session lives as // long as the GET request. isFullyConnected *done.Instance } func (h *requestHandler) upsertSession(sessionId string) *httpSession { // fast path currentSessionAny, ok := h.sessions.Load(sessionId) if ok { return currentSessionAny.(*httpSession) } // slow path h.sessionMu.Lock() defer h.sessionMu.Unlock() currentSessionAny, ok = h.sessions.Load(sessionId) if ok { return currentSessionAny.(*httpSession) } s := &httpSession{ uploadQueue: NewUploadQueue(h.ln.config.GetNormalizedScMaxBufferedPosts()), isFullyConnected: done.New(), } h.sessions.Store(sessionId, s) shouldReap := done.New() go func() { time.Sleep(30 * time.Second) shouldReap.Close() }() go func() { select { case <-shouldReap.Wait(): h.sessions.Delete(sessionId) s.uploadQueue.Close() case <-s.isFullyConnected.Wait(): } }() return s } func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if len(h.host) > 0 && !internet.IsValidHTTPHost(request.Host, h.host) { errors.LogInfo(context.Background(), "failed to validate host, request:", request.Host, ", config:", h.host) writer.WriteHeader(http.StatusNotFound) return } if !strings.HasPrefix(request.URL.Path, h.path) { errors.LogInfo(context.Background(), "failed to validate path, request:", request.URL.Path, ", config:", h.path) writer.WriteHeader(http.StatusNotFound) return } h.config.WriteResponseHeader(writer, request.Method, request.Header) length := int(h.config.GetNormalizedXPaddingBytes().rand()) config := XPaddingConfig{Length: length} if h.config.XPaddingObfsMode { config.Placement = XPaddingPlacement{ Placement: h.config.XPaddingPlacement, Key: h.config.XPaddingKey, Header: h.config.XPaddingHeader, } config.Method = PaddingMethod(h.config.XPaddingMethod) } else { config.Placement = XPaddingPlacement{ Placement: PlacementHeader, Header: "X-Padding", } } h.config.ApplyXPaddingToResponse(writer, config) if request.Method == "OPTIONS" { writer.WriteHeader(http.StatusOK) return } /* clientVer := []int{0, 0, 0} x_version := strings.Split(request.URL.Query().Get("x_version"), ".") for j := 0; j < 3 && len(x_version) > j; j++ { clientVer[j], _ = strconv.Atoi(x_version[j]) } */ validRange := h.config.GetNormalizedXPaddingBytes() paddingValue, paddingPlacement := h.config.ExtractXPaddingFromRequest(request, h.config.XPaddingObfsMode) if !h.config.IsPaddingValid(paddingValue, validRange.From, validRange.To, PaddingMethod(h.config.XPaddingMethod)) { errors.LogInfo(context.Background(), "invalid padding ("+paddingPlacement+") length:", int32(len(paddingValue))) writer.WriteHeader(http.StatusBadRequest) return } sessionId, seqStr := h.config.ExtractMetaFromRequest(request, h.path) if sessionId == "" && h.config.Mode != "" && h.config.Mode != "auto" && h.config.Mode != "stream-one" && h.config.Mode != "stream-up" { errors.LogInfo(context.Background(), "stream-one mode is not allowed") writer.WriteHeader(http.StatusBadRequest) return } var forwardedAddrs []net.Address if h.socketSettings != nil && len(h.socketSettings.TrustedXForwardedFor) > 0 { for _, key := range h.socketSettings.TrustedXForwardedFor { if len(request.Header.Values(key)) > 0 { forwardedAddrs = http_proto.ParseXForwardedFor(request.Header) break } } } else { forwardedAddrs = http_proto.ParseXForwardedFor(request.Header) } var remoteAddr net.Addr var err error remoteAddr, err = net.ResolveTCPAddr("tcp", request.RemoteAddr) if err != nil { remoteAddr = &net.TCPAddr{ IP: []byte{0, 0, 0, 0}, Port: 0, } } if request.ProtoMajor == 3 { remoteAddr = &net.UDPAddr{ IP: remoteAddr.(*net.TCPAddr).IP, Port: remoteAddr.(*net.TCPAddr).Port, } } if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { remoteAddr = &net.TCPAddr{ IP: forwardedAddrs[0].IP(), Port: 0, } } var currentSession *httpSession if sessionId != "" { currentSession = h.upsertSession(sessionId) } scMaxEachPostBytes := int(h.ln.config.GetNormalizedScMaxEachPostBytes().To) isUplinkRequest := false switch request.Method { case "GET": isUplinkRequest = seqStr != "" default: isUplinkRequest = true } uplinkDataKey := h.config.UplinkDataKey if isUplinkRequest && sessionId != "" { // stream-up, packet-up if seqStr == "" { if h.config.Mode != "" && h.config.Mode != "auto" && h.config.Mode != "stream-up" { errors.LogInfo(context.Background(), "stream-up mode is not allowed") writer.WriteHeader(http.StatusBadRequest) return } httpSC := &httpServerConn{ Instance: done.New(), Reader: request.Body, ResponseWriter: writer, } err = currentSession.uploadQueue.Push(Packet{ Reader: httpSC, }) if err != nil { errors.LogInfoInner(context.Background(), err, "failed to upload (PushReader)") writer.WriteHeader(http.StatusConflict) } else { writer.Header().Set("X-Accel-Buffering", "no") writer.Header().Set("Cache-Control", "no-store") writer.WriteHeader(http.StatusOK) scStreamUpServerSecs := h.config.GetNormalizedScStreamUpServerSecs() referrer := request.Header.Get("Referer") if referrer != "" && scStreamUpServerSecs.To > 0 { go func() { for { _, err := httpSC.Write(bytes.Repeat([]byte{'X'}, int(h.config.GetNormalizedXPaddingBytes().rand()))) if err != nil { break } time.Sleep(time.Duration(scStreamUpServerSecs.rand()) * time.Second) } }() } select { case <-request.Context().Done(): case <-httpSC.Wait(): } } httpSC.Close() return } if h.config.Mode != "" && h.config.Mode != "auto" && h.config.Mode != "packet-up" { errors.LogInfo(context.Background(), "packet-up mode is not allowed") writer.WriteHeader(http.StatusBadRequest) return } dataPlacement := h.config.GetNormalizedUplinkDataPlacement() var headerPayload []byte if dataPlacement == PlacementAuto || dataPlacement == PlacementHeader { var headerPayloadChunks []string for i := 0; true; i++ { chunk := request.Header.Get(fmt.Sprintf("%s-%d", uplinkDataKey, i)) if chunk == "" { break } headerPayloadChunks = append(headerPayloadChunks, chunk) } headerPayloadEncoded := strings.Join(headerPayloadChunks, "") headerPayload, err = base64.RawURLEncoding.DecodeString(headerPayloadEncoded) if err != nil { errors.LogInfo(context.Background(), "Invalid base64 in header's payload: ", err.Error()) writer.WriteHeader(http.StatusBadRequest) return } } var cookiePayload []byte if dataPlacement == PlacementAuto || dataPlacement == PlacementCookie { var cookiePayloadChunks []string for i := 0; true; i++ { cookieName := fmt.Sprintf("%s_%d", uplinkDataKey, i) if c, _ := request.Cookie(cookieName); c != nil { cookiePayloadChunks = append(cookiePayloadChunks, c.Value) } else { break } } cookiePayloadEncoded := strings.Join(cookiePayloadChunks, "") cookiePayload, err = base64.RawURLEncoding.DecodeString(cookiePayloadEncoded) if err != nil { errors.LogInfo(context.Background(), "Invalid base64 in cookies' payload: ", err.Error()) writer.WriteHeader(http.StatusBadRequest) return } } var bodyPayload []byte if dataPlacement == PlacementAuto || dataPlacement == PlacementBody { var readErr error if request.ContentLength > int64(scMaxEachPostBytes) { errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.") writer.WriteHeader(http.StatusRequestEntityTooLarge) return } if request.ContentLength > 0 { bodyPayload = make([]byte, request.ContentLength) _, readErr = io.ReadFull(request.Body, bodyPayload) } else { bodyPayload, readErr = buf.ReadAllToBytes(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1)) } if readErr != nil { errors.LogInfoInner(context.Background(), readErr, "failed to read body payload") writer.WriteHeader(http.StatusBadRequest) return } } var payload []byte switch dataPlacement { case PlacementHeader: payload = headerPayload case PlacementCookie: payload = cookiePayload case PlacementBody: payload = bodyPayload case PlacementAuto: payload = slices.Concat(headerPayload, cookiePayload, bodyPayload) } if len(payload) > scMaxEachPostBytes { errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.") writer.WriteHeader(http.StatusRequestEntityTooLarge) return } seq, err := strconv.ParseUint(seqStr, 10, 64) if err != nil { errors.LogInfoInner(context.Background(), err, "failed to upload (ParseUint)") writer.WriteHeader(http.StatusInternalServerError) return } err = currentSession.uploadQueue.Push(Packet{ Payload: payload, Seq: seq, }) if err != nil { errors.LogInfoInner(context.Background(), err, "failed to upload (PushPayload)") writer.WriteHeader(http.StatusInternalServerError) return } if len(bodyPayload) == 0 { // Methods without a body are usually cached by default. writer.Header().Set("Cache-Control", "no-store") } writer.WriteHeader(http.StatusOK) } else if request.Method == "GET" || sessionId == "" { // stream-down, stream-one if sessionId != "" { // after GET is done, the connection is finished. disable automatic // session reaping, and handle it in defer currentSession.isFullyConnected.Close() defer h.sessions.Delete(sessionId) } // magic header instructs nginx + apache to not buffer response body writer.Header().Set("X-Accel-Buffering", "no") // A web-compliant header telling all middleboxes to disable caching. // Should be able to prevent overloading the cache, or stop CDNs from // teeing the response stream into their cache, causing slowdowns. writer.Header().Set("Cache-Control", "no-store") if !h.config.NoSSEHeader { // magic header to make the HTTP middle box consider this as SSE to disable buffer writer.Header().Set("Content-Type", "text/event-stream") } writer.WriteHeader(http.StatusOK) writer.(http.Flusher).Flush() httpSC := &httpServerConn{ Instance: done.New(), Reader: request.Body, ResponseWriter: writer, } conn := splitConn{ writer: httpSC, reader: httpSC, remoteAddr: remoteAddr, localAddr: h.localAddr, } if sessionId != "" { // if not stream-one conn.reader = currentSession.uploadQueue } h.ln.addConn(stat.Connection(&conn)) // "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned." select { case <-request.Context().Done(): case <-httpSC.Wait(): } conn.Close() } else { errors.LogInfo(context.Background(), "unsupported method: ", request.Method) writer.WriteHeader(http.StatusMethodNotAllowed) } } type httpServerConn struct { sync.Mutex *done.Instance io.Reader // no need to Close request.Body http.ResponseWriter } func (c *httpServerConn) Write(b []byte) (int, error) { c.Lock() defer c.Unlock() if c.Done() { return 0, io.ErrClosedPipe } n, err := c.ResponseWriter.Write(b) if err == nil { c.ResponseWriter.(http.Flusher).Flush() } return n, err } func (c *httpServerConn) Close() error { c.Lock() defer c.Unlock() return c.Instance.Close() } type Listener struct { sync.Mutex server http.Server h3server *http3.Server listener net.Listener h3listener *quic.EarlyListener config *Config addConn internet.ConnHandler isH3 bool } func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { l := &Listener{ addConn: addConn, } l.config = streamSettings.ProtocolSettings.(*Config) if l.config != nil { if streamSettings.SocketSettings == nil { streamSettings.SocketSettings = &internet.SocketConfig{} } } handler := &requestHandler{ config: l.config, host: l.config.Host, path: l.config.GetNormalizedPath(), ln: l, sessionMu: &sync.Mutex{}, sessions: sync.Map{}, socketSettings: streamSettings.SocketSettings, } tlsConfig := getTLSConfig(streamSettings) l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3" var err error if port == net.Port(0) { // unix l.listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ Name: address.Domain(), Net: "unix", }, streamSettings.SocketSettings) if err != nil { return nil, errors.New("failed to listen UNIX domain socket for XHTTP on ", address).Base(err) } errors.LogInfo(ctx, "listening UNIX domain socket for XHTTP on ", address) } else if l.isH3 { // quic Conn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{ IP: address.IP(), Port: int(port), }, streamSettings.SocketSettings) if err != nil { return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err) } if streamSettings.UdpmaskManager != nil { pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnServer(Conn) if err != nil { Conn.Close() return nil, errors.New("mask err").Base(err) } Conn = pktConn } quicParams := streamSettings.QuicParams if quicParams == nil { quicParams = &internet.QuicParams{ BbrProfile: string(bbr.ProfileStandard), UdpHop: &internet.UdpHop{}, } } quicConfig := &quic.Config{ InitialStreamReceiveWindow: quicParams.InitStreamReceiveWindow, MaxStreamReceiveWindow: quicParams.MaxStreamReceiveWindow, InitialConnectionReceiveWindow: quicParams.InitConnReceiveWindow, MaxConnectionReceiveWindow: quicParams.MaxConnReceiveWindow, MaxIdleTimeout: time.Duration(quicParams.MaxIdleTimeout) * time.Second, MaxIncomingStreams: quicParams.MaxIncomingStreams, DisablePathMTUDiscovery: quicParams.DisablePathMtuDiscovery, } l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, quicConfig) if err != nil { return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err) } errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port) handler.localAddr = l.h3listener.Addr() l.h3server = &http3.Server{ Handler: handler, } go func() { for { conn, err := l.h3listener.Accept(context.Background()) if err != nil { errors.LogInfoInner(ctx, err, "XHTTP/3 listener closed") return } switch quicParams.Congestion { case "force-brutal": errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion brutal bytes per second ", quicParams.BrutalUp) congestion.UseBrutal(conn, quicParams.BrutalUp) case "reno": errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion reno") default: errors.LogDebug(context.Background(), conn.RemoteAddr(), " ", "congestion bbr ", quicParams.BbrProfile) congestion.UseBBR(conn, bbr.Profile(quicParams.BbrProfile)) } go func() { if err := l.h3server.ServeQUICConn(conn); err != nil { errors.LogDebugInner(ctx, err, "XHTTP/3 connection ended") } _ = conn.CloseWithError(0, "") }() } }() } else { // tcp l.listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), Port: int(port), }, streamSettings.SocketSettings) if err != nil { return nil, errors.New("failed to listen TCP for XHTTP on ", address, ":", port).Base(err) } errors.LogInfo(ctx, "listening TCP for XHTTP on ", address, ":", port) } if !l.isH3 && streamSettings.TcpmaskManager != nil { l.listener, _ = streamSettings.TcpmaskManager.WrapListener(l.listener) } // tcp/unix (h1/h2) if l.listener != nil { if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { l.listener = gotls.NewListener(l.listener, tlsConfig) } } if config := reality.ConfigFromStreamSettings(streamSettings); config != nil { l.listener = goreality.NewListener(l.listener, config.GetREALITYConfig()) } handler.localAddr = l.listener.Addr() // server can handle both plaintext HTTP/1.1 and h2c protocols := new(http.Protocols) protocols.SetHTTP1(true) protocols.SetUnencryptedHTTP2(true) l.server = http.Server{ Handler: handler, ReadHeaderTimeout: time.Second * 4, MaxHeaderBytes: l.config.GetNormalizedServerMaxHeaderBytes(), Protocols: protocols, } go func() { if err := l.server.Serve(l.listener); err != nil { errors.LogErrorInner(ctx, err, "failed to serve HTTP for XHTTP") } }() } return l, err } // Addr implements net.Listener.Addr(). func (ln *Listener) Addr() net.Addr { if ln.h3listener != nil { return ln.h3listener.Addr() } if ln.listener != nil { return ln.listener.Addr() } return nil } // Close implements net.Listener.Close(). func (ln *Listener) Close() error { if ln.h3server != nil { if err := ln.h3server.Close(); err != nil { _ = ln.h3listener.Close() return err } return ln.h3listener.Close() } else if ln.listener != nil { return ln.listener.Close() } return errors.New("listener does not have an HTTP/3 server or a net.listener") } func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *gotls.Config { config := tls.ConfigFromStreamSettings(streamSettings) if config == nil { return &gotls.Config{} } return config.GetTLSConfig() } func init() { common.Must(internet.RegisterTransportListener(protocolName, ListenXH)) }