map/das-dn/third_party/AdsLib/Standalone/AmsConnection.cpp

389 lines
12 KiB
C++
Raw Normal View History

2024-12-03 10:36:06 +08:00
// SPDX-License-Identifier: MIT
/**
Copyright (c) 2015 - 2022 Beckhoff Automation GmbH & Co. KG
*/
#include "AmsConnection.h"
#include "Log.h"
2024-12-09 09:41:04 +08:00
namespace Beckhoff
{
namespace Ads
{
AmsResponse::AmsResponse() : request(nullptr), errorCode(WAITING_FOR_RESPONSE)
2024-12-03 10:36:06 +08:00
{}
void AmsResponse::Notify(const uint32_t error)
{
std::unique_lock<std::mutex> lock(mutex);
errorCode = error;
cv.notify_all();
}
uint32_t AmsResponse::Wait()
{
std::unique_lock<std::mutex> lock(mutex);
cv.wait_until(lock, request.load()->deadline, [&]() { return !invokeId.load(); });
if (invokeId.exchange(0)) {
/* invokeId wasn't consumed -> AmsConnection::recv() didn't got a valid response until now */
return ADSERR_CLIENT_SYNCTIMEOUT;
}
/* AmsConnection::recv() is currently processing a response and using the user supplied buffer, we need to wait until that finished */
cv.wait(lock, [&]() { return errorCode != WAITING_FOR_RESPONSE; });
return errorCode;
}
SharedDispatcher AmsConnection::DispatcherListAdd(const VirtualConnection& connection)
{
const auto dispatcher = DispatcherListGet(connection);
if (dispatcher) {
return dispatcher;
}
2024-12-09 09:41:04 +08:00
//add new dispatcher for a connection, which includes a DeleteNotification
2024-12-03 10:36:06 +08:00
std::lock_guard<std::recursive_mutex> lock(dispatcherListMutex);
2024-12-09 09:41:04 +08:00
auto shared_disp = std::make_shared<NotificationDispatcher>(std::bind(&AmsConnection::DeleteNotification,
2024-12-03 10:36:06 +08:00
this,
connection.second,
std::placeholders::_1,
std::placeholders::_2,
2024-12-09 09:41:04 +08:00
connection.first));
return dispatcherList.emplace(connection, shared_disp).first->second;
2024-12-03 10:36:06 +08:00
}
SharedDispatcher AmsConnection::DispatcherListGet(const VirtualConnection& connection)
{
std::lock_guard<std::recursive_mutex> lock(dispatcherListMutex);
const auto it = dispatcherList.find(connection);
if (it != dispatcherList.end()) {
return it->second;
}
return {};
}
AmsConnection::AmsConnection(Router& __router, const struct addrinfo* const destination)
: router(__router),
socket(destination),
refCount(0),
2024-12-09 09:41:04 +08:00
invokeId(0)
2024-12-03 10:36:06 +08:00
{
2024-12-09 09:41:04 +08:00
if(socket.IsConnected())
{
localIp = socket.GetLocalSockAddr();
remoteIp = socket.GetHostSockAddr();
receiver = std::thread(&AmsConnection::TryRecv, this);
struct in_addr ss{htonl(remoteIp)};
LOG_INFO("Socket connect["<<std::string(inet_ntoa(ss))<<"] is done.");
}
2024-12-03 10:36:06 +08:00
}
AmsConnection::~AmsConnection()
{
2024-12-10 10:14:40 +08:00
if (socket.IsConnected())
{
2024-12-16 13:03:34 +08:00
socket.Shutdown();
2024-12-10 10:14:40 +08:00
receiver.join();
}
2024-12-16 13:03:34 +08:00
#if 0
2024-12-16 10:19:51 +08:00
if (socket.IsValid())
{
socket.Shutdown();
}
2024-12-16 13:03:34 +08:00
#endif
2024-12-03 10:36:06 +08:00
}
SharedDispatcher AmsConnection::CreateNotifyMapping(uint32_t hNotify, std::shared_ptr<Notification> notification)
{
auto dispatcher = DispatcherListAdd(notification->connection);
notification->hNotify(hNotify);
dispatcher->Emplace(hNotify, notification);
return dispatcher;
}
2024-12-09 09:41:04 +08:00
long AmsConnection::DeleteNotification(const AmsAddr& amsAddr, uint32_t hNotify, uint32_t tmms, uint16_t srcPort)
2024-12-03 10:36:06 +08:00
{
AmsRequest request {
amsAddr,
2024-12-09 09:41:04 +08:00
srcPort, AoEHeader::DEL_DEVICE_NOTIFICATION,
2024-12-03 10:36:06 +08:00
0, nullptr, nullptr,
sizeof(hNotify)
};
2024-12-09 09:41:04 +08:00
request.frame.prepend(Beckhoff::htole(hNotify));
return SendRequest(request, tmms);
2024-12-03 10:36:06 +08:00
}
bool AmsConnection::IsConnectedTo(const struct addrinfo* targetAddresses) const
{
return socket.IsConnectedTo(targetAddresses);
}
2024-12-09 09:41:04 +08:00
bool AmsConnection::IsConnected() const
{
if(socket.IsValid())
return socket.IsConnected();
else
return false;
}
2024-12-03 10:36:06 +08:00
AmsResponse* AmsConnection::Write(AmsRequest& request, const AmsAddr srcAddr)
{
const AoEHeader aoeHeader {
request.destAddr.netId, request.destAddr.port,
srcAddr.netId, srcAddr.port,
request.cmdId,
static_cast<uint32_t>(request.frame.size()),
GetInvokeId()
};
request.frame.prepend<AoEHeader>(aoeHeader);
const AmsTcpHeader header { static_cast<uint32_t>(request.frame.size()) };
request.frame.prepend<AmsTcpHeader>(header);
2024-12-09 09:41:04 +08:00
//
2024-12-03 10:36:06 +08:00
auto response = Reserve(&request, srcAddr.port);
if (!response) {
return nullptr;
}
response->invokeId.store(aoeHeader.invokeId());
if (request.frame.size() != socket.write(request.frame)) {
response->Release();
return nullptr;
}
return response;
}
2024-12-09 09:41:04 +08:00
long AmsConnection::SendRequest(AmsRequest& request, const uint32_t timeout)
2024-12-03 10:36:06 +08:00
{
2024-12-11 10:30:30 +08:00
if (IsConnected() == false) return -1;
2024-12-09 09:41:04 +08:00
2024-12-03 10:36:06 +08:00
AmsAddr srcAddr;
2024-12-09 09:41:04 +08:00
const auto status = router.GetAmsAddr(request.srcPort, &srcAddr);
2024-12-03 10:36:06 +08:00
if (status) {
return status;
}
request.SetDeadline(timeout);
AmsResponse* response = Write(request, srcAddr);
if (response) {
const auto errorCode = response->Wait();
response->Release();
return errorCode;
}
return -1;
}
uint32_t AmsConnection::GetInvokeId()
{
2024-12-09 09:41:04 +08:00
invokeId.fetch_add(1);
return invokeId;
2024-12-03 10:36:06 +08:00
}
2024-12-09 09:41:04 +08:00
AmsResponse* AmsConnection::GetPending(const uint32_t id, const uint16_t srcPort)
2024-12-03 10:36:06 +08:00
{
2024-12-09 09:41:04 +08:00
const uint16_t portIndex = srcPort - Router::PORT_BASE;
2024-12-03 10:36:06 +08:00
if (portIndex >= Router::NUM_PORTS_MAX) {
2024-12-09 09:41:04 +08:00
LOG_WARN("Port 0x" << std::hex << srcPort << " is out of range");
2024-12-03 10:36:06 +08:00
return nullptr;
}
auto currentId = id;
if (queue[portIndex].invokeId.compare_exchange_strong(currentId, 0)) {
return &queue[portIndex];
}
return nullptr;
}
2024-12-09 09:41:04 +08:00
AmsResponse* AmsConnection::Reserve(AmsRequest* request, const uint16_t srcPort)
2024-12-03 10:36:06 +08:00
{
2024-12-09 09:41:04 +08:00
//如果对应端口没有AmsRequest指针说明该端口目前空闲把请求放置在该端口上
2024-12-03 10:36:06 +08:00
AmsRequest* isFree = nullptr;
2024-12-09 09:41:04 +08:00
if (!queue[srcPort - Router::PORT_BASE].request.compare_exchange_strong(isFree, request)) {
LOG_WARN("Port: " << srcPort << " already in use as " << isFree);
2024-12-03 10:36:06 +08:00
return nullptr;
}
2024-12-09 09:41:04 +08:00
return &queue[srcPort - Router::PORT_BASE];
2024-12-03 10:36:06 +08:00
}
void AmsResponse::Release()
{
errorCode = WAITING_FOR_RESPONSE;
request.store(nullptr);
}
2024-12-09 09:41:04 +08:00
void AmsConnection::Receive(void* buffer, size_t bytesToRead, timeval* timeout)
2024-12-03 10:36:06 +08:00
{
auto pos = reinterpret_cast<uint8_t*>(buffer);
while (bytesToRead) {
const size_t bytesRead = socket.read(pos, bytesToRead, timeout);
2024-12-09 09:41:04 +08:00
if (bytesRead == 0)
break;
2024-12-03 10:36:06 +08:00
bytesToRead -= bytesRead;
pos += bytesRead;
}
}
2024-12-09 09:41:04 +08:00
void AmsConnection::Receive(void* buffer, size_t bytesToRead, const Timepoint& deadline)
2024-12-03 10:36:06 +08:00
{
const auto now = std::chrono::steady_clock::now();
const auto usec = std::chrono::duration_cast<std::chrono::microseconds>(deadline - now).count();
if (usec <= 0) {
2024-12-09 09:41:04 +08:00
//throw Socket::TimeoutEx("deadline reached already!!!");
return;
2024-12-03 10:36:06 +08:00
}
timeval timeout {(long)(usec / 1000000), (int)(usec % 1000000)};
Receive(buffer, bytesToRead, &timeout);
}
2024-12-09 09:41:04 +08:00
void AmsConnection::ReceiveJunk(size_t bytesToRead)
2024-12-03 10:36:06 +08:00
{
uint8_t buffer[1024];
while (bytesToRead > sizeof(buffer)) {
Receive(buffer, sizeof(buffer));
bytesToRead -= sizeof(buffer);
}
Receive(buffer, bytesToRead);
}
template<class T>
2024-12-09 09:41:04 +08:00
void AmsConnection::ReceiveFrame(AmsResponse* const response, size_t bytesLeft, uint32_t aoeError)
2024-12-03 10:36:06 +08:00
{
AmsRequest* const request = response->request.load();
const auto responseId = response->invokeId.load();
T header;
if (aoeError) {
response->Notify(aoeError);
ReceiveJunk(bytesLeft);
return;
}
if (bytesLeft > sizeof(header) + request->bufferLength) {
LOG_WARN("Frame too long: " << std::dec << bytesLeft << '>' << sizeof(header) + request->bufferLength);
response->Notify(ADSERR_DEVICE_INVALIDSIZE);
ReceiveJunk(bytesLeft);
return;
}
try {
Receive(&header, sizeof(header), request->deadline);
bytesLeft -= sizeof(header);
Receive(request->buffer, bytesLeft, request->deadline);
if (request->bytesRead) {
// We already checked bytesLeft <= request->bufferLength
const auto v = static_cast<std::remove_pointer<decltype(AmsRequest::bytesRead)>::type>(bytesLeft);
*(request->bytesRead) = v;
}
response->Notify(header.result());
} catch (const Socket::TimeoutEx&) {
LOG_WARN("InvokeId of response: " << std::dec << responseId << " timed out");
response->Notify(ADSERR_CLIENT_SYNCTIMEOUT);
ReceiveJunk(bytesLeft);
}
}
bool AmsConnection::ReceiveNotification(const AoEHeader& header)
{
const auto dispatcher = DispatcherListGet(VirtualConnection { header.targetPort(), header.sourceAms() });
if (!dispatcher) {
ReceiveJunk(header.length());
LOG_WARN("No dispatcher found for notification");
return false;
}
auto& ring = dispatcher->ring;
auto bytesLeft = header.length();
if (bytesLeft + sizeof(bytesLeft) > ring.BytesFree()) {
ReceiveJunk(bytesLeft);
LOG_WARN("port " << std::dec << header.targetPort() << " receive buffer was full");
return false;
}
/** store AoEHeader.length() in ring buffer to support notification parsing */
for (size_t i = 0; i < sizeof(bytesLeft); ++i) {
*ring.write = (bytesLeft >> (8 * i)) & 0xFF;
ring.Write(1);
}
auto chunk = ring.WriteChunk();
while (bytesLeft > chunk) {
Receive(ring.write, chunk);
ring.Write(chunk);
// We already checked bytesLeft > chunk, well it was not obvious enough for MSVC
bytesLeft -= static_cast<decltype(bytesLeft)>(chunk);
chunk = ring.WriteChunk();
}
Receive(ring.write, bytesLeft);
ring.Write(bytesLeft);
2024-12-09 09:41:04 +08:00
//call the notify callback function
2024-12-03 10:36:06 +08:00
dispatcher->Notify();
return true;
}
void AmsConnection::TryRecv()
{
try {
Recv();
} catch (const std::runtime_error& e) {
LOG_INFO(e.what());
}
}
void AmsConnection::Recv()
{
AmsTcpHeader amsTcpHeader;
AoEHeader aoeHeader;
2024-12-09 09:41:04 +08:00
for ( ; IsConnected() && localIp; ) {
2024-12-03 10:36:06 +08:00
Receive(amsTcpHeader);
if (amsTcpHeader.length() < sizeof(aoeHeader)) {
LOG_WARN("Frame to short to be AoE");
ReceiveJunk(amsTcpHeader.length());
continue;
}
Receive(aoeHeader);
if (aoeHeader.cmdId() == AoEHeader::DEVICE_NOTIFICATION) {
ReceiveNotification(aoeHeader);
continue;
}
auto response = GetPending(aoeHeader.invokeId(), aoeHeader.targetPort());
if (!response) {
2024-12-09 09:41:04 +08:00
//LOG_WARN("No response pending");
2024-12-03 10:36:06 +08:00
ReceiveJunk(aoeHeader.length());
continue;
}
switch (aoeHeader.cmdId()) {
case AoEHeader::READ_DEVICE_INFO:
case AoEHeader::WRITE:
case AoEHeader::READ_STATE:
case AoEHeader::WRITE_CONTROL:
case AoEHeader::ADD_DEVICE_NOTIFICATION:
case AoEHeader::DEL_DEVICE_NOTIFICATION:
ReceiveFrame<AoEResponseHeader>(response, aoeHeader.length(), aoeHeader.errorCode());
continue;
case AoEHeader::READ:
case AoEHeader::READ_WRITE:
ReceiveFrame<AoEReadResponseHeader>(response, aoeHeader.length(), aoeHeader.errorCode());
continue;
default:
LOG_WARN("Unkown AMS command id");
response->Notify(ADSERR_CLIENT_SYNCRESINVALID);
ReceiveJunk(aoeHeader.length());
}
}
}
2024-12-09 09:41:04 +08:00
}
}