aboutsummaryrefslogtreecommitdiffstats
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/train.py b/train.py
index a9c865b..8b15547 100644
--- a/train.py
+++ b/train.py
@@ -8,7 +8,7 @@ THETAS_FILE = "thetas.csv"
def normalize(data):
min_val = min(data)
max_val = max(data)
- return [(x - min_val) / (max_val - min_val) for x in data], min_val, max_val
+ return [(x - min_val) / (max_val - min_val) for x in data]
def load_data():
@@ -28,7 +28,7 @@ def estimate_price(mileage, theta0, theta1):
# DV: dependant variable, IV: independant variable
-def train_once(learning_rate, DV, IV, theta0, theta1):
+def compute_gradients(learning_rate, DV, IV, theta0, theta1):
tmp0 = (
learning_rate
* (1.0 / len(DV))
@@ -52,12 +52,14 @@ def denormalize_thetas(t0, t1, km_min, km_max, price_min, price_max):
def train(learning_rate, iterations):
kms, prices = load_data()
- kms_norm, km_min, km_max = normalize(kms)
- prices_norm, price_min, price_max = normalize(prices)
+ km_min, km_max = min(kms), max(kms)
+ price_min, price_max = min(prices), max(prices)
+ kms_norm = normalize(kms)
+ prices_norm = normalize(prices)
t0 = 0.0
t1 = 0.0
for _ in range(iterations):
- grad0, grad1 = train_once(learning_rate, prices_norm, kms_norm, t0, t1)
+ grad0, grad1 = compute_gradients(learning_rate, prices_norm, kms_norm, t0, t1)
t0 -= grad0
t1 -= grad1
return denormalize_thetas(t0, t1, km_min, km_max, price_min, price_max)
/**
 * @file   BobinkClient.cpp
 * @brief  BobinkClient implementation.
 */
#include "BobinkClient.h"
#include "BobinkAuth.h"

#include <QDir>
#include <QOpcUaUserTokenPolicy>
#include <QStandardPaths>

BobinkClient *BobinkClient::s_instance = nullptr;

static QString defaultPkiDir()
{
    return QStandardPaths::writableLocation(QStandardPaths::AppDataLocation)
           + QStringLiteral("/pki");
}

/** @brief Create the standard OPC UA PKI directory tree. */
static void ensurePkiDirs(const QString &base)
{
    for (const auto *sub : {"own/certs", "own/private",
                            "trusted/certs", "trusted/crl",
                            "issuers/certs", "issuers/crl"}) {
        QDir().mkpath(base + QLatin1Char('/') + QLatin1String(sub));
    }
}

BobinkClient::BobinkClient(QObject *parent)
    : QObject(parent)
    , m_provider(new QOpcUaProvider(this))
    , m_pkiDir(defaultPkiDir())
{
    ensurePkiDirs(m_pkiDir);
    setupClient();
    autoDetectPki();
    applyPki();
    connect(&m_discoveryTimer, &QTimer::timeout, this, &BobinkClient::doDiscovery);
}

BobinkClient::~BobinkClient()
{
    if (s_instance == this)
        s_instance = nullptr;
}

BobinkClient *BobinkClient::instance()
{
    return s_instance;
}

BobinkClient *BobinkClient::create(QQmlEngine *, QJSEngine *)
{
    if (!s_instance) {
        s_instance = new BobinkClient;
        QJSEngine::setObjectOwnership(s_instance, QJSEngine::CppOwnership);
    }
    return s_instance;
}

void BobinkClient::setupClient()
{
    m_client = m_provider->createClient(QStringLiteral("open62541"));
    if (!m_client) {
        qWarning() << "BobinkClient: failed to create open62541 backend";
        return;
    }

    connect(m_client, &QOpcUaClient::stateChanged,
            this, &BobinkClient::handleStateChanged);
    connect(m_client, &QOpcUaClient::endpointsRequestFinished,
            this, &BobinkClient::handleEndpointsReceived);
    connect(m_client, &QOpcUaClient::connectError,
            this, &BobinkClient::handleConnectError);
    connect(m_client, &QOpcUaClient::findServersFinished,
            this, &BobinkClient::handleFindServersFinished);
}

/* ======================================
 *  Connection properties
 * ====================================== */

bool BobinkClient::connected() const { return m_connected; }

QString BobinkClient::serverUrl() const { return m_serverUrl; }

void BobinkClient::setServerUrl(const QString &url)
{
    if (m_serverUrl == url)
        return;
    m_serverUrl = url;
    emit serverUrlChanged();
}

BobinkAuth *BobinkClient::auth() const { return m_auth; }

void BobinkClient::setAuth(BobinkAuth *auth)
{
    if (m_auth == auth)
        return;
    m_auth = auth;
    emit authChanged();
}

QOpcUaClient *BobinkClient::opcuaClient() const { return m_client; }

/* ======================================
 *  Connection methods
 * ====================================== */

void BobinkClient::connectToServer()
{
    if (!m_client) {
        emit connectionError(QStringLiteral("OPC UA backend not available"));
        return;
    }
    if (m_serverUrl.isEmpty()) {
        emit connectionError(QStringLiteral("No server URL set"));
        return;
    }
    if (m_client->state() != QOpcUaClient::Disconnected) {
        emit connectionError(QStringLiteral("Already connected or connecting"));
        return;
    }

    QUrl url(m_serverUrl);
    if (!url.isValid()) {
        emit connectionError(QStringLiteral("Invalid server URL: %1").arg(m_serverUrl));
        return;
    }
    m_client->requestEndpoints(url);
}

static QString securityPolicyUri(BobinkClient::SecurityPolicy policy)
{
    switch (policy) {
    case BobinkClient::Basic256Sha256:
        return QStringLiteral(
            "http://opcfoundation.org/UA/SecurityPolicy#Basic256Sha256");
    case BobinkClient::Aes128_Sha256_RsaOaep:
        return QStringLiteral(
            "http://opcfoundation.org/UA/SecurityPolicy#Aes128_Sha256_RsaOaep");
    case BobinkClient::Aes256_Sha256_RsaPss:
        return QStringLiteral(
            "http://opcfoundation.org/UA/SecurityPolicy#Aes256_Sha256_RsaPss");
    }
    return {};
}

void BobinkClient::connectDirect(SecurityPolicy policy, SecurityMode mode)
{
    if (!m_client) {
        emit connectionError(QStringLiteral("OPC UA backend not available"));
        return;
    }
    if (m_serverUrl.isEmpty()) {
        emit connectionError(QStringLiteral("No server URL set"));
        return;
    }
    if (m_client->state() != QOpcUaClient::Disconnected) {
        emit connectionError(QStringLiteral("Already connected or connecting"));
        return;
    }

    QOpcUaEndpointDescription endpoint;
    endpoint.setEndpointUrl(m_serverUrl);
    endpoint.setSecurityPolicy(securityPolicyUri(policy));
    endpoint.setSecurityMode(
        static_cast<QOpcUaEndpointDescription::MessageSecurityMode>(mode));

    QOpcUaUserTokenPolicy tokenPolicy;
    if (m_auth) {
        switch (m_auth->mode()) {
        case BobinkAuth::Anonymous:
            tokenPolicy.setTokenType(QOpcUaUserTokenPolicy::TokenType::Anonymous);
            break;
        case BobinkAuth::UserPass:
            tokenPolicy.setTokenType(QOpcUaUserTokenPolicy::TokenType::Username);
            break;
        case BobinkAuth::Certificate:
            tokenPolicy.setTokenType(QOpcUaUserTokenPolicy::TokenType::Certificate);
            break;
        }
        m_client->setAuthenticationInformation(m_auth->toAuthenticationInformation());
    } else {
        tokenPolicy.setTokenType(QOpcUaUserTokenPolicy::TokenType::Anonymous);
    }
    endpoint.setUserIdentityTokens({tokenPolicy});

    m_client->connectToEndpoint(endpoint);
}

void BobinkClient::disconnectFromServer()
{
    if (m_client)
        m_client->disconnectFromEndpoint();
}

void BobinkClient::acceptCertificate()
{
    m_certAccepted = true;
    if (m_certLoop)
        m_certLoop->quit();
}

void BobinkClient::rejectCertificate()
{
    m_certAccepted = false;
    if (m_certLoop)
        m_certLoop->quit();
}

/* ======================================
 *  Discovery properties
 * ====================================== */

QString BobinkClient::discoveryUrl() const { return m_discoveryUrl; }

void BobinkClient::setDiscoveryUrl(const QString &url)
{
    if (m_discoveryUrl == url)
        return;
    m_discoveryUrl = url;
    emit discoveryUrlChanged();
}

int BobinkClient::discoveryInterval() const { return m_discoveryInterval; }

void BobinkClient::setDiscoveryInterval(int ms)
{
    if (m_discoveryInterval == ms)
        return;
    m_discoveryInterval = ms;
    emit discoveryIntervalChanged();

    if (m_discoveryTimer.isActive())
        m_discoveryTimer.setInterval(ms);
}

bool BobinkClient::discovering() const { return m_discovering; }

const QList<QOpcUaApplicationDescription> &BobinkClient::discoveredServers() const
{
    return m_discoveredServers;
}

QVariantList BobinkClient::servers() const
{
    return m_serversCache;
}

/* ======================================
 *  PKI
 * ====================================== */

QString BobinkClient::pkiDir() const { return m_pkiDir; }

void BobinkClient::setPkiDir(const QString &path)
{
    if (m_pkiDir == path)
        return;
    m_pkiDir = path;
    ensurePkiDirs(m_pkiDir);
    emit pkiDirChanged();
}

QString BobinkClient::certFile() const { return m_certFile; }

void BobinkClient::setCertFile(const QString &path)
{
    if (m_certFile == path)
        return;
    m_certFile = path;
    emit certFileChanged();
}

QString BobinkClient::keyFile() const { return m_keyFile; }

void BobinkClient::setKeyFile(const QString &path)
{
    if (m_keyFile == path)
        return;
    m_keyFile = path;
    emit keyFileChanged();
}

void BobinkClient::autoDetectPki()
{
    if (m_pkiDir.isEmpty())
        return;

    QDir certDir(m_pkiDir + QStringLiteral("/own/certs"));
    QStringList certs = certDir.entryList({QStringLiteral("*.der")}, QDir::Files);
    if (!certs.isEmpty())
        setCertFile(certDir.filePath(certs.first()));

    QDir keyDir(m_pkiDir + QStringLiteral("/own/private"));
    QStringList keys = keyDir.entryList(
        {QStringLiteral("*.pem"), QStringLiteral("*.crt")}, QDir::Files);
    if (!keys.isEmpty())
        setKeyFile(keyDir.filePath(keys.first()));
}

void BobinkClient::applyPki()
{
    if (!m_client || m_pkiDir.isEmpty())
        return;

    QOpcUaPkiConfiguration pki;
    if (!m_certFile.isEmpty())
        pki.setClientCertificateFile(m_certFile);
    if (!m_keyFile.isEmpty())
        pki.setPrivateKeyFile(m_keyFile);
    pki.setTrustListDirectory(m_pkiDir + QStringLiteral("/trusted/certs"));
    pki.setRevocationListDirectory(m_pkiDir + QStringLiteral("/trusted/crl"));
    pki.setIssuerListDirectory(m_pkiDir + QStringLiteral("/issuers/certs"));
    pki.setIssuerRevocationListDirectory(m_pkiDir + QStringLiteral("/issuers/crl"));

    m_client->setPkiConfiguration(pki);

    if (pki.isKeyAndCertificateFileSet())
        m_client->setApplicationIdentity(pki.applicationIdentity());
}

/* ======================================
 *  Discovery methods
 * ====================================== */

void BobinkClient::startDiscovery()
{
    if (m_discoveryUrl.isEmpty() || !m_client)
        return;

    doDiscovery();
    m_discoveryTimer.start(m_discoveryInterval);

    if (!m_discovering) {
        m_discovering = true;
        emit discoveringChanged();
    }
}

void BobinkClient::stopDiscovery()
{
    m_discoveryTimer.stop();

    if (m_discovering) {
        m_discovering = false;
        emit discoveringChanged();
    }
}

void BobinkClient::doDiscovery()
{
    if (!m_client || m_discoveryUrl.isEmpty())
        return;
    QUrl url(m_discoveryUrl);
    if (!url.isValid())
        return;
    m_client->findServers(url);
}

/* ======================================
 *  Private slots
 * ====================================== */

void BobinkClient::handleStateChanged(QOpcUaClient::ClientState state)
{
    bool nowConnected = (state == QOpcUaClient::Connected);
    if (m_connected != nowConnected) {
        m_connected = nowConnected;
        emit connectedChanged();
    }
}

void BobinkClient::handleEndpointsReceived(
    const QList<QOpcUaEndpointDescription> &endpoints,
    QOpcUa::UaStatusCode statusCode, const QUrl &)
{
    if (statusCode != QOpcUa::Good || endpoints.isEmpty()) {
        emit connectionError(QStringLiteral("Failed to retrieve endpoints"));
        return;
    }

    QOpcUaEndpointDescription best = endpoints.first();
    for (const auto &ep : endpoints) {
        if (ep.securityLevel() > best.securityLevel())
            best = ep;
    }

    if (m_auth)
        m_client->setAuthenticationInformation(m_auth->toAuthenticationInformation());

    m_client->connectToEndpoint(best);
}

void BobinkClient::handleConnectError(QOpcUaErrorState *errorState)
{
    if (errorState->connectionStep() ==
        QOpcUaErrorState::ConnectionStep::CertificateValidation) {
        // connectError uses BlockingQueuedConnection — the backend thread is
        // blocked waiting for us to return.  The errorState pointer is stack-
        // allocated in the backend, so it is only valid during this call.
        // Spin a local event loop so QML can show a dialog and call
        // acceptCertificate() / rejectCertificate() while we stay in scope.
        m_certAccepted = false;
        emit certificateTrustRequested(
            QStringLiteral("The server certificate is not trusted. Accept?"));

        QEventLoop loop;
        m_certLoop = &loop;
        QTimer::singleShot(30000, &loop, &QEventLoop::quit);
        loop.exec();
        m_certLoop = nullptr;

        errorState->setIgnoreError(m_certAccepted);
    } else {
        emit connectionError(
            QStringLiteral("Connection error at step %1, code 0x%2")
                .arg(static_cast<int>(errorState->connectionStep()))
                .arg(static_cast<uint>(errorState->errorCode()), 8, 16, QLatin1Char('0')));
    }
}

void BobinkClient::handleFindServersFinished(
    const QList<QOpcUaApplicationDescription> &servers,
    QOpcUa::UaStatusCode statusCode, const QUrl &)
{
    if (statusCode != QOpcUa::Good)
        return;

    m_discoveredServers = servers;
    m_serversCache.clear();
    for (const auto &s : m_discoveredServers) {
        QVariantMap entry;
        entry[QStringLiteral("serverName")] = s.applicationName().text();
        entry[QStringLiteral("applicationUri")] = s.applicationUri();
        entry[QStringLiteral("discoveryUrls")] = QVariant::fromValue(s.discoveryUrls());
        m_serversCache.append(entry);
    }
    emit serversChanged();
}