header-custom finalmask: Extend expression primitives for 1:1 handshakes (#5949)

https://github.com/XTLS/Xray-core/pull/5945
https://github.com/XTLS/Xray-core/pull/5920
This commit is contained in:
Иван
2026-04-18 05:01:54 +07:00
committed by GitHub
parent df4b97097c
commit cb1106c2fb
3 changed files with 717 additions and 42 deletions

View File

@@ -9,8 +9,9 @@ import (
)
type evalValue struct {
bytes []byte
u64 *uint64
bytes []byte
u64 *uint64
isBytes bool
}
type evalContext struct {
@@ -175,7 +176,7 @@ func evaluateExpr(expr *Expr, ctx *evalContext) (evalValue, error) {
}
out = append(out, bytesValue...)
}
return evalValue{bytes: out}, nil
return evalValue{bytes: out, isBytes: true}, nil
case "slice":
if len(expr.GetArgs()) != 3 {
return evalValue{}, errors.New("slice expects 3 args")
@@ -208,52 +209,236 @@ func evaluateExpr(expr *Expr, ctx *evalContext) (evalValue, error) {
if end > uint64(len(sourceBytes)) {
return evalValue{}, errors.New("slice out of bounds")
}
return evalValue{bytes: append([]byte(nil), sourceBytes[offsetU64:end]...)}, nil
return evalValue{bytes: append([]byte(nil), sourceBytes[offsetU64:end]...), isBytes: true}, nil
case "xor16":
return evaluateXor(expr.GetArgs(), 0xFFFF, 2, ctx)
case "xor32":
return evaluateXor(expr.GetArgs(), 0xFFFFFFFF, 4, ctx)
case "be16":
if len(expr.GetArgs()) != 1 {
return evalValue{}, errors.New("be16 expects 1 arg")
}
value, err := evaluateExprArg(expr.GetArgs()[0], ctx)
if err != nil {
return evalValue{}, err
}
u64Value, err := value.asU64()
if err != nil {
return evalValue{}, err
}
if u64Value > 0xFFFF {
return evalValue{}, errors.New("be16 overflow")
}
out := make([]byte, 2)
binary.BigEndian.PutUint16(out, uint16(u64Value))
return evalValue{bytes: out}, nil
return evaluatePack(expr.GetArgs(), "be16", 2, binary.BigEndian, ctx)
case "be32":
if len(expr.GetArgs()) != 1 {
return evalValue{}, errors.New("be32 expects 1 arg")
}
value, err := evaluateExprArg(expr.GetArgs()[0], ctx)
if err != nil {
return evalValue{}, err
}
u64Value, err := value.asU64()
if err != nil {
return evalValue{}, err
}
if u64Value > 0xFFFFFFFF {
return evalValue{}, errors.New("be32 overflow")
}
out := make([]byte, 4)
binary.BigEndian.PutUint32(out, uint32(u64Value))
return evalValue{bytes: out}, nil
return evaluatePack(expr.GetArgs(), "be32", 4, binary.BigEndian, ctx)
case "le16":
return evaluatePack(expr.GetArgs(), "le16", 2, binary.LittleEndian, ctx)
case "le32":
return evaluatePack(expr.GetArgs(), "le32", 4, binary.LittleEndian, ctx)
case "le64":
return evaluatePack(expr.GetArgs(), "le64", 8, binary.LittleEndian, ctx)
case "pad":
return evaluatePad(expr.GetArgs(), ctx)
case "truncate":
return evaluateTruncate(expr.GetArgs(), ctx)
case "add":
return evaluateBinaryU64Op(expr.GetArgs(), "add", ctx, func(left, right uint64) (uint64, error) {
if left > ^uint64(0)-right {
return 0, errors.New("add overflow")
}
return left + right, nil
})
case "sub":
return evaluateBinaryU64Op(expr.GetArgs(), "sub", ctx, func(left, right uint64) (uint64, error) {
if left < right {
return 0, errors.New("sub underflow")
}
return left - right, nil
})
case "and":
return evaluateBinaryU64Op(expr.GetArgs(), "and", ctx, func(left, right uint64) (uint64, error) {
return left & right, nil
})
case "or":
return evaluateBinaryU64Op(expr.GetArgs(), "or", ctx, func(left, right uint64) (uint64, error) {
return left | right, nil
})
case "shl":
return evaluateShift(expr.GetArgs(), "shl", ctx, func(value uint64, shift uint) (uint64, error) {
if shift >= 64 {
return 0, errors.New("shift out of range")
}
if value > (^uint64(0) >> shift) {
return 0, errors.New("shl overflow")
}
return value << shift, nil
})
case "shr":
return evaluateShift(expr.GetArgs(), "shr", ctx, func(value uint64, shift uint) (uint64, error) {
if shift >= 64 {
return 0, errors.New("shift out of range")
}
return value >> shift, nil
})
default:
return evalValue{}, errors.New("unsupported expr op: ", expr.GetOp())
}
}
func evaluatePack(args []*ExprArg, name string, width int, order binary.ByteOrder, ctx *evalContext) (evalValue, error) {
if len(args) != 1 {
return evalValue{}, errors.New(name, " expects 1 arg")
}
value, err := evaluateExprArg(args[0], ctx)
if err != nil {
return evalValue{}, err
}
u64Value, err := value.asU64()
if err != nil {
return evalValue{}, err
}
switch width {
case 2:
if u64Value > 0xFFFF {
return evalValue{}, errors.New(name, " overflow")
}
out := make([]byte, 2)
order.PutUint16(out, uint16(u64Value))
return evalValue{bytes: out, isBytes: true}, nil
case 4:
if u64Value > 0xFFFFFFFF {
return evalValue{}, errors.New(name, " overflow")
}
out := make([]byte, 4)
order.PutUint32(out, uint32(u64Value))
return evalValue{bytes: out, isBytes: true}, nil
case 8:
out := make([]byte, 8)
order.PutUint64(out, u64Value)
return evalValue{bytes: out, isBytes: true}, nil
default:
return evalValue{}, errors.New("unsupported pack width")
}
}
func evaluatePad(args []*ExprArg, ctx *evalContext) (evalValue, error) {
if len(args) != 3 {
return evalValue{}, errors.New("pad expects 3 args")
}
source, err := evaluateExprArg(args[0], ctx)
if err != nil {
return evalValue{}, err
}
target, err := evaluateExprArg(args[1], ctx)
if err != nil {
return evalValue{}, err
}
fill, err := evaluateExprArg(args[2], ctx)
if err != nil {
return evalValue{}, err
}
sourceBytes, err := source.asBytes()
if err != nil {
return evalValue{}, err
}
targetU64, err := target.asU64()
if err != nil {
return evalValue{}, err
}
fillBytes, err := fill.asBytes()
if err != nil {
return evalValue{}, err
}
if len(fillBytes) == 0 {
return evalValue{}, errors.New("pad fill must not be empty")
}
if targetU64 < uint64(len(sourceBytes)) {
return evalValue{}, errors.New("pad target shorter than source")
}
out := append([]byte(nil), sourceBytes...)
for uint64(len(out)) < targetU64 {
remaining := int(targetU64) - len(out)
if remaining >= len(fillBytes) {
out = append(out, fillBytes...)
continue
}
out = append(out, fillBytes[:remaining]...)
}
return evalValue{bytes: out, isBytes: true}, nil
}
func evaluateTruncate(args []*ExprArg, ctx *evalContext) (evalValue, error) {
if len(args) != 2 {
return evalValue{}, errors.New("truncate expects 2 args")
}
source, err := evaluateExprArg(args[0], ctx)
if err != nil {
return evalValue{}, err
}
length, err := evaluateExprArg(args[1], ctx)
if err != nil {
return evalValue{}, err
}
sourceBytes, err := source.asBytes()
if err != nil {
return evalValue{}, err
}
lengthU64, err := length.asU64()
if err != nil {
return evalValue{}, err
}
if lengthU64 > uint64(len(sourceBytes)) {
return evalValue{}, errors.New("truncate out of bounds")
}
return evalValue{bytes: append([]byte(nil), sourceBytes[:lengthU64]...), isBytes: true}, nil
}
func evaluateBinaryU64Op(args []*ExprArg, name string, ctx *evalContext, op func(left, right uint64) (uint64, error)) (evalValue, error) {
if len(args) != 2 {
return evalValue{}, errors.New(name, " expects 2 args")
}
left, err := evaluateExprArg(args[0], ctx)
if err != nil {
return evalValue{}, err
}
right, err := evaluateExprArg(args[1], ctx)
if err != nil {
return evalValue{}, err
}
leftU64, err := left.asU64()
if err != nil {
return evalValue{}, err
}
rightU64, err := right.asU64()
if err != nil {
return evalValue{}, err
}
result, err := op(leftU64, rightU64)
if err != nil {
return evalValue{}, err
}
return evalValue{u64: &result}, nil
}
func evaluateShift(args []*ExprArg, name string, ctx *evalContext, op func(value uint64, shift uint) (uint64, error)) (evalValue, error) {
if len(args) != 2 {
return evalValue{}, errors.New(name, " expects 2 args")
}
value, err := evaluateExprArg(args[0], ctx)
if err != nil {
return evalValue{}, err
}
shift, err := evaluateExprArg(args[1], ctx)
if err != nil {
return evalValue{}, err
}
valueU64, err := value.asU64()
if err != nil {
return evalValue{}, err
}
shiftU64, err := shift.asU64()
if err != nil {
return evalValue{}, err
}
if shiftU64 >= 64 {
return evalValue{}, errors.New("shift out of range")
}
result, err := op(valueU64, uint(shiftU64))
if err != nil {
return evalValue{}, err
}
return evalValue{u64: &result}, nil
}
func evaluateXor(args []*ExprArg, mask uint64, width int, ctx *evalContext) (evalValue, error) {
if len(args) != 2 {
return evalValue{}, errors.New("xor expects 2 args")
@@ -309,6 +494,30 @@ func measureExpr(expr *Expr, sizeCtx map[string]int) (int, error) {
return 2, nil
case "be32":
return 4, nil
case "le16":
return 2, nil
case "le32":
return 4, nil
case "le64":
return 8, nil
case "pad":
if len(expr.GetArgs()) != 3 {
return 0, errors.New("pad expects 3 args")
}
lengthArg := expr.GetArgs()[1]
if value, ok := lengthArg.GetValue().(*ExprArg_U64); ok {
return int(value.U64), nil
}
return 0, errors.New("pad length must be u64")
case "truncate":
if len(expr.GetArgs()) != 2 {
return 0, errors.New("truncate expects 2 args")
}
lengthArg := expr.GetArgs()[1]
if value, ok := lengthArg.GetValue().(*ExprArg_U64); ok {
return int(value.U64), nil
}
return 0, errors.New("truncate length must be u64")
default:
return 0, errors.New("expr size is not bytes for op: ", expr.GetOp())
}
@@ -317,7 +526,7 @@ func measureExpr(expr *Expr, sizeCtx map[string]int) (int, error) {
func evaluateExprArg(arg *ExprArg, ctx *evalContext) (evalValue, error) {
switch value := arg.GetValue().(type) {
case *ExprArg_Bytes:
return evalValue{bytes: append([]byte(nil), value.Bytes...)}, nil
return evalValue{bytes: append([]byte(nil), value.Bytes...), isBytes: true}, nil
case *ExprArg_U64:
return evalValue{u64: &value.U64}, nil
case *ExprArg_Var:
@@ -325,7 +534,7 @@ func evaluateExprArg(arg *ExprArg, ctx *evalContext) (evalValue, error) {
if !ok {
return evalValue{}, errors.New("unknown variable: ", value.Var)
}
return evalValue{bytes: append([]byte(nil), saved...)}, nil
return evalValue{bytes: append([]byte(nil), saved...), isBytes: true}, nil
case *ExprArg_Metadata:
metadata, ok := ctx.metadata[value.Metadata]
if !ok {
@@ -361,7 +570,7 @@ func measureExprArg(arg *ExprArg, sizeCtx map[string]int) (int, error) {
}
func (v evalValue) asBytes() ([]byte, error) {
if v.bytes != nil {
if v.isBytes {
return append([]byte(nil), v.bytes...), nil
}
return nil, errors.New("expr value is not bytes")

View File

@@ -128,3 +128,364 @@ func TestEvaluatorRejectsInvalidArgType(t *testing.T) {
t.Fatal("expected evaluator error")
}
}
func TestEvaluatorLittleEndianProducesExpectedBytes(t *testing.T) {
items := []*UDPItem{
{
Expr: &Expr{
Op: "concat",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "le16",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 0x1234}},
},
},
},
},
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "le32",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 0xA1B2C3D4}},
},
},
},
},
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "le64",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 0x0102030405060708}},
},
},
},
},
},
},
},
}
got, err := evaluateUDPItems(items)
if err != nil {
t.Fatal(err)
}
want := []byte{
0x34, 0x12,
0xD4, 0xC3, 0xB2, 0xA1,
0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01,
}
if !bytes.Equal(got, want) {
t.Fatalf("unexpected output: %x", got)
}
}
func TestEvaluatorPadAndTruncateShapeBytes(t *testing.T) {
items := []*UDPItem{
{
Expr: &Expr{
Op: "concat",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "pad",
Args: []*ExprArg{
{Value: &ExprArg_Bytes{Bytes: []byte{0xAA, 0xBB}}},
{Value: &ExprArg_U64{U64: 5}},
{Value: &ExprArg_Bytes{Bytes: []byte{0xCC, 0xDD}}},
},
},
},
},
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "truncate",
Args: []*ExprArg{
{Value: &ExprArg_Bytes{Bytes: []byte{1, 2, 3, 4}}},
{Value: &ExprArg_U64{U64: 2}},
},
},
},
},
},
},
},
}
got, err := evaluateUDPItems(items)
if err != nil {
t.Fatal(err)
}
want := []byte{0xAA, 0xBB, 0xCC, 0xDD, 0xCC, 0x01, 0x02}
if !bytes.Equal(got, want) {
t.Fatalf("unexpected output: %x", got)
}
}
func TestMeasureUDPItemsSupportsPadAndTruncate(t *testing.T) {
items := []*UDPItem{
{
Expr: &Expr{
Op: "pad",
Args: []*ExprArg{
{Value: &ExprArg_Bytes{Bytes: []byte{0xAA}}},
{Value: &ExprArg_U64{U64: 4}},
{Value: &ExprArg_Bytes{Bytes: []byte{0x00}}},
},
},
},
{
Expr: &Expr{
Op: "truncate",
Args: []*ExprArg{
{Value: &ExprArg_Bytes{Bytes: []byte{1, 2, 3, 4}}},
{Value: &ExprArg_U64{U64: 3}},
},
},
},
}
got, err := measureUDPItems(items)
if err != nil {
t.Fatal(err)
}
if got != 7 {
t.Fatalf("unexpected size: %d", got)
}
}
func TestEvaluatorArithmeticAndBitwiseProduceExpectedBytes(t *testing.T) {
items := []*UDPItem{
{
Expr: &Expr{
Op: "concat",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "be16",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "add",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 1}},
{Value: &ExprArg_U64{U64: 2}},
},
},
},
},
},
},
},
},
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "be16",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "sub",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 10}},
{Value: &ExprArg_U64{U64: 3}},
},
},
},
},
},
},
},
},
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "be16",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "and",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 0xF0F0}},
{Value: &ExprArg_U64{U64: 0x0FF0}},
},
},
},
},
},
},
},
},
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "be16",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "or",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "shl",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 1}},
{Value: &ExprArg_U64{U64: 8}},
},
},
},
},
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "shr",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 0x80}},
{Value: &ExprArg_U64{U64: 7}},
},
},
},
},
},
},
},
},
},
},
},
},
},
},
},
}
got, err := evaluateUDPItems(items)
if err != nil {
t.Fatal(err)
}
want := []byte{
0x00, 0x03,
0x00, 0x07,
0x00, 0xF0,
0x01, 0x01,
}
if !bytes.Equal(got, want) {
t.Fatalf("unexpected output: %x", got)
}
}
func TestEvaluatorRejectsInvalidShapingAndArithmetic(t *testing.T) {
tests := []struct {
name string
items []*UDPItem
match string
}{
{
name: "pad with empty fill",
items: []*UDPItem{
{
Expr: &Expr{
Op: "pad",
Args: []*ExprArg{
{Value: &ExprArg_Bytes{Bytes: []byte{0xAA}}},
{Value: &ExprArg_U64{U64: 4}},
{Value: &ExprArg_Bytes{Bytes: []byte{}}},
},
},
},
},
match: "pad fill",
},
{
name: "truncate beyond source",
items: []*UDPItem{
{
Expr: &Expr{
Op: "truncate",
Args: []*ExprArg{
{Value: &ExprArg_Bytes{Bytes: []byte{1, 2}}},
{Value: &ExprArg_U64{U64: 3}},
},
},
},
},
match: "truncate",
},
{
name: "sub underflow",
items: []*UDPItem{
{
Expr: &Expr{
Op: "be16",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "sub",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 1}},
{Value: &ExprArg_U64{U64: 2}},
},
},
},
},
},
},
},
},
match: "underflow",
},
{
name: "shift too large",
items: []*UDPItem{
{
Expr: &Expr{
Op: "be16",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "shl",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 1}},
{Value: &ExprArg_U64{U64: 64}},
},
},
},
},
},
},
},
},
match: "shift",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := evaluateUDPItems(tt.items)
if err == nil {
t.Fatal("expected evaluator error")
}
if !bytes.Contains([]byte(err.Error()), []byte(tt.match)) {
t.Fatalf("unexpected error: %v", err)
}
})
}
}

View File

@@ -1,6 +1,10 @@
package custom
import "testing"
import (
"bytes"
"net"
"testing"
)
func TestDSLUDPClientSizeTracksEvaluatedItems(t *testing.T) {
conn, err := NewConnClientUDP(&UDPConfig{
@@ -81,3 +85,104 @@ func TestDSLUDPServerRejectsMalformedVarReference(t *testing.T) {
t.Fatal("expected packet mismatch")
}
}
func TestDSLUDPClientWriteSupportsExtendedExprOps(t *testing.T) {
conn, err := NewConnClientUDP(&UDPConfig{
Client: []*UDPItem{
{
Expr: &Expr{
Op: "le16",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "add",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 1}},
{Value: &ExprArg_U64{U64: 2}},
},
},
},
},
},
},
},
{
Expr: &Expr{
Op: "pad",
Args: []*ExprArg{
{Value: &ExprArg_Bytes{Bytes: []byte{0xAA}}},
{Value: &ExprArg_U64{U64: 3}},
{Value: &ExprArg_Bytes{Bytes: []byte{0xBB}}},
},
},
},
{
Expr: &Expr{
Op: "truncate",
Args: []*ExprArg{
{Value: &ExprArg_Bytes{Bytes: []byte{1, 2, 3, 4}}},
{Value: &ExprArg_U64{U64: 2}},
},
},
},
{
Expr: &Expr{
Op: "be16",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "or",
Args: []*ExprArg{
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "shl",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 1}},
{Value: &ExprArg_U64{U64: 8}},
},
},
},
},
{
Value: &ExprArg_Expr{
Expr: &Expr{
Op: "shr",
Args: []*ExprArg{
{Value: &ExprArg_U64{U64: 0x80}},
{Value: &ExprArg_U64{U64: 7}},
},
},
},
},
},
},
},
},
},
},
},
},
}, nil)
if err != nil {
t.Fatal(err)
}
client := conn.(*udpCustomClientConn)
buf := make([]byte, client.Size())
if _, err := client.WriteTo(buf, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 53}); err != nil {
t.Fatal(err)
}
want := []byte{
0x03, 0x00,
0xAA, 0xBB, 0xBB,
0x01, 0x02,
0x01, 0x01,
}
if !bytes.Equal(buf, want) {
t.Fatalf("unexpected encoded header: %x", buf)
}
}