diff --git a/common/buf/copy.go b/common/buf/copy.go index 4cc3be88..72ee3ed8 100644 --- a/common/buf/copy.go +++ b/common/buf/copy.go @@ -56,6 +56,10 @@ type readError struct { error } +func NewReadError(err error) error { + return readError{err} +} + func (e readError) Error() string { return e.error.Error() } @@ -74,6 +78,10 @@ type writeError struct { error } +func NewWriteError(err error) error { + return writeError{err} +} + func (e writeError) Error() string { return e.error.Error() } diff --git a/common/mux/client.go b/common/mux/client.go index 28380331..47943a2b 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -332,14 +332,14 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error { if meta.Option.Has(OptionData) { - return buf.Copy(NewStreamReader(reader), buf.Discard) + return CopyChunk(reader, buf.Discard) } return nil } func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error { if meta.Option.Has(OptionData) { - return buf.Copy(NewStreamReader(reader), buf.Discard) + return CopyChunk(reader, buf.Discard) } return nil } @@ -355,7 +355,19 @@ func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream) closingWriter.Close() - return buf.Copy(NewStreamReader(reader), buf.Discard) + return CopyChunk(reader, buf.Discard) + } + + if s.transferType == protocol.TransferTypeStream { + err := CopyChunk(reader, s.output) + if err != nil && buf.IsWriteError(err) { + errors.LogInfoInner(context.Background(), err, "failed to write to downstream. closing session ", s.ID) + s.Close(false) + // down stream can have a write err but don't return the err to terminate the whole mux connection + // because it's still available for other sessions + return nil + } + return err } rr := s.NewReader(reader, &meta.Target) @@ -374,7 +386,7 @@ func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.Buffered s.Close(false) } if meta.Option.Has(OptionData) { - return buf.Copy(NewStreamReader(reader), buf.Discard) + return CopyChunk(reader, buf.Discard) } return nil } diff --git a/common/mux/reader.go b/common/mux/reader.go index b9714cdf..d27a4f21 100644 --- a/common/mux/reader.go +++ b/common/mux/reader.go @@ -57,3 +57,32 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { func NewStreamReader(reader *buf.BufferedReader) buf.Reader { return crypto.NewChunkStreamReaderWithChunkCount(crypto.PlainChunkSizeParser{}, reader, 1) } + +func CopyChunk(reader *buf.BufferedReader, writer buf.Writer) error { + size, err := serial.ReadUint16(reader) + if err != nil { + return err + } + var writeErr error + for size > 0 { + mb, readErr := reader.ReadAtMost(int32(size)) + if !mb.IsEmpty() { + size -= uint16(mb.Len()) + if writeErr == nil { + if err := writer.WriteMultiBuffer(mb); err != nil { + writeErr = err + } + } else { + buf.ReleaseMulti(mb) + } + continue + } + if readErr != nil { + return buf.NewReadError(readErr) + } + } + if writeErr != nil { + return buf.NewWriteError(writeErr) + } + return nil +} diff --git a/common/mux/server.go b/common/mux/server.go index d1cdac11..87aaf451 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -157,7 +157,7 @@ func (w *ServerWorker) Close() error { func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error { if meta.Option.Has(OptionData) { - return buf.Copy(NewStreamReader(reader), buf.Discard) + return CopyChunk(reader, buf.Discard) } return nil } @@ -264,7 +264,7 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, link, err := w.dispatcher.Dispatch(ctx, meta.Target) if err != nil { if meta.Option.Has(OptionData) { - buf.Copy(NewStreamReader(reader), buf.Discard) + CopyChunk(reader, buf.Discard) } return errors.New("failed to dispatch request.").Base(err) } @@ -287,6 +287,15 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, return nil } + if s.transferType == protocol.TransferTypeStream { + err = CopyChunk(reader, s.output) + if err != nil && buf.IsWriteError(err) { + s.Close(false) + return err + } + return err + } + rr := s.NewReader(reader, &meta.Target) err = buf.Copy(rr, s.output) @@ -308,7 +317,19 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere closingWriter := NewResponseWriter(meta.SessionID, w.link.Writer, protocol.TransferTypeStream) closingWriter.Close() - return buf.Copy(NewStreamReader(reader), buf.Discard) + return CopyChunk(reader, buf.Discard) + } + + if s.transferType == protocol.TransferTypeStream { + err := CopyChunk(reader, s.output) + if err != nil && buf.IsWriteError(err) { + errors.LogInfoInner(context.Background(), err, "failed to write to downstream writer. closing session ", s.ID) + s.Close(false) + // down stream can have a write err but don't return the err to terminate the whole mux connection + // because it's still available for other sessions + return nil + } + return err } rr := s.NewReader(reader, &meta.Target) @@ -328,7 +349,7 @@ func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.Buffered s.Close(false) } if meta.Option.Has(OptionData) { - return buf.Copy(NewStreamReader(reader), buf.Discard) + return CopyChunk(reader, buf.Discard) } return nil }