feat: seamless location switch for WG

This commit is contained in:
cd-amn
2026-05-06 15:05:50 +04:00
parent 5af1edbc54
commit ce754d45e9
12 changed files with 62 additions and 11 deletions

View File

@@ -421,7 +421,23 @@ ErrorCode SubscriptionController::updateServiceFromGateway(int serverIndex, cons
} }
const bool isTestPurchase = apiV2->apiConfig.isTestPurchase; const bool isTestPurchase = apiV2->apiConfig.isTestPurchase;
QString serviceProtocol = apiV2->serviceProtocol(); QString serviceProtocol = apiV2->serviceProtocol();
ProtocolData protocolData = generateProtocolData(serviceProtocol); ProtocolData protocolData;
if (serviceProtocol == configKey::awg && !isConnectEvent) {
DockerContainer container = apiV2->defaultContainer;
ContainerConfig containerConfig = apiV2->containerConfig(container);
AwgProtocolConfig* awgConfig = containerConfig.protocolConfig.as<AwgProtocolConfig>();
if (awgConfig && awgConfig->hasClientConfig()) {
AwgClientConfig clientConfig = awgConfig->clientConfig.value();
QString clientPrivKey = clientConfig.clientPrivateKey;
QString clientPubKey = clientConfig.clientPublicKey;
protocolData.wireGuardClientPubKey = clientPubKey;
protocolData.wireGuardClientPrivKey = clientPrivKey;
} else {
protocolData = generateProtocolData(serviceProtocol);
}
} else {
protocolData = generateProtocolData(serviceProtocol);
}
QJsonObject authDataJson = apiV2->authData.toJson(); QJsonObject authDataJson = apiV2->authData.toJson();
GatewayRequestData gatewayRequestData { QSysInfo::productType(), GatewayRequestData gatewayRequestData { QSysInfo::productType(),

View File

@@ -31,6 +31,7 @@ ConnectionController::ConnectionController(SecureServersRepository* serversRepos
{ {
connect(m_vpnConnection, &VpnConnection::connectionStateChanged, this, &ConnectionController::connectionStateChanged); connect(m_vpnConnection, &VpnConnection::connectionStateChanged, this, &ConnectionController::connectionStateChanged);
connect(this, &ConnectionController::openConnectionRequested, m_vpnConnection, &VpnConnection::connectToVpn, Qt::QueuedConnection); connect(this, &ConnectionController::openConnectionRequested, m_vpnConnection, &VpnConnection::connectToVpn, Qt::QueuedConnection);
connect(this, &ConnectionController::switchConnectionRequested, m_vpnConnection, &VpnConnection::switchToVpn, Qt::QueuedConnection);
connect(this, &ConnectionController::closeConnectionRequested, m_vpnConnection, &VpnConnection::disconnectFromVpn, Qt::QueuedConnection); connect(this, &ConnectionController::closeConnectionRequested, m_vpnConnection, &VpnConnection::disconnectFromVpn, Qt::QueuedConnection);
connect(this, &ConnectionController::setConnectionStateRequested, m_vpnConnection, &VpnConnection::setConnectionState, Qt::QueuedConnection); connect(this, &ConnectionController::setConnectionStateRequested, m_vpnConnection, &VpnConnection::setConnectionState, Qt::QueuedConnection);
connect(this, &ConnectionController::killSwitchModeChangedRequested, m_vpnConnection, &VpnConnection::onKillSwitchModeChanged, Qt::QueuedConnection); connect(this, &ConnectionController::killSwitchModeChangedRequested, m_vpnConnection, &VpnConnection::onKillSwitchModeChanged, Qt::QueuedConnection);
@@ -86,8 +87,11 @@ ErrorCode ConnectionController::openConnection(int serverIndex)
if (errorCode != ErrorCode::NoError) { if (errorCode != ErrorCode::NoError) {
return errorCode; return errorCode;
} }
if (isConnected()) {
emit openConnectionRequested(serverIndex, container, vpnConfiguration); emit switchConnectionRequested(serverIndex, container, vpnConfiguration);
} else {
emit openConnectionRequested(serverIndex, container, vpnConfiguration);
}
return ErrorCode::NoError; return ErrorCode::NoError;
} }

View File

@@ -61,6 +61,7 @@ public:
signals: signals:
void connectionStateChanged(Vpn::ConnectionState state); void connectionStateChanged(Vpn::ConnectionState state);
void openConnectionRequested(int serverIndex, DockerContainer container, const QJsonObject &vpnConfiguration); void openConnectionRequested(int serverIndex, DockerContainer container, const QJsonObject &vpnConfiguration);
void switchConnectionRequested(int serverIndex, DockerContainer container, const QJsonObject &vpnConfiguration);
void closeConnectionRequested(); void closeConnectionRequested();
void setConnectionStateRequested(Vpn::ConnectionState state); void setConnectionStateRequested(Vpn::ConnectionState state);
void killSwitchModeChangedRequested(bool enabled); void killSwitchModeChangedRequested(bool enabled);

View File

@@ -7,6 +7,8 @@
#include <QJsonObject> #include <QJsonObject>
#include <QSysInfo> #include <QSysInfo>
#include <QTimer> #include <QTimer>
#include <QStandardPaths>
#include <QTemporaryDir>
#include "amneziaApplication.h" #include "amneziaApplication.h"
#include "logger.h" #include "logger.h"

View File

@@ -160,3 +160,10 @@ bool VpnProtocol::isDisconnected() const
{ {
return m_connectionState == Vpn::ConnectionState::Disconnected; return m_connectionState == Vpn::ConnectionState::Disconnected;
} }
ErrorCode VpnProtocol::switchServer(const QJsonObject &newConfig)
{
stop();
m_rawConfig = newConfig;
return start();
}

View File

@@ -60,6 +60,7 @@ public:
virtual bool isDisconnected() const; virtual bool isDisconnected() const;
virtual ErrorCode start() = 0; virtual ErrorCode start() = 0;
virtual void stop() = 0; virtual void stop() = 0;
virtual ErrorCode switchServer(const QJsonObject &newConfig);
Vpn::ConnectionState connectionState() const; Vpn::ConnectionState connectionState() const;
ErrorCode lastError() const; ErrorCode lastError() const;

View File

@@ -79,3 +79,9 @@ ErrorCode WireguardProtocol::start()
{ {
return startMzImpl(); return startMzImpl();
} }
ErrorCode WireguardProtocol::switchServer(const QJsonObject &newConfig)
{
m_rawConfig = newConfig;
return startMzImpl();
}

View File

@@ -24,6 +24,7 @@ public:
ErrorCode startMzImpl(); ErrorCode startMzImpl();
ErrorCode stopMzImpl(); ErrorCode stopMzImpl();
ErrorCode switchServer(const QJsonObject &newConfig);
private: private:

View File

@@ -501,9 +501,7 @@ bool Daemon::supportServerSwitching(const InterfaceConfig& config) const {
return current.m_privateKey == config.m_privateKey && return current.m_privateKey == config.m_privateKey &&
current.m_deviceIpv4Address == config.m_deviceIpv4Address && current.m_deviceIpv4Address == config.m_deviceIpv4Address &&
current.m_deviceIpv6Address == config.m_deviceIpv6Address && current.m_deviceIpv6Address == config.m_deviceIpv6Address;
current.m_serverIpv4Gateway == config.m_serverIpv4Gateway &&
current.m_serverIpv6Gateway == config.m_serverIpv6Gateway;
} }
bool Daemon::switchServer(const InterfaceConfig& config) { bool Daemon::switchServer(const InterfaceConfig& config) {

View File

@@ -184,24 +184,23 @@ PageType {
imageSource: "qrc:/images/controls/download.svg" imageSource: "qrc:/images/controls/download.svg"
checked: index === ApiCountryModel.currentIndex checked: index === ApiCountryModel.currentIndex
checkable: !ConnectionController.isConnected checkable: !ConnectionController.isConnectionInProgress
onClicked: { onClicked: {
if (ConnectionController.isConnectionInProgress) { if (ConnectionController.isConnectionInProgress) {
PageController.showNotificationMessage(qsTr("Unable change server location while trying to make an active connection")) PageController.showNotificationMessage(qsTr("Unable change server location while trying to make an active connection"))
return return
} }
if (ConnectionController.isConnected) {
PageController.showNotificationMessage(qsTr("Unable change server location while there is an active connection"))
return
}
if (index !== ApiCountryModel.currentIndex) { if (index !== ApiCountryModel.currentIndex) {
PageController.showBusyIndicator(true) PageController.showBusyIndicator(true)
var prevIndex = ApiCountryModel.currentIndex var prevIndex = ApiCountryModel.currentIndex
var wasConnected = ConnectionController.isConnected
ApiCountryModel.currentIndex = index ApiCountryModel.currentIndex = index
if (!SubscriptionUiController.updateServiceFromGateway(ServersUiController.getProcessedServerIndex(), countryCode, countryName)) { if (!SubscriptionUiController.updateServiceFromGateway(ServersUiController.getProcessedServerIndex(), countryCode, countryName)) {
ApiCountryModel.currentIndex = prevIndex ApiCountryModel.currentIndex = prevIndex
} else if (wasConnected) {
ConnectionController.openConnection()
} }
PageController.showBusyIndicator(false) PageController.showBusyIndicator(false)
} }

View File

@@ -197,6 +197,21 @@ void VpnConnection::connectToVpn(int serverIndex, DockerContainer container, con
} }
} }
void VpnConnection::switchToVpn(int serverIndex, DockerContainer container, const QJsonObject &vpnConfiguration)
{
if (!m_vpnProtocol.isNull() && ContainerUtils::isAwgContainer(container) && m_connectionState == Vpn::ConnectionState::Connected) {
m_remoteAddress = NetworkUtilities::getIPAddress(vpnConfiguration.value(configKey::hostName).toString());
m_trafficGuard->allowEndpoint(m_remoteAddress);
m_vpnConfiguration = vpnConfiguration;
appendKillSwitchConfig();
appendSplitTunnelingConfig();
m_trafficGuard->setConfig(m_vpnConfiguration);
m_vpnProtocol->switchServer(m_vpnConfiguration);
} else {
connectToVpn(serverIndex, container, vpnConfiguration);
}
}
void VpnConnection::createProtocolConnections() void VpnConnection::createProtocolConnections()
{ {
connect(m_vpnProtocol.data(), &VpnProtocol::protocolError, this, &VpnConnection::vpnProtocolError); connect(m_vpnProtocol.data(), &VpnProtocol::protocolError, this, &VpnConnection::vpnProtocolError);

View File

@@ -47,6 +47,7 @@ public:
public slots: public slots:
void setRepositories(SecureServersRepository* serversRepository, SecureAppSettingsRepository* appSettingsRepository); void setRepositories(SecureServersRepository* serversRepository, SecureAppSettingsRepository* appSettingsRepository);
void connectToVpn(int serverIndex, DockerContainer container, const QJsonObject &vpnConfiguration); void connectToVpn(int serverIndex, DockerContainer container, const QJsonObject &vpnConfiguration);
void switchToVpn(int serverIndex, DockerContainer container, const QJsonObject &vpnConfiguration);
void reconnectToVpn(); void reconnectToVpn();
void disconnectFromVpn(); void disconnectFromVpn();