diff --git a/transport/internet/splithttp/browser_client.go b/transport/internet/splithttp/browser_client.go index 1ae3ae95..a70447f2 100644 --- a/transport/internet/splithttp/browser_client.go +++ b/transport/internet/splithttp/browser_client.go @@ -5,6 +5,7 @@ import ( "io" "net/http" + "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/transport/internet/browser_dialer" @@ -41,21 +42,20 @@ func (c *BrowserDialerClient) OpenStream(ctx context.Context, url string, sessio return websocket.NewConnection(conn, dummyAddr, nil, 0), conn.RemoteAddr(), conn.LocalAddr(), nil } -func (c *BrowserDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, body io.Reader, contentLength int64) error { +func (c *BrowserDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, payload buf.MultiBuffer) error { method := c.transportConfig.GetNormalizedUplinkHTTPMethod() - request, err := http.NewRequest(method, url, body) + request, err := http.NewRequest(method, url, nil) if err != nil { return err } - request.ContentLength = contentLength - err = c.transportConfig.FillPacketRequest(request, sessionId, seqStr) + err = c.transportConfig.FillPacketRequest(request, sessionId, seqStr, payload) if err != nil { return err } var bytes []byte - if (request.Body != nil) { + if request.Body != nil { bytes, err = io.ReadAll(request.Body) if err != nil { return err diff --git a/transport/internet/splithttp/client.go b/transport/internet/splithttp/client.go index 38990c1a..c156509a 100644 --- a/transport/internet/splithttp/client.go +++ b/transport/internet/splithttp/client.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/signal/done" @@ -23,7 +24,7 @@ type DialerClient interface { OpenStream(context.Context, string, string, io.Reader, bool) (io.ReadCloser, net.Addr, net.Addr, error) // ctx, url, sessionId, seqStr, body, contentLength - PostPacket(context.Context, string, string, string, io.Reader, int64) error + PostPacket(context.Context, string, string, string, buf.MultiBuffer) error } // implements splithttp.DialerClient in terms of direct network connections @@ -89,14 +90,13 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, sessio return } -func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, body io.Reader, contentLength int64) error { +func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, payload buf.MultiBuffer) error { method := c.transportConfig.GetNormalizedUplinkHTTPMethod() - req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, body) + req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, nil) if err != nil { return err } - req.ContentLength = contentLength - c.transportConfig.FillPacketRequest(req, sessionId, seqStr) + c.transportConfig.FillPacketRequest(req, sessionId, seqStr, payload) if c.httpVersion != "1.1" { resp, err := c.client.Do(req) @@ -117,6 +117,7 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessio // times, the body is already drained after the first // request requestBuff := new(bytes.Buffer) + requestBuff.Grow(512 + int(req.ContentLength)) common.Must(req.Write(requestBuff)) var uploadConn any diff --git a/transport/internet/splithttp/config.go b/transport/internet/splithttp/config.go index 03ed591c..61f861a3 100644 --- a/transport/internet/splithttp/config.go +++ b/transport/internet/splithttp/config.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/crypto" "github.com/xtls/xray-core/common/utils" "github.com/xtls/xray-core/transport/internet" @@ -55,7 +56,6 @@ func (c *Config) GetRequestHeader() http.Header { return header } - func (c *Config) GetRequestHeaderWithPayload(payload []byte) http.Header { header := c.GetRequestHeader() @@ -100,9 +100,9 @@ func (c *Config) WriteResponseHeader(writer http.ResponseWriter, requestMethod s } if c.GetNormalizedSessionPlacement() == PlacementCookie || - c.GetNormalizedSeqPlacement() == PlacementCookie || - c.XPaddingPlacement == PlacementCookie || - c.GetNormalizedUplinkDataPlacement() == PlacementCookie { + c.GetNormalizedSeqPlacement() == PlacementCookie || + c.XPaddingPlacement == PlacementCookie || + c.GetNormalizedUplinkDataPlacement() == PlacementCookie { writer.Header().Set("Access-Control-Allow-Credentials", "true") } @@ -322,22 +322,17 @@ func (c *Config) FillStreamRequest(request *http.Request, sessionId string, seqS } } -func (c *Config) FillPacketRequest(request *http.Request, sessionId string, seqStr string) error { +func (c *Config) FillPacketRequest(request *http.Request, sessionId string, seqStr string, payload buf.MultiBuffer) error { dataPlacement := c.GetNormalizedUplinkDataPlacement() if dataPlacement == PlacementBody || dataPlacement == PlacementAuto { request.Header = c.GetRequestHeader() + request.Body = io.NopCloser(&buf.MultiBufferContainer{MultiBuffer: payload}) + request.ContentLength = int64(payload.Len()) } else { - var data []byte - var err error - if request.Body != nil { - data, err = io.ReadAll(request.Body) - if err != nil { - return err - } - } - request.Body = nil - request.ContentLength = 0 + data := make([]byte, payload.Len()) + payload.Copy(data) + buf.ReleaseMulti(payload) switch dataPlacement { case PlacementHeader: request.Header = c.GetRequestHeaderWithPayload(data) diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 71888a73..0c351a5a 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -562,8 +562,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me requestURL.String(), sessionId, seqStr, - &buf.MultiBufferContainer{MultiBuffer: chunk}, - int64(chunk.Len()), + chunk, ) wroteRequest.Close() if err != nil { diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index c2ab8fbf..1ffdf6f2 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -18,6 +18,7 @@ import ( "github.com/apernet/quic-go/http3" goreality "github.com/xtls/reality" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" http_proto "github.com/xtls/xray-core/common/protocol/http" @@ -293,15 +294,36 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req var bodyPayload []byte if dataPlacement == PlacementAuto || dataPlacement == PlacementBody { - bodyPayload, err = io.ReadAll(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1)) - if err != nil { - errors.LogInfoInner(context.Background(), err, "failed to upload (ReadAll)") - writer.WriteHeader(http.StatusInternalServerError) + var readErr error + if request.ContentLength > int64(scMaxEachPostBytes) { + errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.") + writer.WriteHeader(http.StatusRequestEntityTooLarge) + return + } + if request.ContentLength > 0 { + bodyPayload = make([]byte, request.ContentLength) + _, readErr = io.ReadFull(request.Body, bodyPayload) + } else { + bodyPayload, readErr = buf.ReadAllToBytes(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1)) + } + if readErr != nil { + errors.LogInfoInner(context.Background(), readErr, "failed to read body payload") + writer.WriteHeader(http.StatusBadRequest) return } } - payload := slices.Concat(headerPayload, cookiePayload, bodyPayload) + var payload []byte + switch dataPlacement { + case PlacementHeader: + payload = headerPayload + case PlacementCookie: + payload = cookiePayload + case PlacementBody: + payload = bodyPayload + case PlacementAuto: + payload = slices.Concat(headerPayload, cookiePayload, bodyPayload) + } if len(payload) > scMaxEachPostBytes { errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.")