// SPDX-License-Identifier: MIT /** Copyright (c) 2015 - 2022 Beckhoff Automation GmbH & Co. KG */ #include "AmsRouter.h" #include "Log.h" #include #include namespace Beckhoff { namespace Ads { AmsRouter::AmsRouter(AmsNetId netId) : localAmsNetId(netId) {} long AmsRouter::ConnectTarget(AmsNetId ams, const IpV4& ip) { /** * We keep this madness only for backwards compatibility, to give * downstream projects time to migrate to the much saner interface. */ struct in_addr addr; static_assert(sizeof(addr) == sizeof(ip), "Oops sizeof(IpV4) doesn't match sizeof(in_addr)"); memcpy(&addr, &ip.value, sizeof(addr)); return ConnectTarget(ams, std::string(inet_ntoa(addr))); } long AmsRouter::ConnectTarget(AmsNetId ams, const std::string& host) { /** DNS lookups are pretty time consuming, we shouldn't do them with a looked mutex! So instead we do the lookup first and use the results, later. */ bool found = false; auto hostAddresses = Beckhoff::GetHostAddresses(host, ADS_TCP_SERVER_PORT_STR, &found); if(found == false){ LOG_ERROR("AmsRouter unable to get host addresses."); return ROUTERERR_INVALIDHOST; } std::unique_lock lock(mutex); AwaitConnectionAttempts(ams, lock); const auto oldConnection = GetConnection(ams); //if (oldConnection && !oldConnection->IsConnectedTo(hostAddresses.get())) { if (oldConnection && oldConnection->IsConnected()) { /** There is already a route for this AmsNetId, but with a different IP. The old route has to be deleted, first! */ return ROUTERERR_PORTALREADYINUSE; } //increase the ref count if remote target had been connected for (const auto& conn : connections) { if (conn->IsConnectedTo(hostAddresses.get())) { conn->refCount++; connection_mapping[ams] = conn.get(); return ADSERR_NOERR; } } connection_attempts[ams] = {}; lock.unlock(); // try { //AmsConnection is created and try to connect remote host auto new_connection = std::unique_ptr(new AmsConnection {*this, hostAddresses.get()}); if(new_connection.get() == nullptr || new_connection.get()->IsConnected() == false){ LOG_ERROR("AmsRouter add AmsConnection failed."); return ROUTERERR_HOSTDENY; } lock.lock(); connection_attempts.erase(ams); connection_attempt_events.notify_all(); auto conn = connections.emplace(std::move(new_connection)); if (conn.second) { // in case no local AmsNetId was set previously, we derive one if (!localAmsNetId) { localAmsNetId = AmsNetId {conn.first->get()->localIp}; } conn.first->get()->refCount++; connection_mapping[ams] = conn.first->get(); return !conn.first->get()->localIp; } return -1;// // } catch (std::exception& e) { // std::cout<<"AmsRouter::AddRoute(Exception) is occured."<<"\n"; // lock.lock(); // connection_attempts.erase(ams); // connection_attempt_events.notify_all(); // throw e; // } } void AmsRouter::DisconnectTarget(const AmsNetId& ams) { std::unique_lock lock(mutex); AwaitConnectionAttempts(ams, lock); auto route = connection_mapping.find(ams); if (route != connection_mapping.end()) { AmsConnection* conn = route->second; if (0 == --conn->refCount) { connection_mapping.erase(route); DeleteIfLastConnection(conn); } } } void AmsRouter::DeleteIfLastConnection(const AmsConnection* const conn) { if (conn) { for (const auto& r : connection_mapping) { if (r.second == conn) { return; } } for (auto it = connections.begin(); it != connections.end(); ++it) { if (conn == it->get()) { connections.erase(it); return; } } } } uint16_t AmsRouter::OpenPort() { std::lock_guard lock(mutex); for (uint16_t i = 0; i < NUM_PORTS_MAX; ++i) { if (!ports[i].IsOpen()) { return ports[i].Open(PORT_BASE + i); } } return 0; } long AmsRouter::ClosePort(uint16_t port) { std::lock_guard lock(mutex); if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX) || !ports[port - PORT_BASE].IsOpen()) { return ADSERR_CLIENT_PORTNOTOPEN; } ports[port - PORT_BASE].Close(); return 0; } long AmsRouter::GetAmsAddr(uint16_t port, AmsAddr* pAddr) { std::lock_guard lock(mutex); if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX)) { return ADSERR_CLIENT_PORTNOTOPEN; } if (ports[port - PORT_BASE].IsOpen()) { memcpy(&pAddr->netId, &localAmsNetId, sizeof(localAmsNetId)); pAddr->port = port; return 0; } return ADSERR_CLIENT_PORTNOTOPEN; } void AmsRouter::SetAmsNetId(AmsNetId netId) { std::lock_guard lock(mutex); localAmsNetId = netId; } long AmsRouter::GetTimeout(uint16_t port, uint32_t& timeout) { std::lock_guard lock(mutex); if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX)) { return ADSERR_CLIENT_PORTNOTOPEN; } timeout = ports[port - PORT_BASE].tmms; return 0; } long AmsRouter::SetTimeout(uint16_t port, uint32_t timeout) { std::lock_guard lock(mutex); if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX)) { return ADSERR_CLIENT_PORTNOTOPEN; } ports[port - PORT_BASE].tmms = timeout; return 0; } AmsConnection* AmsRouter::GetConnection(const AmsNetId& amsDest) { std::lock_guard lock(mutex); const auto it = connection_mapping.find(amsDest); if (it != connection_mapping.end()) { return it->second; } return nullptr; } long AmsRouter::SendAdsRequest(AmsRequest& request) { if (request.bytesRead) { *request.bytesRead = 0; } auto conn = GetConnection(request.destAddr.netId); if (!conn) { return GLOBALERR_MISSING_ROUTE; } return conn->SendRequest(request, ports[request.srcPort - Router::PORT_BASE].tmms); } long AmsRouter::AddNotification(AmsRequest& request, uint32_t* pNotification, std::shared_ptr notify) { if (request.bytesRead) { *request.bytesRead = 0; } auto conn = GetConnection(request.destAddr.netId); if (!conn) { return GLOBALERR_MISSING_ROUTE; } auto& port = ports[request.srcPort - Router::PORT_BASE]; const long status = conn->SendRequest(request, port.tmms); if (!status) { *pNotification = Beckhoff::letoh(request.buffer); auto dispatcher = conn->CreateNotifyMapping(*pNotification, notify); port.AddNotification(request.destAddr, *pNotification, dispatcher); } return status; } long AmsRouter::DelNotification(uint16_t srcPort, const AmsAddr* pAddr, uint32_t hNotification) { auto& p = ports[srcPort - Router::PORT_BASE]; return p.DelNotification(*pAddr, hNotification); } void AmsRouter::AwaitConnectionAttempts(const AmsNetId& ams, std::unique_lock& lock) { connection_attempt_events.wait(lock, [&]() { return connection_attempts.find(ams) == connection_attempts.end(); }); } } }