diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index ef454d21..be671b02 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -19,6 +19,8 @@ import ( "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/proxy" + "github.com/xtls/xray-core/proxy/hysteria/account" + hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tcp" @@ -138,6 +140,13 @@ func (w *tcpWorker) Proxy() proxy.Inbound { func (w *tcpWorker) Start() error { ctx := context.Background() + + type HysteriaInboundValidator interface{ HysteriaInboundValidator() *account.Validator } + if v, ok := w.proxy.(HysteriaInboundValidator); ok { + ctx = hyCtx.ContextWithRequireDatagram(ctx, true) + ctx = hyCtx.ContextWithValidator(ctx, v.HysteriaInboundValidator()) + } + hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) { go w.callback(conn) }) diff --git a/infra/conf/hysteria.go b/infra/conf/hysteria.go index f690c363..3574811c 100644 --- a/infra/conf/hysteria.go +++ b/infra/conf/hysteria.go @@ -3,7 +3,9 @@ package conf import ( "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/serial" "github.com/xtls/xray-core/proxy/hysteria" + "github.com/xtls/xray-core/proxy/hysteria/account" "google.golang.org/protobuf/proto" ) @@ -27,3 +29,33 @@ func (c *HysteriaClientConfig) Build() (proto.Message, error) { return config, nil } + +type HysteriaUserConfig struct { + Auth string `json:"auth"` + Level uint32 `json:"level"` + Email string `json:"email"` +} + +type HysteriaServerConfig struct { + Version int32 `json:"version"` + Users []*HysteriaUserConfig `json:"clients"` +} + +func (c *HysteriaServerConfig) Build() (proto.Message, error) { + config := new(hysteria.ServerConfig) + + if c.Users != nil { + for _, user := range c.Users { + account := &account.Account{ + Auth: user.Auth, + } + config.Users = append(config.Users, &protocol.User{ + Email: user.Email, + Level: user.Level, + Account: serial.ToTypedMessage(account), + }) + } + } + + return config, nil +} diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 965a05a0..b4458096 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -508,6 +508,20 @@ type UdpHop struct { Interval *Int32Range `json:"interval"` } +type Masquerade struct { + Type string `json:"type"` + + Dir string `json:"dir"` + + Url string `json:"url"` + RewriteHost bool `json:"rewriteHost"` + Insecure bool `json:"insecure"` + + Content string `json:"content"` + Headers map[string]string `json:"headers"` + StatusCode int32 `json:"statusCode"` +} + type HysteriaConfig struct { Version int32 `json:"version"` Auth string `json:"auth"` @@ -523,6 +537,10 @@ type HysteriaConfig struct { MaxIdleTimeout int64 `json:"maxIdleTimeout"` KeepAlivePeriod int64 `json:"keepAlivePeriod"` DisablePathMTUDiscovery bool `json:"disablePathMTUDiscovery"` + MaxIncomingStreams int64 `json:"maxIncomingStreams"` + + UdpIdleTimeout int64 `json:"udpIdleTimeout"` + Masquerade Masquerade `json:"masquerade"` } func (c *HysteriaConfig) Build() (proto.Message, error) { @@ -556,10 +574,10 @@ func (c *HysteriaConfig) Build() (proto.Message, error) { } if up > 0 && up < 65536 { - return nil, errors.New("Up must be at least 65536 Bps") + return nil, errors.New("Up must be at least 65536 bytes per second") } if down > 0 && down < 65536 { - return nil, errors.New("Down must be at least 65536 Bps") + return nil, errors.New("Down must be at least 65536 bytes per second") } if (inertvalMin != 0 && inertvalMin < 5) || (inertvalMax != 0 && inertvalMax < 5) { return nil, errors.New("Interval must be at least 5") @@ -583,6 +601,12 @@ func (c *HysteriaConfig) Build() (proto.Message, error) { if c.KeepAlivePeriod != 0 && (c.KeepAlivePeriod < 2 || c.KeepAlivePeriod > 60) { return nil, errors.New("KeepAlivePeriod must be between 2 and 60") } + if c.MaxIncomingStreams != 0 && c.MaxIncomingStreams < 8 { + return nil, errors.New("MaxIncomingStreams must be at least 8") + } + if c.UdpIdleTimeout != 0 && (c.UdpIdleTimeout < 2 || c.UdpIdleTimeout > 600) { + return nil, errors.New("UdpIdleTimeout must be between 2 and 600") + } config := &hysteria.Config{} config.Version = c.Version @@ -600,6 +624,16 @@ func (c *HysteriaConfig) Build() (proto.Message, error) { config.MaxIdleTimeout = c.MaxIdleTimeout config.KeepAlivePeriod = c.KeepAlivePeriod config.DisablePathMtuDiscovery = c.DisablePathMTUDiscovery + config.MaxIncomingStreams = c.MaxIncomingStreams + config.UdpIdleTimeout = c.UdpIdleTimeout + config.MasqType = c.Masquerade.Type + config.MasqFile = c.Masquerade.Dir + config.MasqUrl = c.Masquerade.Url + config.MasqUrlRewriteHost = c.Masquerade.RewriteHost + config.MasqUrlInsecure = c.Masquerade.Insecure + config.MasqString = c.Masquerade.Content + config.MasqStringHeaders = c.Masquerade.Headers + config.MasqStringStatusCode = c.Masquerade.StatusCode if config.InitStreamReceiveWindow == 0 { config.InitStreamReceiveWindow = 8388608 @@ -619,6 +653,12 @@ func (c *HysteriaConfig) Build() (proto.Message, error) { // if config.KeepAlivePeriod == 0 { // config.KeepAlivePeriod = 10 // } + if config.MaxIncomingStreams == 0 { + config.MaxIncomingStreams = 1024 + } + if config.UdpIdleTimeout == 0 { + config.UdpIdleTimeout = 60 + } return config, nil } diff --git a/infra/conf/xray.go b/infra/conf/xray.go index 39a1f763..15e4a191 100644 --- a/infra/conf/xray.go +++ b/infra/conf/xray.go @@ -33,6 +33,7 @@ var ( "vmess": func() interface{} { return new(VMessInboundConfig) }, "trojan": func() interface{} { return new(TrojanServerConfig) }, "wireguard": func() interface{} { return &WireGuardConfig{IsClient: false} }, + "hysteria": func() interface{} { return new(HysteriaServerConfig) }, "tun": func() interface{} { return new(TunConfig) }, }, "protocol", "settings") diff --git a/proxy/hysteria/account/config.go b/proxy/hysteria/account/config.go new file mode 100644 index 00000000..0e50dcc9 --- /dev/null +++ b/proxy/hysteria/account/config.go @@ -0,0 +1,129 @@ +package account + +import ( + "sync" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/protocol" + + "google.golang.org/protobuf/proto" +) + +func (a *Account) AsAccount() (protocol.Account, error) { + return &MemoryAccount{ + Auth: a.Auth, + }, nil +} + +type MemoryAccount struct { + Auth string +} + +func (a *MemoryAccount) Equals(another protocol.Account) bool { + if account, ok := another.(*MemoryAccount); ok { + return a.Auth == account.Auth + } + return false +} + +func (a *MemoryAccount) ToProto() proto.Message { + return &Account{ + Auth: a.Auth, + } +} + +type Validator struct { + emails map[string]struct{} + users map[string]*protocol.MemoryUser + + mutex sync.Mutex +} + +func NewValidator() *Validator { + return &Validator{ + emails: make(map[string]struct{}), + users: make(map[string]*protocol.MemoryUser), + } +} + +func (v *Validator) Add(u *protocol.MemoryUser) error { + v.mutex.Lock() + defer v.mutex.Unlock() + + if u.Email != "" { + if _, ok := v.emails[u.Email]; ok { + return errors.New("User ", u.Email, " already exists.") + } + v.emails[u.Email] = struct{}{} + } + v.users[u.Account.(*MemoryAccount).Auth] = u + + return nil +} + +func (v *Validator) Del(email string) error { + if email == "" { + return errors.New("Email must not be empty.") + } + + v.mutex.Lock() + defer v.mutex.Unlock() + + if _, ok := v.emails[email]; !ok { + return errors.New("User ", email, " not found.") + } + delete(v.emails, email) + for key, user := range v.users { + if user.Email == email { + delete(v.users, key) + break + } + } + + return nil +} + +func (v *Validator) Get(auth string) *protocol.MemoryUser { + v.mutex.Lock() + defer v.mutex.Unlock() + + return v.users[auth] +} + +func (v *Validator) GetByEmail(email string) *protocol.MemoryUser { + if email == "" { + return nil + } + + v.mutex.Lock() + defer v.mutex.Unlock() + + if _, ok := v.emails[email]; ok { + for _, user := range v.users { + if user.Email == email { + return user + } + } + } + + return nil +} + +func (v *Validator) GetAll() []*protocol.MemoryUser { + v.mutex.Lock() + defer v.mutex.Unlock() + + var users = make([]*protocol.MemoryUser, 0, len(v.users)) + for _, user := range v.users { + users = append(users, user) + } + + return users +} + +func (v *Validator) GetCount() int64 { + v.mutex.Lock() + defer v.mutex.Unlock() + + return int64(len(v.users)) +} diff --git a/proxy/hysteria/account/config.pb.go b/proxy/hysteria/account/config.pb.go new file mode 100644 index 00000000..f48dca32 --- /dev/null +++ b/proxy/hysteria/account/config.pb.go @@ -0,0 +1,123 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.5 +// source: proxy/hysteria/account/config.proto + +package account + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Account struct { + state protoimpl.MessageState `protogen:"open.v1"` + Auth string `protobuf:"bytes,1,opt,name=auth,proto3" json:"auth,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Account) Reset() { + *x = Account{} + mi := &file_proxy_hysteria_account_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Account) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Account) ProtoMessage() {} + +func (x *Account) ProtoReflect() protoreflect.Message { + mi := &file_proxy_hysteria_account_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Account.ProtoReflect.Descriptor instead. +func (*Account) Descriptor() ([]byte, []int) { + return file_proxy_hysteria_account_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Account) GetAuth() string { + if x != nil { + return x.Auth + } + return "" +} + +var File_proxy_hysteria_account_config_proto protoreflect.FileDescriptor + +const file_proxy_hysteria_account_config_proto_rawDesc = "" + + "\n" + + "#proxy/hysteria/account/config.proto\x12\x1bxray.proxy.hysteria.account\"\x1d\n" + + "\aAccount\x12\x12\n" + + "\x04auth\x18\x01 \x01(\tR\x04authBs\n" + + "\x1fcom.xray.proxy.hysteria.accountP\x01Z0github.com/xtls/xray-core/proxy/hysteria/account\xaa\x02\x1bXray.Proxy.Hysteria.Accountb\x06proto3" + +var ( + file_proxy_hysteria_account_config_proto_rawDescOnce sync.Once + file_proxy_hysteria_account_config_proto_rawDescData []byte +) + +func file_proxy_hysteria_account_config_proto_rawDescGZIP() []byte { + file_proxy_hysteria_account_config_proto_rawDescOnce.Do(func() { + file_proxy_hysteria_account_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proxy_hysteria_account_config_proto_rawDesc), len(file_proxy_hysteria_account_config_proto_rawDesc))) + }) + return file_proxy_hysteria_account_config_proto_rawDescData +} + +var file_proxy_hysteria_account_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_proxy_hysteria_account_config_proto_goTypes = []any{ + (*Account)(nil), // 0: xray.proxy.hysteria.account.Account +} +var file_proxy_hysteria_account_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_proxy_hysteria_account_config_proto_init() } +func file_proxy_hysteria_account_config_proto_init() { + if File_proxy_hysteria_account_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_proxy_hysteria_account_config_proto_rawDesc), len(file_proxy_hysteria_account_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_proxy_hysteria_account_config_proto_goTypes, + DependencyIndexes: file_proxy_hysteria_account_config_proto_depIdxs, + MessageInfos: file_proxy_hysteria_account_config_proto_msgTypes, + }.Build() + File_proxy_hysteria_account_config_proto = out.File + file_proxy_hysteria_account_config_proto_goTypes = nil + file_proxy_hysteria_account_config_proto_depIdxs = nil +} diff --git a/proxy/hysteria/account/config.proto b/proxy/hysteria/account/config.proto new file mode 100644 index 00000000..48f64e64 --- /dev/null +++ b/proxy/hysteria/account/config.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package xray.proxy.hysteria.account; +option csharp_namespace = "Xray.Proxy.Hysteria.Account"; +option go_package = "github.com/xtls/xray-core/proxy/hysteria/account"; +option java_package = "com.xray.proxy.hysteria.account"; +option java_multiple_files = true; + +message Account { + string auth = 1; +} \ No newline at end of file diff --git a/proxy/hysteria/client.go b/proxy/hysteria/client.go index 2299fe34..1dcb5cf9 100644 --- a/proxy/hysteria/client.go +++ b/proxy/hysteria/client.go @@ -135,6 +135,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil { return errors.New("failed to transport all UDP request").Base(err) } + return nil } @@ -143,12 +144,14 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter reader := &UDPReader{ Reader: conn, + buf: make([]byte, MaxUDPSize), df: &Defragger{}, } if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil { return errors.New("failed to transport all UDP response").Base(err) } + return nil } @@ -178,7 +181,6 @@ type UDPWriter struct { func (w *UDPWriter) sendMsg(msg *UDPMessage) error { msgN := msg.Serialize(w.buf) if msgN < 0 { - // Message larger than buffer, silent drop return nil } _, err := w.Writer.Write(w.buf[:msgN]) @@ -192,10 +194,12 @@ func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { if b == nil { break } + addr := w.addr if b.UDP != nil { addr = b.UDP.NetAddr() } + msg := &UDPMessage{ SessionID: 0, PacketID: 0, @@ -204,47 +208,58 @@ func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { Addr: addr, Data: b.Bytes(), } - if err := w.sendMsg(msg); err != nil { - var errTooLarge *quic.DatagramTooLargeError - if go_errors.As(err, &errTooLarge) { - msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 - fMsgs := FragUDPMessage(msg, int(errTooLarge.MaxDatagramPayloadSize)) - for _, fMsg := range fMsgs { - err := w.sendMsg(&fMsg) - if err != nil { - b.Release() - buf.ReleaseMulti(mb) - return err - } + + err := w.sendMsg(msg) + var errTooLarge *quic.DatagramTooLargeError + if go_errors.As(err, &errTooLarge) { + msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 + fMsgs := FragUDPMessage(msg, int(errTooLarge.MaxDatagramPayloadSize)) + for _, fMsg := range fMsgs { + err := w.sendMsg(&fMsg) + if err != nil { + b.Release() + buf.ReleaseMulti(mb) + return err } - } else { - b.Release() - buf.ReleaseMulti(mb) - return err } + } else if err != nil { + b.Release() + buf.ReleaseMulti(mb) + return err } + b.Release() } + return nil } type UDPReader struct { - Reader io.Reader - df *Defragger + Reader io.Reader + buf []byte + df *Defragger + firstMsg *UDPMessage + firstDest *net.Destination } func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { + if r.firstMsg != nil { + buffer := buf.New() + buffer.Write(r.firstMsg.Data) + buffer.UDP = r.firstDest + + r.firstMsg = nil + + return buf.MultiBuffer{buffer}, nil + } for { - b := buf.New() - _, err := b.ReadFrom(r.Reader) + n, err := r.Reader.Read(r.buf) if err != nil { - b.Release() return nil, err } - msg, err := ParseUDPMessage(b.Bytes()) + msg, err := ParseUDPMessage(r.buf[:n]) if err != nil { - b.Release() continue } @@ -253,7 +268,11 @@ func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { continue } - dest, _ := net.ParseDestination("udp:" + dfMsg.Addr) + dest, err := net.ParseDestination("udp:" + dfMsg.Addr) + if err != nil { + errors.LogDebug(context.Background(), dfMsg.Addr, " ParseDestination err ", err) + continue + } buffer := buf.New() buffer.Write(dfMsg.Data) diff --git a/proxy/hysteria/config.go b/proxy/hysteria/config.go index 2650d856..1daedf03 100644 --- a/proxy/hysteria/config.go +++ b/proxy/hysteria/config.go @@ -5,6 +5,6 @@ import ( ) var ( - tcpRequestPadding = padding.Padding{Min: 64, Max: 512} - // tcpResponsePadding = padding.Padding{Min: 128, Max: 1024} + tcpRequestPadding = padding.Padding{Min: 64, Max: 512} + tcpResponsePadding = padding.Padding{Min: 128, Max: 1024} ) diff --git a/proxy/hysteria/config.pb.go b/proxy/hysteria/config.pb.go index 25438ddf..5764b78c 100644 --- a/proxy/hysteria/config.pb.go +++ b/proxy/hysteria/config.pb.go @@ -74,14 +74,60 @@ func (x *ClientConfig) GetServer() *protocol.ServerEndpoint { return nil } +type ServerConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + Users []*protocol.User `protobuf:"bytes,1,rep,name=users,proto3" json:"users,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ServerConfig) Reset() { + *x = ServerConfig{} + mi := &file_proxy_hysteria_config_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ServerConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ServerConfig) ProtoMessage() {} + +func (x *ServerConfig) ProtoReflect() protoreflect.Message { + mi := &file_proxy_hysteria_config_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ServerConfig.ProtoReflect.Descriptor instead. +func (*ServerConfig) Descriptor() ([]byte, []int) { + return file_proxy_hysteria_config_proto_rawDescGZIP(), []int{1} +} + +func (x *ServerConfig) GetUsers() []*protocol.User { + if x != nil { + return x.Users + } + return nil +} + var File_proxy_hysteria_config_proto protoreflect.FileDescriptor const file_proxy_hysteria_config_proto_rawDesc = "" + "\n" + - "\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\"f\n" + + "\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\x1a\x1acommon/protocol/user.proto\"f\n" + "\fClientConfig\x12\x18\n" + "\aversion\x18\x01 \x01(\x05R\aversion\x12<\n" + - "\x06server\x18\x02 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06serverB[\n" + + "\x06server\x18\x02 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06server\"@\n" + + "\fServerConfig\x120\n" + + "\x05users\x18\x01 \x03(\v2\x1a.xray.common.protocol.UserR\x05usersB[\n" + "\x17com.xray.proxy.hysteriaP\x01Z(github.com/xtls/xray-core/proxy/hysteria\xaa\x02\x13Xray.Proxy.Hysteriab\x06proto3" var ( @@ -96,18 +142,21 @@ func file_proxy_hysteria_config_proto_rawDescGZIP() []byte { return file_proxy_hysteria_config_proto_rawDescData } -var file_proxy_hysteria_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_proxy_hysteria_config_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_proxy_hysteria_config_proto_goTypes = []any{ (*ClientConfig)(nil), // 0: xray.proxy.hysteria.ClientConfig - (*protocol.ServerEndpoint)(nil), // 1: xray.common.protocol.ServerEndpoint + (*ServerConfig)(nil), // 1: xray.proxy.hysteria.ServerConfig + (*protocol.ServerEndpoint)(nil), // 2: xray.common.protocol.ServerEndpoint + (*protocol.User)(nil), // 3: xray.common.protocol.User } var file_proxy_hysteria_config_proto_depIdxs = []int32{ - 1, // 0: xray.proxy.hysteria.ClientConfig.server:type_name -> xray.common.protocol.ServerEndpoint - 1, // [1:1] is the sub-list for method output_type - 1, // [1:1] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 2, // 0: xray.proxy.hysteria.ClientConfig.server:type_name -> xray.common.protocol.ServerEndpoint + 3, // 1: xray.proxy.hysteria.ServerConfig.users:type_name -> xray.common.protocol.User + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name } func init() { file_proxy_hysteria_config_proto_init() } @@ -121,7 +170,7 @@ func file_proxy_hysteria_config_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_proxy_hysteria_config_proto_rawDesc), len(file_proxy_hysteria_config_proto_rawDesc)), NumEnums: 0, - NumMessages: 1, + NumMessages: 2, NumExtensions: 0, NumServices: 0, }, diff --git a/proxy/hysteria/config.proto b/proxy/hysteria/config.proto index 2a71ac53..03fdf4e6 100644 --- a/proxy/hysteria/config.proto +++ b/proxy/hysteria/config.proto @@ -7,8 +7,13 @@ option java_package = "com.xray.proxy.hysteria"; option java_multiple_files = true; import "common/protocol/server_spec.proto"; +import "common/protocol/user.proto"; message ClientConfig { int32 version = 1; xray.common.protocol.ServerEndpoint server = 2; } + +message ServerConfig { + repeated xray.common.protocol.User users = 1; +} \ No newline at end of file diff --git a/proxy/hysteria/ctx/ctx.go b/proxy/hysteria/ctx/ctx.go index 610fa065..4e1b290c 100644 --- a/proxy/hysteria/ctx/ctx.go +++ b/proxy/hysteria/ctx/ctx.go @@ -2,12 +2,15 @@ package ctx import ( "context" + + "github.com/xtls/xray-core/proxy/hysteria/account" ) type key int const ( requireDatagram key = iota + validator ) func ContextWithRequireDatagram(ctx context.Context, udp bool) context.Context { @@ -21,3 +24,12 @@ func RequireDatagramFromContext(ctx context.Context) bool { _, ok := ctx.Value(requireDatagram).(struct{}) return ok } + +func ContextWithValidator(ctx context.Context, v *account.Validator) context.Context { + return context.WithValue(ctx, validator, v) +} + +func ValidatorFromContext(ctx context.Context) *account.Validator { + v, _ := ctx.Value(validator).(*account.Validator) + return v +} diff --git a/proxy/hysteria/protocol.go b/proxy/hysteria/protocol.go index ee4834a0..b838d15a 100644 --- a/proxy/hysteria/protocol.go +++ b/proxy/hysteria/protocol.go @@ -11,8 +11,6 @@ import ( ) const ( - FrameTypeTCPRequest = 0x401 - // Max length values are for preventing DoS attacks MaxAddressLength = 2048 @@ -28,22 +26,49 @@ const ( ) // TCPRequest format: -// 0x401 (QUIC varint) // Address length (QUIC varint) // Address (bytes) // Padding length (QUIC varint) // Padding (bytes) +func ReadTCPRequest(r io.Reader) (string, error) { + bReader := quicvarint.NewReader(r) + addrLen, err := quicvarint.Read(bReader) + if err != nil { + return "", err + } + if addrLen == 0 || addrLen > MaxAddressLength { + return "", errors.New("invalid address length") + } + addrBuf := make([]byte, addrLen) + _, err = io.ReadFull(r, addrBuf) + if err != nil { + return "", err + } + paddingLen, err := quicvarint.Read(bReader) + if err != nil { + return "", err + } + if paddingLen > MaxPaddingLength { + return "", errors.New("invalid padding length") + } + if paddingLen > 0 { + _, err = io.CopyN(io.Discard, r, int64(paddingLen)) + if err != nil { + return "", err + } + } + return string(addrBuf), nil +} + func WriteTCPRequest(w io.Writer, addr string) error { padding := tcpRequestPadding.String() paddingLen := len(padding) addrLen := len(addr) - sz := int(quicvarint.Len(FrameTypeTCPRequest)) + - int(quicvarint.Len(uint64(addrLen))) + addrLen + + sz := int(quicvarint.Len(uint64(addrLen))) + addrLen + int(quicvarint.Len(uint64(paddingLen))) + paddingLen buf := make([]byte, sz) - i := varintPut(buf, FrameTypeTCPRequest) - i += varintPut(buf[i:], uint64(addrLen)) + i := varintPut(buf, uint64(addrLen)) i += copy(buf[i:], addr) i += varintPut(buf[i:], uint64(paddingLen)) copy(buf[i:], padding) @@ -96,6 +121,26 @@ func ReadTCPResponse(r io.Reader) (bool, string, error) { return status[0] == 0, string(msgBuf), nil } +func WriteTCPResponse(w io.Writer, ok bool, msg string) error { + padding := tcpResponsePadding.String() + paddingLen := len(padding) + msgLen := len(msg) + sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + + int(quicvarint.Len(uint64(paddingLen))) + paddingLen + buf := make([]byte, sz) + if ok { + buf[0] = 0 + } else { + buf[0] = 1 + } + i := varintPut(buf[1:], uint64(msgLen)) + i += copy(buf[1+i:], msg) + i += varintPut(buf[1+i:], uint64(paddingLen)) + copy(buf[1+i:], padding) + _, err := w.Write(buf) + return err +} + // UDPMessage format: // Session ID (uint32 BE) // Packet ID (uint16 BE) diff --git a/proxy/hysteria/server.go b/proxy/hysteria/server.go new file mode 100644 index 00000000..dcb99e92 --- /dev/null +++ b/proxy/hysteria/server.go @@ -0,0 +1,198 @@ +package hysteria + +import ( + "context" + "io" + "time" + + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/log" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/core" + "github.com/xtls/xray-core/features/policy" + "github.com/xtls/xray-core/features/routing" + "github.com/xtls/xray-core/proxy/hysteria/account" + "github.com/xtls/xray-core/transport" + "github.com/xtls/xray-core/transport/internet/hysteria" + "github.com/xtls/xray-core/transport/internet/stat" +) + +type Server struct { + config *ServerConfig + validator *account.Validator + policyManager policy.Manager +} + +func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { + validator := account.NewValidator() + for _, user := range config.Users { + u, err := user.ToMemoryUser() + if err != nil { + return nil, errors.New("failed to get hysteria user").Base(err).AtError() + } + + if err := validator.Add(u); err != nil { + return nil, errors.New("failed to add user").Base(err).AtError() + } + } + + v := core.MustFromContext(ctx) + s := &Server{ + config: config, + validator: validator, + policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + } + + return s, nil +} + +func (s *Server) HysteriaInboundValidator() *account.Validator { + return s.validator +} + +func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { + return s.validator.Add(u) +} + +func (s *Server) RemoveUser(ctx context.Context, e string) error { + return s.validator.Del(e) +} + +func (s *Server) GetUser(ctx context.Context, email string) *protocol.MemoryUser { + return s.validator.GetByEmail(email) +} + +func (s *Server) GetUsers(ctx context.Context) []*protocol.MemoryUser { + return s.validator.GetAll() +} + +func (s *Server) GetUsersCount(context.Context) int64 { + return s.validator.GetCount() +} + +func (s *Server) Network() []net.Network { + return []net.Network{net.Network_TCP} +} + +func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { + inbound := session.InboundFromContext(ctx) + inbound.Name = "hysteria" + inbound.CanSpliceCopy = 3 + + var useremail string + var userlevel uint32 + type User interface{ User() *protocol.MemoryUser } + if v, ok := conn.(User); ok { + inbound.User = v.User() + if inbound.User != nil { + useremail = inbound.User.Email + userlevel = inbound.User.Level + } + } + + iConn := stat.TryUnwrapStatsConn(conn) + if _, ok := iConn.(*hysteria.InterUdpConn); ok { + r := io.Reader(conn) + b := make([]byte, MaxUDPSize) + df := &Defragger{} + var firstMsg *UDPMessage + var firstDest net.Destination + + for { + n, err := r.Read(b) + if err != nil { + return err + } + + msg, err := ParseUDPMessage(b[:n]) + if err != nil { + continue + } + + dfMsg := df.Feed(msg) + if dfMsg == nil { + continue + } + + firstMsg = dfMsg + firstDest, err = net.ParseDestination("udp:" + firstMsg.Addr) + if err != nil { + errors.LogDebug(context.Background(), dfMsg.Addr, " ParseDestination err ", err) + continue + } + + break + } + + reader := &UDPReader{ + Reader: r, + buf: b, + df: df, + firstMsg: firstMsg, + firstDest: &firstDest, + } + + writer := &UDPWriter{ + Writer: conn, + buf: make([]byte, MaxUDPSize), + addr: firstMsg.Addr, + } + + return dispatcher.DispatchLink(ctx, firstDest, &transport.Link{ + Reader: reader, + Writer: writer, + }) + } else { + sessionPolicy := s.policyManager.ForLevel(userlevel) + + common.Must(conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake))) + addr, err := ReadTCPRequest(conn) + if err != nil { + log.Record(&log.AccessMessage{ + From: conn.RemoteAddr(), + To: "", + Status: log.AccessRejected, + Reason: err, + }) + return errors.New("failed to create request from: ", conn.RemoteAddr()).Base(err) + } + common.Must(conn.SetReadDeadline(time.Time{})) + + dest, err := net.ParseDestination("tcp:" + addr) + if err != nil { + return err + } + ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ + From: conn.RemoteAddr(), + To: dest, + Status: log.AccessAccepted, + Reason: "", + Email: useremail, + }) + errors.LogInfo(ctx, "tunnelling request to ", dest) + + bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn)) + err = WriteTCPResponse(bufferedWriter, true, "") + if err != nil { + return errors.New("failed to write response").Base(err) + } + if err := bufferedWriter.SetBuffered(false); err != nil { + return err + } + + return dispatcher.DispatchLink(ctx, dest, &transport.Link{ + Reader: buf.NewReader(conn), + Writer: bufferedWriter, + }) + } +} + +func init() { + common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { + return NewServer(ctx, config.(*ServerConfig)) + })) +} diff --git a/transport/internet/hysteria/config.go b/transport/internet/hysteria/config.go index fd7a4bb4..7636983f 100644 --- a/transport/internet/hysteria/config.go +++ b/transport/internet/hysteria/config.go @@ -1,6 +1,8 @@ package hysteria import ( + "time" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/hysteria/padding" @@ -23,11 +25,15 @@ const ( StatusAuthOK = 233 udpMessageChanSize = 1024 + + FrameTypeTCPRequest = 0x401 + + idleCleanupInterval = 1 * time.Second ) var ( - authRequestPadding = padding.Padding{Min: 256, Max: 2048} - // authResponsePadding = padding.Padding{Min: 256, Max: 2048} + authRequestPadding = padding.Padding{Min: 256, Max: 2048} + authResponsePadding = padding.Padding{Min: 256, Max: 2048} ) type Status int diff --git a/transport/internet/hysteria/config.pb.go b/transport/internet/hysteria/config.pb.go index af23ec1b..913c26ff 100644 --- a/transport/internet/hysteria/config.pb.go +++ b/transport/internet/hysteria/config.pb.go @@ -38,6 +38,16 @@ type Config struct { MaxIdleTimeout int64 `protobuf:"varint,13,opt,name=max_idle_timeout,json=maxIdleTimeout,proto3" json:"max_idle_timeout,omitempty"` KeepAlivePeriod int64 `protobuf:"varint,14,opt,name=keep_alive_period,json=keepAlivePeriod,proto3" json:"keep_alive_period,omitempty"` DisablePathMtuDiscovery bool `protobuf:"varint,15,opt,name=disable_path_mtu_discovery,json=disablePathMtuDiscovery,proto3" json:"disable_path_mtu_discovery,omitempty"` + MaxIncomingStreams int64 `protobuf:"varint,16,opt,name=max_incoming_streams,json=maxIncomingStreams,proto3" json:"max_incoming_streams,omitempty"` + UdpIdleTimeout int64 `protobuf:"varint,17,opt,name=udp_idle_timeout,json=udpIdleTimeout,proto3" json:"udp_idle_timeout,omitempty"` + MasqType string `protobuf:"bytes,18,opt,name=masq_type,json=masqType,proto3" json:"masq_type,omitempty"` + MasqFile string `protobuf:"bytes,19,opt,name=masq_file,json=masqFile,proto3" json:"masq_file,omitempty"` + MasqUrl string `protobuf:"bytes,20,opt,name=masq_url,json=masqUrl,proto3" json:"masq_url,omitempty"` + MasqUrlRewriteHost bool `protobuf:"varint,21,opt,name=masq_url_rewrite_host,json=masqUrlRewriteHost,proto3" json:"masq_url_rewrite_host,omitempty"` + MasqUrlInsecure bool `protobuf:"varint,22,opt,name=masq_url_insecure,json=masqUrlInsecure,proto3" json:"masq_url_insecure,omitempty"` + MasqString string `protobuf:"bytes,23,opt,name=masq_string,json=masqString,proto3" json:"masq_string,omitempty"` + MasqStringHeaders map[string]string `protobuf:"bytes,24,rep,name=masq_string_headers,json=masqStringHeaders,proto3" json:"masq_string_headers,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + MasqStringStatusCode int32 `protobuf:"varint,25,opt,name=masq_string_status_code,json=masqStringStatusCode,proto3" json:"masq_string_status_code,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -177,11 +187,81 @@ func (x *Config) GetDisablePathMtuDiscovery() bool { return false } +func (x *Config) GetMaxIncomingStreams() int64 { + if x != nil { + return x.MaxIncomingStreams + } + return 0 +} + +func (x *Config) GetUdpIdleTimeout() int64 { + if x != nil { + return x.UdpIdleTimeout + } + return 0 +} + +func (x *Config) GetMasqType() string { + if x != nil { + return x.MasqType + } + return "" +} + +func (x *Config) GetMasqFile() string { + if x != nil { + return x.MasqFile + } + return "" +} + +func (x *Config) GetMasqUrl() string { + if x != nil { + return x.MasqUrl + } + return "" +} + +func (x *Config) GetMasqUrlRewriteHost() bool { + if x != nil { + return x.MasqUrlRewriteHost + } + return false +} + +func (x *Config) GetMasqUrlInsecure() bool { + if x != nil { + return x.MasqUrlInsecure + } + return false +} + +func (x *Config) GetMasqString() string { + if x != nil { + return x.MasqString + } + return "" +} + +func (x *Config) GetMasqStringHeaders() map[string]string { + if x != nil { + return x.MasqStringHeaders + } + return nil +} + +func (x *Config) GetMasqStringStatusCode() int32 { + if x != nil { + return x.MasqStringStatusCode + } + return 0 +} + var File_transport_internet_hysteria_config_proto protoreflect.FileDescriptor const file_transport_internet_hysteria_config_proto_rawDesc = "" + "\n" + - "(transport/internet/hysteria/config.proto\x12 xray.transport.internet.hysteria\"\xd1\x04\n" + + "(transport/internet/hysteria/config.proto\x12 xray.transport.internet.hysteria\"\xf0\b\n" + "\x06Config\x12\x18\n" + "\aversion\x18\x01 \x01(\x05R\aversion\x12\x12\n" + "\x04auth\x18\x02 \x01(\tR\x04auth\x12\x1e\n" + @@ -200,7 +280,21 @@ const file_transport_internet_hysteria_config_proto_rawDesc = "" + "\x17max_conn_receive_window\x18\f \x01(\x04R\x14maxConnReceiveWindow\x12(\n" + "\x10max_idle_timeout\x18\r \x01(\x03R\x0emaxIdleTimeout\x12*\n" + "\x11keep_alive_period\x18\x0e \x01(\x03R\x0fkeepAlivePeriod\x12;\n" + - "\x1adisable_path_mtu_discovery\x18\x0f \x01(\bR\x17disablePathMtuDiscoveryB\x82\x01\n" + + "\x1adisable_path_mtu_discovery\x18\x0f \x01(\bR\x17disablePathMtuDiscovery\x120\n" + + "\x14max_incoming_streams\x18\x10 \x01(\x03R\x12maxIncomingStreams\x12(\n" + + "\x10udp_idle_timeout\x18\x11 \x01(\x03R\x0eudpIdleTimeout\x12\x1b\n" + + "\tmasq_type\x18\x12 \x01(\tR\bmasqType\x12\x1b\n" + + "\tmasq_file\x18\x13 \x01(\tR\bmasqFile\x12\x19\n" + + "\bmasq_url\x18\x14 \x01(\tR\amasqUrl\x121\n" + + "\x15masq_url_rewrite_host\x18\x15 \x01(\bR\x12masqUrlRewriteHost\x12*\n" + + "\x11masq_url_insecure\x18\x16 \x01(\bR\x0fmasqUrlInsecure\x12\x1f\n" + + "\vmasq_string\x18\x17 \x01(\tR\n" + + "masqString\x12o\n" + + "\x13masq_string_headers\x18\x18 \x03(\v2?.xray.transport.internet.hysteria.Config.MasqStringHeadersEntryR\x11masqStringHeaders\x125\n" + + "\x17masq_string_status_code\x18\x19 \x01(\x05R\x14masqStringStatusCode\x1aD\n" + + "\x16MasqStringHeadersEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01B\x82\x01\n" + "$com.xray.transport.internet.hysteriaP\x01Z5github.com/xtls/xray-core/transport/internet/hysteria\xaa\x02 Xray.Transport.Internet.Hysteriab\x06proto3" var ( @@ -215,16 +309,18 @@ func file_transport_internet_hysteria_config_proto_rawDescGZIP() []byte { return file_transport_internet_hysteria_config_proto_rawDescData } -var file_transport_internet_hysteria_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_hysteria_config_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_transport_internet_hysteria_config_proto_goTypes = []any{ (*Config)(nil), // 0: xray.transport.internet.hysteria.Config + nil, // 1: xray.transport.internet.hysteria.Config.MasqStringHeadersEntry } var file_transport_internet_hysteria_config_proto_depIdxs = []int32{ - 0, // [0:0] is the sub-list for method output_type - 0, // [0:0] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name + 1, // 0: xray.transport.internet.hysteria.Config.masq_string_headers:type_name -> xray.transport.internet.hysteria.Config.MasqStringHeadersEntry + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name } func init() { file_transport_internet_hysteria_config_proto_init() } @@ -238,7 +334,7 @@ func file_transport_internet_hysteria_config_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_hysteria_config_proto_rawDesc), len(file_transport_internet_hysteria_config_proto_rawDesc)), NumEnums: 0, - NumMessages: 1, + NumMessages: 2, NumExtensions: 0, NumServices: 0, }, diff --git a/transport/internet/hysteria/config.proto b/transport/internet/hysteria/config.proto index 09f97298..fbc919cc 100644 --- a/transport/internet/hysteria/config.proto +++ b/transport/internet/hysteria/config.proto @@ -23,5 +23,15 @@ message Config { int64 max_idle_timeout = 13; int64 keep_alive_period = 14; bool disable_path_mtu_discovery = 15; -} + int64 max_incoming_streams = 16; + int64 udp_idle_timeout = 17; + string masq_type = 18; + string masq_file = 19; + string masq_url = 20; + bool masq_url_rewrite_host = 21; + bool masq_url_insecure = 22; + string masq_string = 23; + map masq_string_headers = 24; + int32 masq_string_status_code = 25; +} \ No newline at end of file diff --git a/transport/internet/hysteria/conn.go b/transport/internet/hysteria/conn.go index ffc41a54..be4b0f59 100644 --- a/transport/internet/hysteria/conn.go +++ b/transport/internet/hysteria/conn.go @@ -3,16 +3,28 @@ package hysteria import ( "encoding/binary" "io" + "sync" "time" "github.com/apernet/quic-go" + "github.com/apernet/quic-go/quicvarint" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/protocol" ) type interConn struct { stream *quic.Stream local net.Addr remote net.Addr + + client bool + mutex sync.Mutex + + user *protocol.MemoryUser +} + +func (i *interConn) User() *protocol.MemoryUser { + return i.user } func (i *interConn) Read(b []byte) (int, error) { @@ -20,6 +32,22 @@ func (i *interConn) Read(b []byte) (int, error) { } func (i *interConn) Write(b []byte) (int, error) { + if i.client { + i.mutex.Lock() + if i.client { + buf := make([]byte, 0, quicvarint.Len(FrameTypeTCPRequest)+len(b)) + buf = quicvarint.Append(buf, FrameTypeTCPRequest) + buf = append(buf, b...) + _, err := i.stream.Write(buf) + if err != nil { + return 0, err + } + i.client = false + return len(b), nil + } + i.mutex.Unlock() + } + return i.stream.Write(b) } @@ -53,10 +81,34 @@ type InterUdpConn struct { local net.Addr remote net.Addr - id uint32 - ch chan []byte + id uint32 + ch chan []byte + closed bool closeFunc func() + + last time.Time + mutex sync.Mutex + + user *protocol.MemoryUser +} + +func (i *InterUdpConn) User() *protocol.MemoryUser { + return i.user +} + +func (i *InterUdpConn) SetLast() { + i.mutex.Lock() + defer i.mutex.Unlock() + + i.last = time.Now() +} + +func (i *InterUdpConn) GetLast() time.Time { + i.mutex.Lock() + defer i.mutex.Unlock() + + return i.last } func (i *InterUdpConn) Read(p []byte) (int, error) { @@ -68,10 +120,14 @@ func (i *InterUdpConn) Read(p []byte) (int, error) { if n != len(b) { return 0, io.ErrShortBuffer } + + i.SetLast() return n, nil } func (i *InterUdpConn) Write(p []byte) (int, error) { + i.SetLast() + binary.BigEndian.PutUint32(p, i.id) if err := i.conn.SendDatagram(p); err != nil { return 0, err diff --git a/transport/internet/hysteria/dialer.go b/transport/internet/hysteria/dialer.go index cb801aa4..c1e8d150 100644 --- a/transport/internet/hysteria/dialer.go +++ b/transport/internet/hysteria/dialer.go @@ -26,15 +26,23 @@ import ( "github.com/xtls/xray-core/transport/internet/tls" ) -type udpSessionManager struct { +type udpSessionManagerClient struct { conn *quic.Conn m map[uint32]*InterUdpConn - nextId uint32 + next uint32 closed bool mutex sync.RWMutex } -func (m *udpSessionManager) run() { +func (m *udpSessionManagerClient) close(udpConn *InterUdpConn) { + if !udpConn.closed { + udpConn.closed = true + close(udpConn.ch) + delete(m.m, udpConn.id) + } +} + +func (m *udpSessionManagerClient) run() { for { d, err := m.conn.ReceiveDatagram(context.Background()) if err != nil { @@ -44,29 +52,22 @@ func (m *udpSessionManager) run() { if len(d) < 4 { continue } - sessionId := binary.BigEndian.Uint32(d[:4]) + id := binary.BigEndian.Uint32(d[:4]) - m.feed(sessionId, d) + m.feed(id, d) } m.mutex.Lock() defer m.mutex.Unlock() m.closed = true + for _, udpConn := range m.m { m.close(udpConn) } } -func (m *udpSessionManager) close(udpConn *InterUdpConn) { - if !udpConn.closed { - udpConn.closed = true - close(udpConn.ch) - delete(m.m, udpConn.id) - } -} - -func (m *udpSessionManager) udp() (*InterUdpConn, error) { +func (m *udpSessionManagerClient) udp() (*InterUdpConn, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -79,7 +80,7 @@ func (m *udpSessionManager) udp() (*InterUdpConn, error) { local: m.conn.LocalAddr(), remote: m.conn.RemoteAddr(), - id: m.nextId, + id: m.next, ch: make(chan []byte, udpMessageChanSize), } udpConn.closeFunc = func() { @@ -87,17 +88,17 @@ func (m *udpSessionManager) udp() (*InterUdpConn, error) { defer m.mutex.Unlock() m.close(udpConn) } - m.m[m.nextId] = udpConn - m.nextId++ + m.m[m.next] = udpConn + m.next++ return udpConn, nil } -func (m *udpSessionManager) feed(sessionId uint32, d []byte) { +func (m *udpSessionManagerClient) feed(id uint32, d []byte) { m.mutex.RLock() defer m.mutex.RUnlock() - udpConn, ok := m.m[sessionId] + udpConn, ok := m.m[id] if !ok { return } @@ -117,7 +118,7 @@ type client struct { tlsConfig *go_tls.Config socketConfig *internet.SocketConfig udpmaskManager *finalmask.UdpmaskManager - udpSM *udpSessionManager + udpSM *udpSessionManagerClient mutex sync.Mutex } @@ -269,10 +270,10 @@ func (c *client) dial() error { c.pktConn = pktConn c.conn = quicConn if serverUdp { - c.udpSM = &udpSessionManager{ - conn: quicConn, - m: make(map[uint32]*InterUdpConn), - nextId: 1, + c.udpSM = &udpSessionManagerClient{ + conn: quicConn, + m: make(map[uint32]*InterUdpConn), + next: 1, } go c.udpSM.run() } @@ -307,6 +308,8 @@ func (c *client) tcp() (stat.Connection, error) { stream: stream, local: c.conn.LocalAddr(), remote: c.conn.RemoteAddr(), + + client: true, }, nil } diff --git a/transport/internet/hysteria/hub.go b/transport/internet/hysteria/hub.go new file mode 100644 index 00000000..32a4def9 --- /dev/null +++ b/transport/internet/hysteria/hub.go @@ -0,0 +1,412 @@ +package hysteria + +import ( + "context" + gotls "crypto/tls" + "encoding/binary" + "net/http" + "net/http/httputil" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/apernet/quic-go" + "github.com/apernet/quic-go/http3" + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/proxy/hysteria/account" + hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx" + "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/hysteria/congestion" + "github.com/xtls/xray-core/transport/internet/tls" +) + +type udpSessionManagerServer struct { + conn *quic.Conn + m map[uint32]*InterUdpConn + addConn internet.ConnHandler + stopCh chan struct{} + udpIdleTimeout time.Duration + mutex sync.RWMutex + + user *protocol.MemoryUser +} + +func (m *udpSessionManagerServer) close(udpConn *InterUdpConn) { + if !udpConn.closed { + udpConn.closed = true + close(udpConn.ch) + delete(m.m, udpConn.id) + } +} + +func (m *udpSessionManagerServer) clean() { + ticker := time.NewTicker(idleCleanupInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + m.mutex.RLock() + now := time.Now() + timeoutConn := make([]*InterUdpConn, 0, len(m.m)) + for _, udpConn := range m.m { + if now.Sub(udpConn.GetLast()) > m.udpIdleTimeout { + timeoutConn = append(timeoutConn, udpConn) + } + } + m.mutex.RUnlock() + + for _, udpConn := range timeoutConn { + m.mutex.Lock() + m.close(udpConn) + m.mutex.Unlock() + } + case <-m.stopCh: + return + } + } +} + +func (m *udpSessionManagerServer) run() { + for { + d, err := m.conn.ReceiveDatagram(context.Background()) + if err != nil { + break + } + + if len(d) < 4 { + continue + } + id := binary.BigEndian.Uint32(d[:4]) + + m.feed(id, d) + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + close(m.stopCh) + + for _, udpConn := range m.m { + m.close(udpConn) + } +} + +func (m *udpSessionManagerServer) feed(id uint32, d []byte) { + m.mutex.RLock() + udpConn, ok := m.m[id] + m.mutex.RUnlock() + + if !ok { + m.mutex.Lock() + udpConn, ok = m.m[id] + if !ok { + udpConn = &InterUdpConn{ + conn: m.conn, + local: m.conn.LocalAddr(), + remote: m.conn.RemoteAddr(), + + id: id, + ch: make(chan []byte, udpMessageChanSize), + last: time.Now(), + + user: m.user, + } + udpConn.closeFunc = func() { + m.mutex.Lock() + defer m.mutex.Unlock() + m.close(udpConn) + } + m.m[id] = udpConn + m.addConn(udpConn) + } + m.mutex.Unlock() + } + + select { + case udpConn.ch <- d: + default: + } +} + +type httpHandler struct { + ctx context.Context + conn *quic.Conn + addConn internet.ConnHandler + + config *Config + validator *account.Validator + masqHandler http.Handler + + auth bool + mutex sync.Mutex + user *protocol.MemoryUser +} + +func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.Host == URLHost && r.URL.Path == URLPath { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.auth { + w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(hyCtx.RequireDatagramFromContext(h.ctx))) + w.Header().Set(CommonHeaderCCRX, strconv.FormatUint(h.config.Down, 10)) + w.Header().Set(CommonHeaderPadding, authResponsePadding.String()) + w.WriteHeader(StatusAuthOK) + return + } + + auth := r.Header.Get(RequestHeaderAuth) + clientDown, _ := strconv.ParseUint(r.Header.Get(CommonHeaderCCRX), 10, 64) + + var user *protocol.MemoryUser + var ok bool + if h.validator != nil { + user = h.validator.Get(auth) + } else if auth == h.config.Auth { + ok = true + } + + if user != nil || ok { + h.auth = true + h.user = user + + switch h.config.Congestion { + case "reno": + errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion reno") + case "bbr": + errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr") + congestion.UseBBR(h.conn) + case "brutal", "": + if h.config.Up == 0 || clientDown == 0 { + errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr") + congestion.UseBBR(h.conn) + } else { + errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion brutal bytes per second ", min(h.config.Up, clientDown)) + congestion.UseBrutal(h.conn, min(h.config.Up, clientDown)) + } + case "force-brutal": + errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion brutal bytes per second ", h.config.Up) + congestion.UseBrutal(h.conn, h.config.Up) + default: + errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion reno") + } + + if hyCtx.RequireDatagramFromContext(h.ctx) { + udpSM := &udpSessionManagerServer{ + conn: h.conn, + m: make(map[uint32]*InterUdpConn), + addConn: h.addConn, + stopCh: make(chan struct{}), + udpIdleTimeout: time.Duration(h.config.UdpIdleTimeout) * time.Second, + + user: h.user, + } + go udpSM.clean() + go udpSM.run() + } + + w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(hyCtx.RequireDatagramFromContext(h.ctx))) + w.Header().Set(CommonHeaderCCRX, strconv.FormatUint(h.config.Down, 10)) + w.Header().Set(CommonHeaderPadding, authResponsePadding.String()) + w.WriteHeader(StatusAuthOK) + return + } + } + + h.masqHandler.ServeHTTP(w, r) +} + +func (h *httpHandler) ProxyStreamHijacker(ft http3.FrameType, id quic.ConnectionTracingID, stream *quic.Stream, err error) (bool, error) { + if err != nil || !h.auth { + return false, nil + } + + switch ft { + case FrameTypeTCPRequest: + h.addConn(&interConn{ + stream: stream, + local: h.conn.LocalAddr(), + remote: h.conn.RemoteAddr(), + + user: h.user, + }) + return true, nil + default: + return false, nil + } +} + +type Listener struct { + ctx context.Context + pktConn net.PacketConn + listener *quic.Listener + addConn internet.ConnHandler + + config *Config + validator *account.Validator + masqHandler http.Handler +} + +func (l *Listener) handleClient(conn *quic.Conn) { + handler := &httpHandler{ + ctx: l.ctx, + conn: conn, + addConn: l.addConn, + + config: l.config, + validator: l.validator, + masqHandler: l.masqHandler, + } + h3 := http3.Server{ + Handler: handler, + StreamHijacker: handler.ProxyStreamHijacker, + } + err := h3.ServeQUICConn(conn) + errors.LogDebug(context.Background(), conn.RemoteAddr(), " disconnected with err ", err) + _ = conn.CloseWithError(closeErrCodeOK, "") +} + +func (l *Listener) keepAccepting() { + for { + conn, err := l.listener.Accept(context.Background()) + if err != nil { + errors.LogInfoInner(context.Background(), err, "failed to accept QUIC connection") + break + } + go l.handleClient(conn) + } +} + +func (l *Listener) Addr() net.Addr { + return l.listener.Addr() +} + +func (l *Listener) Close() error { + err := l.listener.Close() + _ = l.pktConn.Close() + return err +} + +func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) { + if address.Family().IsDomain() { + return nil, errors.New("address is domain") + } + + tlsConfig := tls.ConfigFromStreamSettings(streamSettings) + if tlsConfig == nil { + return nil, errors.New("tls config is nil") + } + + config := streamSettings.ProtocolSettings.(*Config) + + validator := hyCtx.ValidatorFromContext(ctx) + + if config.Auth == "" && validator == nil { + return nil, errors.New("validator is nil") + } + + var masqHandler http.Handler + switch strings.ToLower(config.MasqType) { + case "", "404": + masqHandler = http.NotFoundHandler() + case "file": + masqHandler = http.FileServer(http.Dir(config.MasqFile)) + case "proxy": + u, err := url.Parse(config.MasqUrl) + if err != nil { + return nil, err + } + transport := http.DefaultTransport.(*http.Transport) + if config.MasqUrlInsecure { + transport = transport.Clone() + transport.TLSClientConfig = &gotls.Config{ + InsecureSkipVerify: true, + } + } + masqHandler = &httputil.ReverseProxy{ + Rewrite: func(pr *httputil.ProxyRequest) { + pr.SetURL(u) + if !config.MasqUrlRewriteHost { + pr.Out.Host = pr.In.Host + } + }, + Transport: transport, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusBadGateway) + }, + } + case "string": + masqHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for k, v := range config.MasqStringHeaders { + w.Header().Set(k, v) + } + if config.MasqStringStatusCode != 0 { + w.WriteHeader(int(config.MasqStringStatusCode)) + } else { + w.WriteHeader(http.StatusOK) + } + _, _ = w.Write([]byte(config.MasqString)) + }) + default: + return nil, errors.New("unknown masq type") + } + + raw, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{IP: address.IP(), Port: int(port)}, streamSettings.SocketSettings) + if err != nil { + return nil, err + } + + var pktConn net.PacketConn + pktConn = raw + + if streamSettings.UdpmaskManager != nil { + pktConn, err = streamSettings.UdpmaskManager.WrapPacketConnServer(raw) + if err != nil { + raw.Close() + return nil, errors.New("mask err").Base(err) + } + } + + quicConfig := &quic.Config{ + InitialStreamReceiveWindow: config.InitStreamReceiveWindow, + MaxStreamReceiveWindow: config.MaxStreamReceiveWindow, + InitialConnectionReceiveWindow: config.InitConnReceiveWindow, + MaxConnectionReceiveWindow: config.MaxConnReceiveWindow, + MaxIdleTimeout: time.Duration(config.MaxIdleTimeout) * time.Second, + MaxIncomingStreams: config.MaxIncomingStreams, + DisablePathMTUDiscovery: config.DisablePathMtuDiscovery, + EnableDatagrams: true, + MaxDatagramFrameSize: MaxDatagramFrameSize, + DisablePathManager: true, + } + + qListener, err := quic.Listen(pktConn, tlsConfig.GetTLSConfig(), quicConfig) + if err != nil { + _ = pktConn.Close() + return nil, err + } + + listener := &Listener{ + ctx: ctx, + pktConn: pktConn, + listener: qListener, + addConn: handler, + + config: config, + validator: validator, + masqHandler: masqHandler, + } + + go listener.keepAccepting() + + return listener, nil +} + +func init() { + common.Must(internet.RegisterTransportListener(protocolName, Listen)) +}