// SPDX-License-Identifier: MIT /** Copyright (c) 2015 - 2022 Beckhoff Automation GmbH & Co. KG */ #include "AmsConnection.h" #include "Log.h" namespace Beckhoff { namespace Ads { AmsResponse::AmsResponse() : request(nullptr), errorCode(WAITING_FOR_RESPONSE) {} void AmsResponse::Notify(const uint32_t error) { std::unique_lock lock(mutex); errorCode = error; cv.notify_all(); } uint32_t AmsResponse::Wait() { std::unique_lock lock(mutex); cv.wait_until(lock, request.load()->deadline, [&]() { return !invokeId.load(); }); if (invokeId.exchange(0)) { /* invokeId wasn't consumed -> AmsConnection::recv() didn't got a valid response until now */ return ADSERR_CLIENT_SYNCTIMEOUT; } /* AmsConnection::recv() is currently processing a response and using the user supplied buffer, we need to wait until that finished */ cv.wait(lock, [&]() { return errorCode != WAITING_FOR_RESPONSE; }); return errorCode; } SharedDispatcher AmsConnection::DispatcherListAdd(const VirtualConnection& connection) { const auto dispatcher = DispatcherListGet(connection); if (dispatcher) { return dispatcher; } //add new dispatcher for a connection, which includes a DeleteNotification std::lock_guard lock(dispatcherListMutex); auto shared_disp = std::make_shared(std::bind(&AmsConnection::DeleteNotification, this, connection.second, std::placeholders::_1, std::placeholders::_2, connection.first)); return dispatcherList.emplace(connection, shared_disp).first->second; } SharedDispatcher AmsConnection::DispatcherListGet(const VirtualConnection& connection) { std::lock_guard lock(dispatcherListMutex); const auto it = dispatcherList.find(connection); if (it != dispatcherList.end()) { return it->second; } return {}; } AmsConnection::AmsConnection(Router& __router, const struct addrinfo* const destination) : router(__router), socket(destination), refCount(0), invokeId(0) { if(socket.IsConnected()) { localIp = socket.GetLocalSockAddr(); remoteIp = socket.GetHostSockAddr(); receiver = std::thread(&AmsConnection::TryRecv, this); struct in_addr ss{htonl(remoteIp)}; LOG_INFO("Socket connect["< notification) { auto dispatcher = DispatcherListAdd(notification->connection); notification->hNotify(hNotify); dispatcher->Emplace(hNotify, notification); return dispatcher; } long AmsConnection::DeleteNotification(const AmsAddr& amsAddr, uint32_t hNotify, uint32_t tmms, uint16_t srcPort) { AmsRequest request { amsAddr, srcPort, AoEHeader::DEL_DEVICE_NOTIFICATION, 0, nullptr, nullptr, sizeof(hNotify) }; request.frame.prepend(Beckhoff::htole(hNotify)); return SendRequest(request, tmms); } bool AmsConnection::IsConnectedTo(const struct addrinfo* targetAddresses) const { return socket.IsConnectedTo(targetAddresses); } bool AmsConnection::IsConnected() const { if(socket.IsValid()) return socket.IsConnected(); else return false; } AmsResponse* AmsConnection::Write(AmsRequest& request, const AmsAddr srcAddr) { const AoEHeader aoeHeader { request.destAddr.netId, request.destAddr.port, srcAddr.netId, srcAddr.port, request.cmdId, static_cast(request.frame.size()), GetInvokeId() }; request.frame.prepend(aoeHeader); const AmsTcpHeader header { static_cast(request.frame.size()) }; request.frame.prepend(header); // auto response = Reserve(&request, srcAddr.port); if (!response) { return nullptr; } response->invokeId.store(aoeHeader.invokeId()); if (request.frame.size() != socket.write(request.frame)) { response->Release(); return nullptr; } return response; } long AmsConnection::SendRequest(AmsRequest& request, const uint32_t timeout) { if(IsConnected() == false) return -1; AmsAddr srcAddr; const auto status = router.GetAmsAddr(request.srcPort, &srcAddr); if (status) { return status; } request.SetDeadline(timeout); AmsResponse* response = Write(request, srcAddr); if (response) { const auto errorCode = response->Wait(); response->Release(); return errorCode; } return -1; } uint32_t AmsConnection::GetInvokeId() { invokeId.fetch_add(1); return invokeId; } AmsResponse* AmsConnection::GetPending(const uint32_t id, const uint16_t srcPort) { const uint16_t portIndex = srcPort - Router::PORT_BASE; if (portIndex >= Router::NUM_PORTS_MAX) { LOG_WARN("Port 0x" << std::hex << srcPort << " is out of range"); return nullptr; } auto currentId = id; if (queue[portIndex].invokeId.compare_exchange_strong(currentId, 0)) { return &queue[portIndex]; } return nullptr; } AmsResponse* AmsConnection::Reserve(AmsRequest* request, const uint16_t srcPort) { //如果对应端口没有AmsRequest指针,说明该端口目前空闲,把请求放置在该端口上 AmsRequest* isFree = nullptr; if (!queue[srcPort - Router::PORT_BASE].request.compare_exchange_strong(isFree, request)) { LOG_WARN("Port: " << srcPort << " already in use as " << isFree); return nullptr; } return &queue[srcPort - Router::PORT_BASE]; } void AmsResponse::Release() { errorCode = WAITING_FOR_RESPONSE; request.store(nullptr); } void AmsConnection::Receive(void* buffer, size_t bytesToRead, timeval* timeout) { auto pos = reinterpret_cast(buffer); while (bytesToRead) { const size_t bytesRead = socket.read(pos, bytesToRead, timeout); if (bytesRead == 0) break; bytesToRead -= bytesRead; pos += bytesRead; } } void AmsConnection::Receive(void* buffer, size_t bytesToRead, const Timepoint& deadline) { const auto now = std::chrono::steady_clock::now(); const auto usec = std::chrono::duration_cast(deadline - now).count(); if (usec <= 0) { //throw Socket::TimeoutEx("deadline reached already!!!"); return; } timeval timeout {(long)(usec / 1000000), (int)(usec % 1000000)}; Receive(buffer, bytesToRead, &timeout); } void AmsConnection::ReceiveJunk(size_t bytesToRead) { uint8_t buffer[1024]; while (bytesToRead > sizeof(buffer)) { Receive(buffer, sizeof(buffer)); bytesToRead -= sizeof(buffer); } Receive(buffer, bytesToRead); } template void AmsConnection::ReceiveFrame(AmsResponse* const response, size_t bytesLeft, uint32_t aoeError) { AmsRequest* const request = response->request.load(); const auto responseId = response->invokeId.load(); T header; if (aoeError) { response->Notify(aoeError); ReceiveJunk(bytesLeft); return; } if (bytesLeft > sizeof(header) + request->bufferLength) { LOG_WARN("Frame too long: " << std::dec << bytesLeft << '>' << sizeof(header) + request->bufferLength); response->Notify(ADSERR_DEVICE_INVALIDSIZE); ReceiveJunk(bytesLeft); return; } try { Receive(&header, sizeof(header), request->deadline); bytesLeft -= sizeof(header); Receive(request->buffer, bytesLeft, request->deadline); if (request->bytesRead) { // We already checked bytesLeft <= request->bufferLength const auto v = static_cast::type>(bytesLeft); *(request->bytesRead) = v; } response->Notify(header.result()); } catch (const Socket::TimeoutEx&) { LOG_WARN("InvokeId of response: " << std::dec << responseId << " timed out"); response->Notify(ADSERR_CLIENT_SYNCTIMEOUT); ReceiveJunk(bytesLeft); } } bool AmsConnection::ReceiveNotification(const AoEHeader& header) { const auto dispatcher = DispatcherListGet(VirtualConnection { header.targetPort(), header.sourceAms() }); if (!dispatcher) { ReceiveJunk(header.length()); LOG_WARN("No dispatcher found for notification"); return false; } auto& ring = dispatcher->ring; auto bytesLeft = header.length(); if (bytesLeft + sizeof(bytesLeft) > ring.BytesFree()) { ReceiveJunk(bytesLeft); LOG_WARN("port " << std::dec << header.targetPort() << " receive buffer was full"); return false; } /** store AoEHeader.length() in ring buffer to support notification parsing */ for (size_t i = 0; i < sizeof(bytesLeft); ++i) { *ring.write = (bytesLeft >> (8 * i)) & 0xFF; ring.Write(1); } auto chunk = ring.WriteChunk(); while (bytesLeft > chunk) { Receive(ring.write, chunk); ring.Write(chunk); // We already checked bytesLeft > chunk, well it was not obvious enough for MSVC bytesLeft -= static_cast(chunk); chunk = ring.WriteChunk(); } Receive(ring.write, bytesLeft); ring.Write(bytesLeft); //call the notify callback function dispatcher->Notify(); return true; } void AmsConnection::TryRecv() { try { Recv(); } catch (const std::runtime_error& e) { LOG_INFO(e.what()); } } void AmsConnection::Recv() { AmsTcpHeader amsTcpHeader; AoEHeader aoeHeader; for ( ; IsConnected() && localIp; ) { Receive(amsTcpHeader); if (amsTcpHeader.length() < sizeof(aoeHeader)) { LOG_WARN("Frame to short to be AoE"); ReceiveJunk(amsTcpHeader.length()); continue; } Receive(aoeHeader); if (aoeHeader.cmdId() == AoEHeader::DEVICE_NOTIFICATION) { ReceiveNotification(aoeHeader); continue; } auto response = GetPending(aoeHeader.invokeId(), aoeHeader.targetPort()); if (!response) { //LOG_WARN("No response pending"); ReceiveJunk(aoeHeader.length()); continue; } switch (aoeHeader.cmdId()) { case AoEHeader::READ_DEVICE_INFO: case AoEHeader::WRITE: case AoEHeader::READ_STATE: case AoEHeader::WRITE_CONTROL: case AoEHeader::ADD_DEVICE_NOTIFICATION: case AoEHeader::DEL_DEVICE_NOTIFICATION: ReceiveFrame(response, aoeHeader.length(), aoeHeader.errorCode()); continue; case AoEHeader::READ: case AoEHeader::READ_WRITE: ReceiveFrame(response, aoeHeader.length(), aoeHeader.errorCode()); continue; default: LOG_WARN("Unkown AMS command id"); response->Notify(ADSERR_CLIENT_SYNCRESINVALID); ReceiveJunk(aoeHeader.length()); } } } } }