map/das-dn/third_party/AdsLib/Standalone/AmsRouter.cpp

238 lines
6.6 KiB
C++
Raw Normal View History

2024-12-03 10:36:06 +08:00
// SPDX-License-Identifier: MIT
/**
Copyright (c) 2015 - 2022 Beckhoff Automation GmbH & Co. KG
*/
#include "AmsRouter.h"
#include "Log.h"
#include <algorithm>
2024-12-05 11:04:47 +08:00
AmsRouter::AmsRouter(AmsNetId netId)
: localAddr(netId)
2024-12-03 10:36:06 +08:00
{}
2024-12-05 11:04:47 +08:00
long AmsRouter::AddRoute(AmsNetId ams, const IpV4& ip)
2024-12-03 10:36:06 +08:00
{
/**
* We keep this madness only for backwards compatibility, to give
* downstream projects time to migrate to the much saner interface.
*/
struct in_addr addr;
static_assert(sizeof(addr) == sizeof(ip), "Oops sizeof(IpV4) doesn't match sizeof(in_addr)");
memcpy(&addr, &ip.value, sizeof(addr));
2024-12-05 11:04:47 +08:00
return AddRoute(ams, std::string(inet_ntoa(addr)));
2024-12-03 10:36:06 +08:00
}
2024-12-05 11:04:47 +08:00
long AmsRouter::AddRoute(AmsNetId ams, const std::string& host)
2024-12-03 10:36:06 +08:00
{
/**
DNS lookups are pretty time consuming, we shouldn't do them
with a looked mutex! So instead we do the lookup first and
use the results, later.
*/
2024-12-05 11:04:47 +08:00
auto hostAddresses = bhf::ads::GetListOfAddresses(host, "48898");
2024-12-03 10:36:06 +08:00
std::unique_lock<std::recursive_mutex> lock(mutex);
2024-12-05 11:04:47 +08:00
2024-12-03 10:36:06 +08:00
AwaitConnectionAttempts(ams, lock);
2024-12-05 11:04:47 +08:00
2024-12-03 10:36:06 +08:00
const auto oldConnection = GetConnection(ams);
2024-12-05 11:04:47 +08:00
if (oldConnection && !oldConnection->IsConnectedTo(hostAddresses.get())) {
2024-12-03 10:36:06 +08:00
/**
There is already a route for this AmsNetId, but with
a different IP. The old route has to be deleted, first!
*/
return ROUTERERR_PORTALREADYINUSE;
}
for (const auto& conn : connections) {
if (conn->IsConnectedTo(hostAddresses.get())) {
conn->refCount++;
2024-12-05 11:04:47 +08:00
mapping[ams] = conn.get();
return 0;
2024-12-03 10:36:06 +08:00
}
}
2024-12-05 11:04:47 +08:00
2024-12-03 10:36:06 +08:00
connection_attempts[ams] = {};
lock.unlock();
2024-12-05 11:04:47 +08:00
2024-12-03 10:36:06 +08:00
try {
2024-12-05 11:04:47 +08:00
auto new_connection = std::unique_ptr<AmsConnection>(new AmsConnection { *this, hostAddresses.get()});
2024-12-03 10:36:06 +08:00
lock.lock();
connection_attempts.erase(ams);
connection_attempt_events.notify_all();
2024-12-05 11:04:47 +08:00
2024-12-03 10:36:06 +08:00
auto conn = connections.emplace(std::move(new_connection));
if (conn.second) {
2024-12-05 11:04:47 +08:00
/** in case no local AmsNetId was set previously, we derive one */
if (!localAddr) {
localAddr = AmsNetId {conn.first->get()->ownIp};
2024-12-03 10:36:06 +08:00
}
conn.first->get()->refCount++;
2024-12-05 11:04:47 +08:00
mapping[ams] = conn.first->get();
return !conn.first->get()->ownIp;
2024-12-03 10:36:06 +08:00
}
2024-12-05 11:04:47 +08:00
return -1;
2024-12-03 10:36:06 +08:00
} catch (std::exception& e) {
lock.lock();
connection_attempts.erase(ams);
connection_attempt_events.notify_all();
throw e;
}
}
2024-12-05 11:04:47 +08:00
void AmsRouter::DelRoute(const AmsNetId& ams)
2024-12-03 10:36:06 +08:00
{
std::unique_lock<std::recursive_mutex> lock(mutex);
2024-12-05 11:04:47 +08:00
2024-12-03 10:36:06 +08:00
AwaitConnectionAttempts(ams, lock);
2024-12-05 11:04:47 +08:00
auto route = mapping.find(ams);
if (route != mapping.end()) {
2024-12-03 10:36:06 +08:00
AmsConnection* conn = route->second;
if (0 == --conn->refCount) {
2024-12-05 11:04:47 +08:00
mapping.erase(route);
2024-12-03 10:36:06 +08:00
DeleteIfLastConnection(conn);
}
}
}
void AmsRouter::DeleteIfLastConnection(const AmsConnection* const conn)
{
if (conn) {
2024-12-05 11:04:47 +08:00
for (const auto& r : mapping) {
2024-12-03 10:36:06 +08:00
if (r.second == conn) {
return;
}
}
for (auto it = connections.begin(); it != connections.end(); ++it) {
if (conn == it->get()) {
connections.erase(it);
return;
}
}
}
}
uint16_t AmsRouter::OpenPort()
{
std::lock_guard<std::recursive_mutex> lock(mutex);
2024-12-05 11:04:47 +08:00
2024-12-03 10:36:06 +08:00
for (uint16_t i = 0; i < NUM_PORTS_MAX; ++i) {
if (!ports[i].IsOpen()) {
return ports[i].Open(PORT_BASE + i);
}
}
return 0;
}
long AmsRouter::ClosePort(uint16_t port)
{
std::lock_guard<std::recursive_mutex> lock(mutex);
if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX) || !ports[port - PORT_BASE].IsOpen()) {
return ADSERR_CLIENT_PORTNOTOPEN;
}
ports[port - PORT_BASE].Close();
return 0;
}
2024-12-05 11:04:47 +08:00
long AmsRouter::GetLocalAddress(uint16_t port, AmsAddr* pAddr)
2024-12-03 10:36:06 +08:00
{
std::lock_guard<std::recursive_mutex> lock(mutex);
if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX)) {
return ADSERR_CLIENT_PORTNOTOPEN;
}
if (ports[port - PORT_BASE].IsOpen()) {
2024-12-05 11:04:47 +08:00
memcpy(&pAddr->netId, &localAddr, sizeof(localAddr));
2024-12-03 10:36:06 +08:00
pAddr->port = port;
return 0;
}
return ADSERR_CLIENT_PORTNOTOPEN;
}
2024-12-05 11:04:47 +08:00
void AmsRouter::SetLocalAddress(AmsNetId netId)
2024-12-03 10:36:06 +08:00
{
std::lock_guard<std::recursive_mutex> lock(mutex);
2024-12-05 11:04:47 +08:00
localAddr = netId;
2024-12-03 10:36:06 +08:00
}
long AmsRouter::GetTimeout(uint16_t port, uint32_t& timeout)
{
std::lock_guard<std::recursive_mutex> lock(mutex);
if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX)) {
return ADSERR_CLIENT_PORTNOTOPEN;
}
timeout = ports[port - PORT_BASE].tmms;
return 0;
}
long AmsRouter::SetTimeout(uint16_t port, uint32_t timeout)
{
std::lock_guard<std::recursive_mutex> lock(mutex);
if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX)) {
return ADSERR_CLIENT_PORTNOTOPEN;
}
ports[port - PORT_BASE].tmms = timeout;
return 0;
}
AmsConnection* AmsRouter::GetConnection(const AmsNetId& amsDest)
{
std::lock_guard<std::recursive_mutex> lock(mutex);
2024-12-05 11:04:47 +08:00
const auto it = mapping.find(amsDest);
if (it != mapping.end()) {
2024-12-03 10:36:06 +08:00
return it->second;
}
return nullptr;
}
2024-12-05 11:04:47 +08:00
long AmsRouter::AdsRequest(AmsRequest& request)
2024-12-03 10:36:06 +08:00
{
if (request.bytesRead) {
*request.bytesRead = 0;
}
2024-12-05 11:04:47 +08:00
auto ads = GetConnection(request.destAddr.netId);
if (!ads) {
2024-12-03 10:36:06 +08:00
return GLOBALERR_MISSING_ROUTE;
}
2024-12-05 11:04:47 +08:00
return ads->AdsRequest(request, ports[request.port - Router::PORT_BASE].tmms);
2024-12-03 10:36:06 +08:00
}
long AmsRouter::AddNotification(AmsRequest& request, uint32_t* pNotification, std::shared_ptr<Notification> notify)
{
if (request.bytesRead) {
*request.bytesRead = 0;
}
2024-12-05 11:04:47 +08:00
auto ads = GetConnection(request.destAddr.netId);
if (!ads) {
2024-12-03 10:36:06 +08:00
return GLOBALERR_MISSING_ROUTE;
}
2024-12-05 11:04:47 +08:00
auto& port = ports[request.port - Router::PORT_BASE];
const long status = ads->AdsRequest(request, port.tmms);
2024-12-03 10:36:06 +08:00
if (!status) {
2024-12-05 11:04:47 +08:00
*pNotification = bhf::ads::letoh<uint32_t>(request.buffer);
auto dispatcher = ads->CreateNotifyMapping(*pNotification, notify);
2024-12-03 10:36:06 +08:00
port.AddNotification(request.destAddr, *pNotification, dispatcher);
}
return status;
}
2024-12-05 11:04:47 +08:00
long AmsRouter::DelNotification(uint16_t port, const AmsAddr* pAddr, uint32_t hNotification)
2024-12-03 10:36:06 +08:00
{
2024-12-05 11:04:47 +08:00
auto& p = ports[port - Router::PORT_BASE];
2024-12-03 10:36:06 +08:00
return p.DelNotification(*pAddr, hNotification);
}
void AmsRouter::AwaitConnectionAttempts(const AmsNetId& ams, std::unique_lock<std::recursive_mutex>& lock)
{
connection_attempt_events.wait(lock, [&]() { return connection_attempts.find(ams) == connection_attempts.end(); });
}