diff --git a/proxy/wireguard/bind.go b/proxy/wireguard/bind.go index f4ec6f25..ddbc2178 100644 --- a/proxy/wireguard/bind.go +++ b/proxy/wireguard/bind.go @@ -10,6 +10,7 @@ import ( "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" + "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/features/dns" @@ -17,7 +18,7 @@ import ( ) type netReadInfo struct { - buff []byte + buff *buf.Buffer endpoint conn.Endpoint } @@ -82,7 +83,8 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { select { case r := <-bind.readQueue: - sizes[0], eps[0] = copy(bufs[0], r.buff), r.endpoint + sizes[0], eps[0] = copy(bufs[0], r.buff.Bytes()), r.endpoint + r.buff.Release() return 1, nil case <-bind.closedCh: errors.LogDebug(context.Background(), "recv func closed") @@ -130,27 +132,30 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { go func() { for { - buff := make([]byte, device.MaxMessageSize) - n, err := c.Read(buff) + buff := buf.NewWithSize(device.MaxMessageSize) + n, err := buff.ReadFrom(c) if err != nil { + buff.Release() endpoint.conn = nil c.Close() return } + rawBytes := buff.Bytes() if n > 3 { - buff[1] = 0 - buff[2] = 0 - buff[3] = 0 + rawBytes[1] = 0 + rawBytes[2] = 0 + rawBytes[3] = 0 } select { case bind.readQueue <- &netReadInfo{ - buff: buff[:n], + buff: buff, endpoint: endpoint, }: case <-bind.closedCh: + buff.Release() endpoint.conn = nil c.Close() return diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go index 9e82ed39..1f358a38 100644 --- a/proxy/wireguard/server.go +++ b/proxy/wireguard/server.go @@ -101,17 +101,17 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con } for i, b := range mb { - buff := b.Bytes() + rawBytes := b.Bytes() if b.Len() > 3 { - buff[1] = 0 - buff[2] = 0 - buff[3] = 0 + rawBytes[1] = 0 + rawBytes[2] = 0 + rawBytes[3] = 0 } select { case s.bindServer.readQueue <- &netReadInfo{ - buff: buff, + buff: b, endpoint: nep, }: case <-s.bindServer.closedCh: