map/das-dn/third_party/AdsLib/Standalone/AmsConnection.cpp
2024-12-10 10:14:40 +08:00

384 lines
12 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// SPDX-License-Identifier: MIT
/**
Copyright (c) 2015 - 2022 Beckhoff Automation GmbH & Co. KG
*/
#include "AmsConnection.h"
#include "Log.h"
namespace Beckhoff
{
namespace Ads
{
AmsResponse::AmsResponse() : request(nullptr), errorCode(WAITING_FOR_RESPONSE)
{}
void AmsResponse::Notify(const uint32_t error)
{
std::unique_lock<std::mutex> lock(mutex);
errorCode = error;
cv.notify_all();
}
uint32_t AmsResponse::Wait()
{
std::unique_lock<std::mutex> lock(mutex);
cv.wait_until(lock, request.load()->deadline, [&]() { return !invokeId.load(); });
if (invokeId.exchange(0)) {
/* invokeId wasn't consumed -> AmsConnection::recv() didn't got a valid response until now */
return ADSERR_CLIENT_SYNCTIMEOUT;
}
/* AmsConnection::recv() is currently processing a response and using the user supplied buffer, we need to wait until that finished */
cv.wait(lock, [&]() { return errorCode != WAITING_FOR_RESPONSE; });
return errorCode;
}
SharedDispatcher AmsConnection::DispatcherListAdd(const VirtualConnection& connection)
{
const auto dispatcher = DispatcherListGet(connection);
if (dispatcher) {
return dispatcher;
}
//add new dispatcher for a connection, which includes a DeleteNotification
std::lock_guard<std::recursive_mutex> lock(dispatcherListMutex);
auto shared_disp = std::make_shared<NotificationDispatcher>(std::bind(&AmsConnection::DeleteNotification,
this,
connection.second,
std::placeholders::_1,
std::placeholders::_2,
connection.first));
return dispatcherList.emplace(connection, shared_disp).first->second;
}
SharedDispatcher AmsConnection::DispatcherListGet(const VirtualConnection& connection)
{
std::lock_guard<std::recursive_mutex> lock(dispatcherListMutex);
const auto it = dispatcherList.find(connection);
if (it != dispatcherList.end()) {
return it->second;
}
return {};
}
AmsConnection::AmsConnection(Router& __router, const struct addrinfo* const destination)
: router(__router),
socket(destination),
refCount(0),
invokeId(0)
{
if(socket.IsConnected())
{
localIp = socket.GetLocalSockAddr();
remoteIp = socket.GetHostSockAddr();
receiver = std::thread(&AmsConnection::TryRecv, this);
struct in_addr ss{htonl(remoteIp)};
LOG_INFO("Socket connect["<<std::string(inet_ntoa(ss))<<"] is done.");
}
}
AmsConnection::~AmsConnection()
{
if (socket.IsConnected())
{
socket.Shutdown();
receiver.join();
}
}
SharedDispatcher AmsConnection::CreateNotifyMapping(uint32_t hNotify, std::shared_ptr<Notification> notification)
{
auto dispatcher = DispatcherListAdd(notification->connection);
notification->hNotify(hNotify);
dispatcher->Emplace(hNotify, notification);
return dispatcher;
}
long AmsConnection::DeleteNotification(const AmsAddr& amsAddr, uint32_t hNotify, uint32_t tmms, uint16_t srcPort)
{
AmsRequest request {
amsAddr,
srcPort, AoEHeader::DEL_DEVICE_NOTIFICATION,
0, nullptr, nullptr,
sizeof(hNotify)
};
request.frame.prepend(Beckhoff::htole(hNotify));
return SendRequest(request, tmms);
}
bool AmsConnection::IsConnectedTo(const struct addrinfo* targetAddresses) const
{
return socket.IsConnectedTo(targetAddresses);
}
bool AmsConnection::IsConnected() const
{
if(socket.IsValid())
return socket.IsConnected();
else
return false;
}
AmsResponse* AmsConnection::Write(AmsRequest& request, const AmsAddr srcAddr)
{
const AoEHeader aoeHeader {
request.destAddr.netId, request.destAddr.port,
srcAddr.netId, srcAddr.port,
request.cmdId,
static_cast<uint32_t>(request.frame.size()),
GetInvokeId()
};
request.frame.prepend<AoEHeader>(aoeHeader);
const AmsTcpHeader header { static_cast<uint32_t>(request.frame.size()) };
request.frame.prepend<AmsTcpHeader>(header);
//
auto response = Reserve(&request, srcAddr.port);
if (!response) {
return nullptr;
}
response->invokeId.store(aoeHeader.invokeId());
if (request.frame.size() != socket.write(request.frame)) {
response->Release();
return nullptr;
}
return response;
}
long AmsConnection::SendRequest(AmsRequest& request, const uint32_t timeout)
{
if(IsConnected() == false)
return -1;
AmsAddr srcAddr;
const auto status = router.GetAmsAddr(request.srcPort, &srcAddr);
if (status) {
return status;
}
request.SetDeadline(timeout);
AmsResponse* response = Write(request, srcAddr);
if (response) {
const auto errorCode = response->Wait();
response->Release();
return errorCode;
}
return -1;
}
uint32_t AmsConnection::GetInvokeId()
{
invokeId.fetch_add(1);
return invokeId;
}
AmsResponse* AmsConnection::GetPending(const uint32_t id, const uint16_t srcPort)
{
const uint16_t portIndex = srcPort - Router::PORT_BASE;
if (portIndex >= Router::NUM_PORTS_MAX) {
LOG_WARN("Port 0x" << std::hex << srcPort << " is out of range");
return nullptr;
}
auto currentId = id;
if (queue[portIndex].invokeId.compare_exchange_strong(currentId, 0)) {
return &queue[portIndex];
}
return nullptr;
}
AmsResponse* AmsConnection::Reserve(AmsRequest* request, const uint16_t srcPort)
{
//如果对应端口没有AmsRequest指针说明该端口目前空闲把请求放置在该端口上
AmsRequest* isFree = nullptr;
if (!queue[srcPort - Router::PORT_BASE].request.compare_exchange_strong(isFree, request)) {
LOG_WARN("Port: " << srcPort << " already in use as " << isFree);
return nullptr;
}
return &queue[srcPort - Router::PORT_BASE];
}
void AmsResponse::Release()
{
errorCode = WAITING_FOR_RESPONSE;
request.store(nullptr);
}
void AmsConnection::Receive(void* buffer, size_t bytesToRead, timeval* timeout)
{
auto pos = reinterpret_cast<uint8_t*>(buffer);
while (bytesToRead) {
const size_t bytesRead = socket.read(pos, bytesToRead, timeout);
if (bytesRead == 0)
break;
bytesToRead -= bytesRead;
pos += bytesRead;
}
}
void AmsConnection::Receive(void* buffer, size_t bytesToRead, const Timepoint& deadline)
{
const auto now = std::chrono::steady_clock::now();
const auto usec = std::chrono::duration_cast<std::chrono::microseconds>(deadline - now).count();
if (usec <= 0) {
//throw Socket::TimeoutEx("deadline reached already!!!");
return;
}
timeval timeout {(long)(usec / 1000000), (int)(usec % 1000000)};
Receive(buffer, bytesToRead, &timeout);
}
void AmsConnection::ReceiveJunk(size_t bytesToRead)
{
uint8_t buffer[1024];
while (bytesToRead > sizeof(buffer)) {
Receive(buffer, sizeof(buffer));
bytesToRead -= sizeof(buffer);
}
Receive(buffer, bytesToRead);
}
template<class T>
void AmsConnection::ReceiveFrame(AmsResponse* const response, size_t bytesLeft, uint32_t aoeError)
{
AmsRequest* const request = response->request.load();
const auto responseId = response->invokeId.load();
T header;
if (aoeError) {
response->Notify(aoeError);
ReceiveJunk(bytesLeft);
return;
}
if (bytesLeft > sizeof(header) + request->bufferLength) {
LOG_WARN("Frame too long: " << std::dec << bytesLeft << '>' << sizeof(header) + request->bufferLength);
response->Notify(ADSERR_DEVICE_INVALIDSIZE);
ReceiveJunk(bytesLeft);
return;
}
try {
Receive(&header, sizeof(header), request->deadline);
bytesLeft -= sizeof(header);
Receive(request->buffer, bytesLeft, request->deadline);
if (request->bytesRead) {
// We already checked bytesLeft <= request->bufferLength
const auto v = static_cast<std::remove_pointer<decltype(AmsRequest::bytesRead)>::type>(bytesLeft);
*(request->bytesRead) = v;
}
response->Notify(header.result());
} catch (const Socket::TimeoutEx&) {
LOG_WARN("InvokeId of response: " << std::dec << responseId << " timed out");
response->Notify(ADSERR_CLIENT_SYNCTIMEOUT);
ReceiveJunk(bytesLeft);
}
}
bool AmsConnection::ReceiveNotification(const AoEHeader& header)
{
const auto dispatcher = DispatcherListGet(VirtualConnection { header.targetPort(), header.sourceAms() });
if (!dispatcher) {
ReceiveJunk(header.length());
LOG_WARN("No dispatcher found for notification");
return false;
}
auto& ring = dispatcher->ring;
auto bytesLeft = header.length();
if (bytesLeft + sizeof(bytesLeft) > ring.BytesFree()) {
ReceiveJunk(bytesLeft);
LOG_WARN("port " << std::dec << header.targetPort() << " receive buffer was full");
return false;
}
/** store AoEHeader.length() in ring buffer to support notification parsing */
for (size_t i = 0; i < sizeof(bytesLeft); ++i) {
*ring.write = (bytesLeft >> (8 * i)) & 0xFF;
ring.Write(1);
}
auto chunk = ring.WriteChunk();
while (bytesLeft > chunk) {
Receive(ring.write, chunk);
ring.Write(chunk);
// We already checked bytesLeft > chunk, well it was not obvious enough for MSVC
bytesLeft -= static_cast<decltype(bytesLeft)>(chunk);
chunk = ring.WriteChunk();
}
Receive(ring.write, bytesLeft);
ring.Write(bytesLeft);
//call the notify callback function
dispatcher->Notify();
return true;
}
void AmsConnection::TryRecv()
{
try {
Recv();
} catch (const std::runtime_error& e) {
LOG_INFO(e.what());
}
}
void AmsConnection::Recv()
{
AmsTcpHeader amsTcpHeader;
AoEHeader aoeHeader;
for ( ; IsConnected() && localIp; ) {
Receive(amsTcpHeader);
if (amsTcpHeader.length() < sizeof(aoeHeader)) {
LOG_WARN("Frame to short to be AoE");
ReceiveJunk(amsTcpHeader.length());
continue;
}
Receive(aoeHeader);
if (aoeHeader.cmdId() == AoEHeader::DEVICE_NOTIFICATION) {
ReceiveNotification(aoeHeader);
continue;
}
auto response = GetPending(aoeHeader.invokeId(), aoeHeader.targetPort());
if (!response) {
//LOG_WARN("No response pending");
ReceiveJunk(aoeHeader.length());
continue;
}
switch (aoeHeader.cmdId()) {
case AoEHeader::READ_DEVICE_INFO:
case AoEHeader::WRITE:
case AoEHeader::READ_STATE:
case AoEHeader::WRITE_CONTROL:
case AoEHeader::ADD_DEVICE_NOTIFICATION:
case AoEHeader::DEL_DEVICE_NOTIFICATION:
ReceiveFrame<AoEResponseHeader>(response, aoeHeader.length(), aoeHeader.errorCode());
continue;
case AoEHeader::READ:
case AoEHeader::READ_WRITE:
ReceiveFrame<AoEReadResponseHeader>(response, aoeHeader.length(), aoeHeader.errorCode());
continue;
default:
LOG_WARN("Unkown AMS command id");
response->Notify(ADSERR_CLIENT_SYNCRESINVALID);
ReceiveJunk(aoeHeader.length());
}
}
}
}
}