diff --git a/transport/internet/finalmask/header/custom/evaluator.go b/transport/internet/finalmask/header/custom/evaluator.go index 46655700..728be543 100644 --- a/transport/internet/finalmask/header/custom/evaluator.go +++ b/transport/internet/finalmask/header/custom/evaluator.go @@ -9,8 +9,9 @@ import ( ) type evalValue struct { - bytes []byte - u64 *uint64 + bytes []byte + u64 *uint64 + isBytes bool } type evalContext struct { @@ -175,7 +176,7 @@ func evaluateExpr(expr *Expr, ctx *evalContext) (evalValue, error) { } out = append(out, bytesValue...) } - return evalValue{bytes: out}, nil + return evalValue{bytes: out, isBytes: true}, nil case "slice": if len(expr.GetArgs()) != 3 { return evalValue{}, errors.New("slice expects 3 args") @@ -208,52 +209,236 @@ func evaluateExpr(expr *Expr, ctx *evalContext) (evalValue, error) { if end > uint64(len(sourceBytes)) { return evalValue{}, errors.New("slice out of bounds") } - return evalValue{bytes: append([]byte(nil), sourceBytes[offsetU64:end]...)}, nil + return evalValue{bytes: append([]byte(nil), sourceBytes[offsetU64:end]...), isBytes: true}, nil case "xor16": return evaluateXor(expr.GetArgs(), 0xFFFF, 2, ctx) case "xor32": return evaluateXor(expr.GetArgs(), 0xFFFFFFFF, 4, ctx) case "be16": - if len(expr.GetArgs()) != 1 { - return evalValue{}, errors.New("be16 expects 1 arg") - } - value, err := evaluateExprArg(expr.GetArgs()[0], ctx) - if err != nil { - return evalValue{}, err - } - u64Value, err := value.asU64() - if err != nil { - return evalValue{}, err - } - if u64Value > 0xFFFF { - return evalValue{}, errors.New("be16 overflow") - } - out := make([]byte, 2) - binary.BigEndian.PutUint16(out, uint16(u64Value)) - return evalValue{bytes: out}, nil + return evaluatePack(expr.GetArgs(), "be16", 2, binary.BigEndian, ctx) case "be32": - if len(expr.GetArgs()) != 1 { - return evalValue{}, errors.New("be32 expects 1 arg") - } - value, err := evaluateExprArg(expr.GetArgs()[0], ctx) - if err != nil { - return evalValue{}, err - } - u64Value, err := value.asU64() - if err != nil { - return evalValue{}, err - } - if u64Value > 0xFFFFFFFF { - return evalValue{}, errors.New("be32 overflow") - } - out := make([]byte, 4) - binary.BigEndian.PutUint32(out, uint32(u64Value)) - return evalValue{bytes: out}, nil + return evaluatePack(expr.GetArgs(), "be32", 4, binary.BigEndian, ctx) + case "le16": + return evaluatePack(expr.GetArgs(), "le16", 2, binary.LittleEndian, ctx) + case "le32": + return evaluatePack(expr.GetArgs(), "le32", 4, binary.LittleEndian, ctx) + case "le64": + return evaluatePack(expr.GetArgs(), "le64", 8, binary.LittleEndian, ctx) + case "pad": + return evaluatePad(expr.GetArgs(), ctx) + case "truncate": + return evaluateTruncate(expr.GetArgs(), ctx) + case "add": + return evaluateBinaryU64Op(expr.GetArgs(), "add", ctx, func(left, right uint64) (uint64, error) { + if left > ^uint64(0)-right { + return 0, errors.New("add overflow") + } + return left + right, nil + }) + case "sub": + return evaluateBinaryU64Op(expr.GetArgs(), "sub", ctx, func(left, right uint64) (uint64, error) { + if left < right { + return 0, errors.New("sub underflow") + } + return left - right, nil + }) + case "and": + return evaluateBinaryU64Op(expr.GetArgs(), "and", ctx, func(left, right uint64) (uint64, error) { + return left & right, nil + }) + case "or": + return evaluateBinaryU64Op(expr.GetArgs(), "or", ctx, func(left, right uint64) (uint64, error) { + return left | right, nil + }) + case "shl": + return evaluateShift(expr.GetArgs(), "shl", ctx, func(value uint64, shift uint) (uint64, error) { + if shift >= 64 { + return 0, errors.New("shift out of range") + } + if value > (^uint64(0) >> shift) { + return 0, errors.New("shl overflow") + } + return value << shift, nil + }) + case "shr": + return evaluateShift(expr.GetArgs(), "shr", ctx, func(value uint64, shift uint) (uint64, error) { + if shift >= 64 { + return 0, errors.New("shift out of range") + } + return value >> shift, nil + }) default: return evalValue{}, errors.New("unsupported expr op: ", expr.GetOp()) } } +func evaluatePack(args []*ExprArg, name string, width int, order binary.ByteOrder, ctx *evalContext) (evalValue, error) { + if len(args) != 1 { + return evalValue{}, errors.New(name, " expects 1 arg") + } + value, err := evaluateExprArg(args[0], ctx) + if err != nil { + return evalValue{}, err + } + u64Value, err := value.asU64() + if err != nil { + return evalValue{}, err + } + + switch width { + case 2: + if u64Value > 0xFFFF { + return evalValue{}, errors.New(name, " overflow") + } + out := make([]byte, 2) + order.PutUint16(out, uint16(u64Value)) + return evalValue{bytes: out, isBytes: true}, nil + case 4: + if u64Value > 0xFFFFFFFF { + return evalValue{}, errors.New(name, " overflow") + } + out := make([]byte, 4) + order.PutUint32(out, uint32(u64Value)) + return evalValue{bytes: out, isBytes: true}, nil + case 8: + out := make([]byte, 8) + order.PutUint64(out, u64Value) + return evalValue{bytes: out, isBytes: true}, nil + default: + return evalValue{}, errors.New("unsupported pack width") + } +} + +func evaluatePad(args []*ExprArg, ctx *evalContext) (evalValue, error) { + if len(args) != 3 { + return evalValue{}, errors.New("pad expects 3 args") + } + source, err := evaluateExprArg(args[0], ctx) + if err != nil { + return evalValue{}, err + } + target, err := evaluateExprArg(args[1], ctx) + if err != nil { + return evalValue{}, err + } + fill, err := evaluateExprArg(args[2], ctx) + if err != nil { + return evalValue{}, err + } + sourceBytes, err := source.asBytes() + if err != nil { + return evalValue{}, err + } + targetU64, err := target.asU64() + if err != nil { + return evalValue{}, err + } + fillBytes, err := fill.asBytes() + if err != nil { + return evalValue{}, err + } + if len(fillBytes) == 0 { + return evalValue{}, errors.New("pad fill must not be empty") + } + if targetU64 < uint64(len(sourceBytes)) { + return evalValue{}, errors.New("pad target shorter than source") + } + + out := append([]byte(nil), sourceBytes...) + for uint64(len(out)) < targetU64 { + remaining := int(targetU64) - len(out) + if remaining >= len(fillBytes) { + out = append(out, fillBytes...) + continue + } + out = append(out, fillBytes[:remaining]...) + } + return evalValue{bytes: out, isBytes: true}, nil +} + +func evaluateTruncate(args []*ExprArg, ctx *evalContext) (evalValue, error) { + if len(args) != 2 { + return evalValue{}, errors.New("truncate expects 2 args") + } + source, err := evaluateExprArg(args[0], ctx) + if err != nil { + return evalValue{}, err + } + length, err := evaluateExprArg(args[1], ctx) + if err != nil { + return evalValue{}, err + } + sourceBytes, err := source.asBytes() + if err != nil { + return evalValue{}, err + } + lengthU64, err := length.asU64() + if err != nil { + return evalValue{}, err + } + if lengthU64 > uint64(len(sourceBytes)) { + return evalValue{}, errors.New("truncate out of bounds") + } + return evalValue{bytes: append([]byte(nil), sourceBytes[:lengthU64]...), isBytes: true}, nil +} + +func evaluateBinaryU64Op(args []*ExprArg, name string, ctx *evalContext, op func(left, right uint64) (uint64, error)) (evalValue, error) { + if len(args) != 2 { + return evalValue{}, errors.New(name, " expects 2 args") + } + left, err := evaluateExprArg(args[0], ctx) + if err != nil { + return evalValue{}, err + } + right, err := evaluateExprArg(args[1], ctx) + if err != nil { + return evalValue{}, err + } + leftU64, err := left.asU64() + if err != nil { + return evalValue{}, err + } + rightU64, err := right.asU64() + if err != nil { + return evalValue{}, err + } + result, err := op(leftU64, rightU64) + if err != nil { + return evalValue{}, err + } + return evalValue{u64: &result}, nil +} + +func evaluateShift(args []*ExprArg, name string, ctx *evalContext, op func(value uint64, shift uint) (uint64, error)) (evalValue, error) { + if len(args) != 2 { + return evalValue{}, errors.New(name, " expects 2 args") + } + value, err := evaluateExprArg(args[0], ctx) + if err != nil { + return evalValue{}, err + } + shift, err := evaluateExprArg(args[1], ctx) + if err != nil { + return evalValue{}, err + } + valueU64, err := value.asU64() + if err != nil { + return evalValue{}, err + } + shiftU64, err := shift.asU64() + if err != nil { + return evalValue{}, err + } + if shiftU64 >= 64 { + return evalValue{}, errors.New("shift out of range") + } + result, err := op(valueU64, uint(shiftU64)) + if err != nil { + return evalValue{}, err + } + return evalValue{u64: &result}, nil +} + func evaluateXor(args []*ExprArg, mask uint64, width int, ctx *evalContext) (evalValue, error) { if len(args) != 2 { return evalValue{}, errors.New("xor expects 2 args") @@ -309,6 +494,30 @@ func measureExpr(expr *Expr, sizeCtx map[string]int) (int, error) { return 2, nil case "be32": return 4, nil + case "le16": + return 2, nil + case "le32": + return 4, nil + case "le64": + return 8, nil + case "pad": + if len(expr.GetArgs()) != 3 { + return 0, errors.New("pad expects 3 args") + } + lengthArg := expr.GetArgs()[1] + if value, ok := lengthArg.GetValue().(*ExprArg_U64); ok { + return int(value.U64), nil + } + return 0, errors.New("pad length must be u64") + case "truncate": + if len(expr.GetArgs()) != 2 { + return 0, errors.New("truncate expects 2 args") + } + lengthArg := expr.GetArgs()[1] + if value, ok := lengthArg.GetValue().(*ExprArg_U64); ok { + return int(value.U64), nil + } + return 0, errors.New("truncate length must be u64") default: return 0, errors.New("expr size is not bytes for op: ", expr.GetOp()) } @@ -317,7 +526,7 @@ func measureExpr(expr *Expr, sizeCtx map[string]int) (int, error) { func evaluateExprArg(arg *ExprArg, ctx *evalContext) (evalValue, error) { switch value := arg.GetValue().(type) { case *ExprArg_Bytes: - return evalValue{bytes: append([]byte(nil), value.Bytes...)}, nil + return evalValue{bytes: append([]byte(nil), value.Bytes...), isBytes: true}, nil case *ExprArg_U64: return evalValue{u64: &value.U64}, nil case *ExprArg_Var: @@ -325,7 +534,7 @@ func evaluateExprArg(arg *ExprArg, ctx *evalContext) (evalValue, error) { if !ok { return evalValue{}, errors.New("unknown variable: ", value.Var) } - return evalValue{bytes: append([]byte(nil), saved...)}, nil + return evalValue{bytes: append([]byte(nil), saved...), isBytes: true}, nil case *ExprArg_Metadata: metadata, ok := ctx.metadata[value.Metadata] if !ok { @@ -361,7 +570,7 @@ func measureExprArg(arg *ExprArg, sizeCtx map[string]int) (int, error) { } func (v evalValue) asBytes() ([]byte, error) { - if v.bytes != nil { + if v.isBytes { return append([]byte(nil), v.bytes...), nil } return nil, errors.New("expr value is not bytes") diff --git a/transport/internet/finalmask/header/custom/evaluator_test.go b/transport/internet/finalmask/header/custom/evaluator_test.go index 6bd34264..51a9d277 100644 --- a/transport/internet/finalmask/header/custom/evaluator_test.go +++ b/transport/internet/finalmask/header/custom/evaluator_test.go @@ -128,3 +128,364 @@ func TestEvaluatorRejectsInvalidArgType(t *testing.T) { t.Fatal("expected evaluator error") } } + +func TestEvaluatorLittleEndianProducesExpectedBytes(t *testing.T) { + items := []*UDPItem{ + { + Expr: &Expr{ + Op: "concat", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "le16", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 0x1234}}, + }, + }, + }, + }, + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "le32", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 0xA1B2C3D4}}, + }, + }, + }, + }, + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "le64", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 0x0102030405060708}}, + }, + }, + }, + }, + }, + }, + }, + } + + got, err := evaluateUDPItems(items) + if err != nil { + t.Fatal(err) + } + + want := []byte{ + 0x34, 0x12, + 0xD4, 0xC3, 0xB2, 0xA1, + 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, + } + if !bytes.Equal(got, want) { + t.Fatalf("unexpected output: %x", got) + } +} + +func TestEvaluatorPadAndTruncateShapeBytes(t *testing.T) { + items := []*UDPItem{ + { + Expr: &Expr{ + Op: "concat", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "pad", + Args: []*ExprArg{ + {Value: &ExprArg_Bytes{Bytes: []byte{0xAA, 0xBB}}}, + {Value: &ExprArg_U64{U64: 5}}, + {Value: &ExprArg_Bytes{Bytes: []byte{0xCC, 0xDD}}}, + }, + }, + }, + }, + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "truncate", + Args: []*ExprArg{ + {Value: &ExprArg_Bytes{Bytes: []byte{1, 2, 3, 4}}}, + {Value: &ExprArg_U64{U64: 2}}, + }, + }, + }, + }, + }, + }, + }, + } + + got, err := evaluateUDPItems(items) + if err != nil { + t.Fatal(err) + } + + want := []byte{0xAA, 0xBB, 0xCC, 0xDD, 0xCC, 0x01, 0x02} + if !bytes.Equal(got, want) { + t.Fatalf("unexpected output: %x", got) + } +} + +func TestMeasureUDPItemsSupportsPadAndTruncate(t *testing.T) { + items := []*UDPItem{ + { + Expr: &Expr{ + Op: "pad", + Args: []*ExprArg{ + {Value: &ExprArg_Bytes{Bytes: []byte{0xAA}}}, + {Value: &ExprArg_U64{U64: 4}}, + {Value: &ExprArg_Bytes{Bytes: []byte{0x00}}}, + }, + }, + }, + { + Expr: &Expr{ + Op: "truncate", + Args: []*ExprArg{ + {Value: &ExprArg_Bytes{Bytes: []byte{1, 2, 3, 4}}}, + {Value: &ExprArg_U64{U64: 3}}, + }, + }, + }, + } + + got, err := measureUDPItems(items) + if err != nil { + t.Fatal(err) + } + + if got != 7 { + t.Fatalf("unexpected size: %d", got) + } +} + +func TestEvaluatorArithmeticAndBitwiseProduceExpectedBytes(t *testing.T) { + items := []*UDPItem{ + { + Expr: &Expr{ + Op: "concat", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "be16", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "add", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 1}}, + {Value: &ExprArg_U64{U64: 2}}, + }, + }, + }, + }, + }, + }, + }, + }, + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "be16", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "sub", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 10}}, + {Value: &ExprArg_U64{U64: 3}}, + }, + }, + }, + }, + }, + }, + }, + }, + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "be16", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "and", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 0xF0F0}}, + {Value: &ExprArg_U64{U64: 0x0FF0}}, + }, + }, + }, + }, + }, + }, + }, + }, + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "be16", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "or", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "shl", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 1}}, + {Value: &ExprArg_U64{U64: 8}}, + }, + }, + }, + }, + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "shr", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 0x80}}, + {Value: &ExprArg_U64{U64: 7}}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + got, err := evaluateUDPItems(items) + if err != nil { + t.Fatal(err) + } + + want := []byte{ + 0x00, 0x03, + 0x00, 0x07, + 0x00, 0xF0, + 0x01, 0x01, + } + if !bytes.Equal(got, want) { + t.Fatalf("unexpected output: %x", got) + } +} + +func TestEvaluatorRejectsInvalidShapingAndArithmetic(t *testing.T) { + tests := []struct { + name string + items []*UDPItem + match string + }{ + { + name: "pad with empty fill", + items: []*UDPItem{ + { + Expr: &Expr{ + Op: "pad", + Args: []*ExprArg{ + {Value: &ExprArg_Bytes{Bytes: []byte{0xAA}}}, + {Value: &ExprArg_U64{U64: 4}}, + {Value: &ExprArg_Bytes{Bytes: []byte{}}}, + }, + }, + }, + }, + match: "pad fill", + }, + { + name: "truncate beyond source", + items: []*UDPItem{ + { + Expr: &Expr{ + Op: "truncate", + Args: []*ExprArg{ + {Value: &ExprArg_Bytes{Bytes: []byte{1, 2}}}, + {Value: &ExprArg_U64{U64: 3}}, + }, + }, + }, + }, + match: "truncate", + }, + { + name: "sub underflow", + items: []*UDPItem{ + { + Expr: &Expr{ + Op: "be16", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "sub", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 1}}, + {Value: &ExprArg_U64{U64: 2}}, + }, + }, + }, + }, + }, + }, + }, + }, + match: "underflow", + }, + { + name: "shift too large", + items: []*UDPItem{ + { + Expr: &Expr{ + Op: "be16", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "shl", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 1}}, + {Value: &ExprArg_U64{U64: 64}}, + }, + }, + }, + }, + }, + }, + }, + }, + match: "shift", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := evaluateUDPItems(tt.items) + if err == nil { + t.Fatal("expected evaluator error") + } + if !bytes.Contains([]byte(err.Error()), []byte(tt.match)) { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} diff --git a/transport/internet/finalmask/header/custom/udp_runtime_test.go b/transport/internet/finalmask/header/custom/udp_runtime_test.go index c66d4c41..e0a091f4 100644 --- a/transport/internet/finalmask/header/custom/udp_runtime_test.go +++ b/transport/internet/finalmask/header/custom/udp_runtime_test.go @@ -1,6 +1,10 @@ package custom -import "testing" +import ( + "bytes" + "net" + "testing" +) func TestDSLUDPClientSizeTracksEvaluatedItems(t *testing.T) { conn, err := NewConnClientUDP(&UDPConfig{ @@ -81,3 +85,104 @@ func TestDSLUDPServerRejectsMalformedVarReference(t *testing.T) { t.Fatal("expected packet mismatch") } } + +func TestDSLUDPClientWriteSupportsExtendedExprOps(t *testing.T) { + conn, err := NewConnClientUDP(&UDPConfig{ + Client: []*UDPItem{ + { + Expr: &Expr{ + Op: "le16", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "add", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 1}}, + {Value: &ExprArg_U64{U64: 2}}, + }, + }, + }, + }, + }, + }, + }, + { + Expr: &Expr{ + Op: "pad", + Args: []*ExprArg{ + {Value: &ExprArg_Bytes{Bytes: []byte{0xAA}}}, + {Value: &ExprArg_U64{U64: 3}}, + {Value: &ExprArg_Bytes{Bytes: []byte{0xBB}}}, + }, + }, + }, + { + Expr: &Expr{ + Op: "truncate", + Args: []*ExprArg{ + {Value: &ExprArg_Bytes{Bytes: []byte{1, 2, 3, 4}}}, + {Value: &ExprArg_U64{U64: 2}}, + }, + }, + }, + { + Expr: &Expr{ + Op: "be16", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "or", + Args: []*ExprArg{ + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "shl", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 1}}, + {Value: &ExprArg_U64{U64: 8}}, + }, + }, + }, + }, + { + Value: &ExprArg_Expr{ + Expr: &Expr{ + Op: "shr", + Args: []*ExprArg{ + {Value: &ExprArg_U64{U64: 0x80}}, + {Value: &ExprArg_U64{U64: 7}}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, nil) + if err != nil { + t.Fatal(err) + } + + client := conn.(*udpCustomClientConn) + buf := make([]byte, client.Size()) + if _, err := client.WriteTo(buf, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 53}); err != nil { + t.Fatal(err) + } + + want := []byte{ + 0x03, 0x00, + 0xAA, 0xBB, 0xBB, + 0x01, 0x02, + 0x01, 0x01, + } + if !bytes.Equal(buf, want) { + t.Fatalf("unexpected encoded header: %x", buf) + } +}