map/das-dn/third_party/AdsLib/Standalone/AmsRouter.cpp
2024-12-30 10:54:46 +08:00

252 lines
7.5 KiB
C++

// SPDX-License-Identifier: MIT
/**
Copyright (c) 2015 - 2022 Beckhoff Automation GmbH & Co. KG
*/
#include "AmsRouter.h"
#include "Log.h"
#include <algorithm>
#include <iostream>
namespace Beckhoff
{
namespace Ads
{
AmsRouter::AmsRouter(AmsNetId netId) : localAmsNetId(netId)
{}
long AmsRouter::ConnectTarget(AmsNetId ams, const IpV4& ip)
{
/**
* We keep this madness only for backwards compatibility, to give
* downstream projects time to migrate to the much saner interface.
*/
struct in_addr addr;
static_assert(sizeof(addr) == sizeof(ip), "Oops sizeof(IpV4) doesn't match sizeof(in_addr)");
memcpy(&addr, &ip.value, sizeof(addr));
return ConnectTarget(ams, std::string(inet_ntoa(addr)));
}
long AmsRouter::ConnectTarget(AmsNetId ams, const std::string& host)
{
/**
DNS lookups are pretty time consuming, we shouldn't do them
with a looked mutex! So instead we do the lookup first and
use the results, later.
*/
bool found = false;
auto hostAddresses = Beckhoff::GetHostAddresses(host, ADS_TCP_SERVER_PORT_STR, &found);
if(found == false){
LOG_ERROR("AmsRouter unable to get host addresses.");
return ROUTERERR_INVALIDHOST;
}
std::unique_lock<std::recursive_mutex> lock(mutex);
AwaitConnectionAttempts(ams, lock);
const auto oldConnection = GetConnection(ams);
//if (oldConnection && !oldConnection->IsConnectedTo(hostAddresses.get())) {
if (oldConnection && oldConnection->IsConnected()) {
/**
There is already a route for this AmsNetId, but with
a different IP. The old route has to be deleted, first!
*/
return ROUTERERR_PORTALREADYINUSE;
}
//increase the ref count if remote target had been connected
for (const auto& conn : connections) {
if (conn->IsConnectedTo(hostAddresses.get())) {
conn->refCount++;
connection_mapping[ams] = conn.get();
return ADSERR_NOERR;
}
}
connection_attempts[ams] = {};
lock.unlock();
// try {
//AmsConnection is created and try to connect remote host
auto new_connection = std::unique_ptr<AmsConnection>(new AmsConnection {*this, hostAddresses.get()});
if(new_connection.get() == nullptr || new_connection.get()->IsConnected() == false){
LOG_ERROR("AmsRouter add AmsConnection failed.");
return ROUTERERR_HOSTDENY;
}
lock.lock();
connection_attempts.erase(ams);
connection_attempt_events.notify_all();
auto conn = connections.emplace(std::move(new_connection));
if (conn.second) {
// in case no local AmsNetId was set previously, we derive one
if (!localAmsNetId) {
localAmsNetId = AmsNetId {conn.first->get()->localIp};
}
conn.first->get()->refCount++;
connection_mapping[ams] = conn.first->get();
return !conn.first->get()->localIp;
}
return -1;//
// } catch (std::exception& e) {
// std::cout<<"AmsRouter::AddRoute(Exception) is occured."<<"\n";
// lock.lock();
// connection_attempts.erase(ams);
// connection_attempt_events.notify_all();
// throw e;
// }
}
void AmsRouter::DisconnectTarget(const AmsNetId& ams)
{
std::unique_lock<std::recursive_mutex> lock(mutex);
AwaitConnectionAttempts(ams, lock);
auto route = connection_mapping.find(ams);
if (route != connection_mapping.end()) {
AmsConnection* conn = route->second;
if (0 == --conn->refCount) {
connection_mapping.erase(route);
DeleteIfLastConnection(conn);
}
}
}
void AmsRouter::DeleteIfLastConnection(const AmsConnection* const conn)
{
if (conn) {
for (const auto& r : connection_mapping) {
if (r.second == conn) {
return;
}
}
for (auto it = connections.begin(); it != connections.end(); ++it) {
if (conn == it->get()) {
connections.erase(it);
return;
}
}
}
}
uint16_t AmsRouter::OpenPort()
{
std::lock_guard<std::recursive_mutex> lock(mutex);
for (uint16_t i = 0; i < NUM_PORTS_MAX; ++i) {
if (!ports[i].IsOpen()) {
return ports[i].Open(PORT_BASE + i);
}
}
return 0;
}
long AmsRouter::ClosePort(uint16_t port)
{
std::lock_guard<std::recursive_mutex> lock(mutex);
if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX) || !ports[port - PORT_BASE].IsOpen()) {
return ADSERR_CLIENT_PORTNOTOPEN;
}
ports[port - PORT_BASE].Close();
return 0;
}
long AmsRouter::GetAmsAddr(uint16_t port, AmsAddr* pAddr)
{
std::lock_guard<std::recursive_mutex> lock(mutex);
if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX)) {
return ADSERR_CLIENT_PORTNOTOPEN;
}
if (ports[port - PORT_BASE].IsOpen()) {
memcpy(&pAddr->netId, &localAmsNetId, sizeof(localAmsNetId));
pAddr->port = port;
return 0;
}
return ADSERR_CLIENT_PORTNOTOPEN;
}
void AmsRouter::SetAmsNetId(AmsNetId netId)
{
std::lock_guard<std::recursive_mutex> lock(mutex);
localAmsNetId = netId;
}
long AmsRouter::GetTimeout(uint16_t port, uint32_t& timeout)
{
std::lock_guard<std::recursive_mutex> lock(mutex);
if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX)) {
return ADSERR_CLIENT_PORTNOTOPEN;
}
timeout = ports[port - PORT_BASE].tmms;
return 0;
}
long AmsRouter::SetTimeout(uint16_t port, uint32_t timeout)
{
std::lock_guard<std::recursive_mutex> lock(mutex);
if ((port < PORT_BASE) || (port >= PORT_BASE + NUM_PORTS_MAX)) {
return ADSERR_CLIENT_PORTNOTOPEN;
}
ports[port - PORT_BASE].tmms = timeout;
return 0;
}
AmsConnection* AmsRouter::GetConnection(const AmsNetId& amsDest)
{
std::lock_guard<std::recursive_mutex> lock(mutex);
const auto it = connection_mapping.find(amsDest);
if (it != connection_mapping.end()) {
return it->second;
}
return nullptr;
}
long AmsRouter::SendAdsRequest(AmsRequest& request)
{
if (request.bytesRead) {
*request.bytesRead = 0;
}
auto conn = GetConnection(request.destAddr.netId);
if (!conn) {
return GLOBALERR_MISSING_ROUTE;
}
return conn->SendRequest(request, ports[request.srcPort - Router::PORT_BASE].tmms);
}
long AmsRouter::AddNotification(AmsRequest& request, uint32_t* pNotification, std::shared_ptr<Notification> notify)
{
if (request.bytesRead) {
*request.bytesRead = 0;
}
auto conn = GetConnection(request.destAddr.netId);
if (!conn) {
return GLOBALERR_MISSING_ROUTE;
}
auto& port = ports[request.srcPort - Router::PORT_BASE];
const long status = conn->SendRequest(request, port.tmms);
if (!status) {
*pNotification = Beckhoff::letoh<uint32_t>(request.buffer);
auto dispatcher = conn->CreateNotifyMapping(*pNotification, notify);
port.AddNotification(request.destAddr, *pNotification, dispatcher);
}
return status;
}
long AmsRouter::DelNotification(uint16_t srcPort, const AmsAddr* pAddr, uint32_t hNotification)
{
auto& p = ports[srcPort - Router::PORT_BASE];
return p.DelNotification(*pAddr, hNotification);
}
void AmsRouter::AwaitConnectionAttempts(const AmsNetId& ams, std::unique_lock<std::recursive_mutex>& lock)
{
connection_attempt_events.wait(lock, [&]() { return connection_attempts.find(ams) == connection_attempts.end(); });
}
}
}