map/das-dn/third_party/AdsLib/Sockets.cpp

283 lines
8.1 KiB
C++
Raw Normal View History

2024-12-03 10:36:06 +08:00
// SPDX-License-Identifier: MIT
/**
Copyright (c) 2015 - 2022 Beckhoff Automation GmbH & Co. KG
*/
#include "Sockets.h"
#include "Log.h"
#include <algorithm>
#include <cstring>
#include <exception>
#include <limits>
#include <sstream>
#include <system_error>
2024-12-05 11:04:47 +08:00
namespace bhf
{
namespace ads
2024-12-03 10:36:06 +08:00
{
/**
* Splits the provided host string into host and port. If no port was found
* in the host string, port is returned untouched acting as a default value.
*/
static void ParseHostAndPort(std::string& host, std::string& port)
{
if (host.empty()) {
return;
}
auto split = host.find_last_of(":");
if (host.find_first_of(":") != split) {
// more than one colon -> IPv6
const auto closingBracket = host.find_last_of("]");
if (closingBracket > split) {
// IPv6 without port
split = host.npos;
}
}
if (split != host.npos) {
port = host.substr(split + 1);
host.resize(split);
}
// remove brackets
if (*host.crbegin() == ']') {
host.pop_back();
}
if (*host.begin() == '[') {
host.erase(host.begin());
}
}
2024-12-05 11:04:47 +08:00
AddressList GetListOfAddresses(const std::string& hostPort, const std::string& defaultPort)
2024-12-03 10:36:06 +08:00
{
2024-12-05 11:04:47 +08:00
auto host = std::string(hostPort);
auto service = std::string(defaultPort);
ParseHostAndPort(host, service);
2024-12-03 10:36:06 +08:00
InitSocketLibrary();
struct addrinfo* results;
2024-12-05 11:04:47 +08:00
if (getaddrinfo(host.c_str(), service.c_str(), nullptr, &results)) {
throw std::runtime_error("Invalid or unknown host: " + host);
2024-12-03 10:36:06 +08:00
}
return AddressList { results, [](struct addrinfo* p) { freeaddrinfo(p); }};
}
2024-12-05 11:04:47 +08:00
}
}
static const struct addrinfo addrinfo = []() {
struct addrinfo a;
memset(&a, 0, sizeof(a));
a.ai_family = AF_INET;
a.ai_socktype = SOCK_STREAM;
a.ai_protocol = IPPROTO_TCP;
return a;
} ();
2024-12-03 10:36:06 +08:00
uint32_t getIpv4(const std::string& addr)
{
struct addrinfo* res;
InitSocketLibrary();
const auto status = getaddrinfo(addr.c_str(), nullptr, &addrinfo, &res);
if (status) {
throw std::runtime_error("Invalid IPv4 address or unknown hostname: " + addr);
}
const auto value = ((struct sockaddr_in*)res->ai_addr)->sin_addr.s_addr;
freeaddrinfo(res);
WSACleanup();
return ntohl(value);
}
IpV4::IpV4(const std::string& addr)
: value(getIpv4(addr))
{}
IpV4::IpV4(uint32_t __val)
: value(__val)
{}
bool IpV4::operator<(const IpV4& ref) const
{
return value < ref.value;
}
bool IpV4::operator==(const IpV4& ref) const
{
return value == ref.value;
}
Socket::Socket(const struct addrinfo* const host, const int type)
: m_WSAInitialized(!InitSocketLibrary()),
m_DestAddr(SOCK_DGRAM == type ? reinterpret_cast<const struct sockaddr*>(&m_SockAddress) : nullptr),
2024-12-05 11:04:47 +08:00
m_DestAddrLen(0)
2024-12-03 10:36:06 +08:00
{
for (auto rp = host; rp; rp = rp->ai_next) {
m_Socket = socket(rp->ai_family, type, 0);
if (INVALID_SOCKET == m_Socket) {
continue;
}
if (SOCK_STREAM == type) {
if (::connect(m_Socket, rp->ai_addr, rp->ai_addrlen)) {
2024-12-05 11:04:47 +08:00
LOG_WARN("Socket(): connect failed");
2024-12-03 10:36:06 +08:00
closesocket(m_Socket);
m_Socket = INVALID_SOCKET;
continue;
2024-12-05 11:04:47 +08:00
}
2024-12-03 10:36:06 +08:00
} else { /*if (SOCK_DGRAM == type)*/
2024-12-05 11:04:47 +08:00
#if defined(_WIN32) || defined(__CYGWIN__)
// MSVC on Windows is the only platform using different types for connect() and ai_addrlen ...
m_DestAddrLen = static_cast<decltype(m_DestAddrLen)>(rp->ai_addrlen);
#else
m_DestAddrLen = rp->ai_addrlen;
#endif
2024-12-03 10:36:06 +08:00
}
memcpy(&m_SockAddress, rp->ai_addr, std::min<size_t>(sizeof(m_SockAddress), rp->ai_addrlen));
return;
}
2024-12-05 11:04:47 +08:00
LOG_ERROR("Unable to create socket");
throw std::system_error(WSAGetLastError(), std::system_category());
2024-12-03 10:36:06 +08:00
}
Socket::~Socket()
{
Shutdown();
closesocket(m_Socket);
if (m_WSAInitialized) {
WSACleanup();
}
}
void Socket::Shutdown()
{
2024-12-05 11:04:47 +08:00
shutdown(m_Socket, SHUT_RDWR);
2024-12-03 10:36:06 +08:00
}
2024-12-05 11:04:47 +08:00
size_t Socket::read(uint8_t* buffer, size_t maxBytes, timeval* timeout) const
2024-12-03 10:36:06 +08:00
{
if (!Select(timeout)) {
return 0;
}
const auto msvcMaxBytes = static_cast<int>(std::min<size_t>(std::numeric_limits<int>::max(), maxBytes));
const int bytesRead = recv(m_Socket, reinterpret_cast<char*>(buffer), msvcMaxBytes, 0);
if (bytesRead > 0) {
return bytesRead;
}
2024-12-05 11:04:47 +08:00
const auto lastError = WSAGetLastError();
if ((0 == bytesRead) || (lastError == CONNECTION_CLOSED) || (lastError == CONNECTION_ABORTED)) {
throw std::runtime_error("connection closed by remote");
2024-12-03 10:36:06 +08:00
} else {
2024-12-05 11:04:47 +08:00
LOG_ERROR("read frame failed with error: " << std::dec << std::strerror(lastError));
2024-12-03 10:36:06 +08:00
}
return 0;
}
2024-12-05 11:04:47 +08:00
Frame& Socket::read(Frame& frame, timeval* timeout) const
2024-12-03 10:36:06 +08:00
{
const size_t bytesRead = read(frame.rawData(), frame.capacity(), timeout);
if (bytesRead) {
return frame.limit(bytesRead);
}
return frame.clear();
}
2024-12-05 11:04:47 +08:00
bool Socket::Select(timeval* timeout) const
2024-12-03 10:36:06 +08:00
{
/* prepare socket set for select() */
fd_set readSockets;
FD_ZERO(&readSockets);
FD_SET(m_Socket, &readSockets);
/* wait for receive data */
const int state = NATIVE_SELECT(m_Socket + 1, &readSockets, nullptr, nullptr, timeout);
if (0 == state) {
2024-12-05 11:04:47 +08:00
LOG_ERROR("select() timeout");
throw TimeoutEx("select() timeout");
2024-12-03 10:36:06 +08:00
}
2024-12-05 11:04:47 +08:00
const auto lastError = WSAGetLastError();
if (lastError == WSAENOTSOCK) {
throw std::runtime_error("connection closed");
2024-12-03 10:36:06 +08:00
}
/* and check if socket was correct */
2024-12-05 11:04:47 +08:00
if ((1 != state) || (!FD_ISSET(m_Socket, &readSockets))) {
LOG_ERROR("something strange happen while waiting for socket in state: " <<
state << " with error: " << std::strerror(lastError));
2024-12-03 10:36:06 +08:00
return false;
}
return true;
}
2024-12-05 11:04:47 +08:00
size_t Socket::write(const Frame& frame) const
2024-12-03 10:36:06 +08:00
{
if (frame.size() > std::numeric_limits<int>::max()) {
2024-12-05 11:04:47 +08:00
LOG_ERROR("frame length: " << frame.size() << " exceeds maximum length for sockets");
2024-12-03 10:36:06 +08:00
return 0;
}
2024-12-05 20:23:37 +08:00
if (m_Socket == INVALID_SOCKET) return 0;
if (m_DestAddr == 0) return 0;
if (m_DestAddrLen <= 0) return 0;
2024-12-03 10:36:06 +08:00
const int bufferLength = static_cast<int>(frame.size());
const char* const buffer = reinterpret_cast<const char*>(frame.data());
2024-12-05 11:04:47 +08:00
const int status = sendto(m_Socket, buffer, bufferLength, 0, m_DestAddr, m_DestAddrLen);
2024-12-03 10:36:06 +08:00
2024-12-05 11:04:47 +08:00
if (SOCKET_ERROR == status) {
LOG_ERROR("write frame failed with error: " << std::strerror(WSAGetLastError()));
2024-12-03 10:36:06 +08:00
return 0;
}
2024-12-05 11:04:47 +08:00
return status;
2024-12-03 10:36:06 +08:00
}
2024-12-05 11:04:47 +08:00
TcpSocket::TcpSocket(const struct addrinfo* const host)
: Socket(host, SOCK_STREAM)
2024-12-03 10:36:06 +08:00
{
2024-12-05 11:04:47 +08:00
// AdsDll.lib seems to use TCP_NODELAY, we use it to be compatible
const int enable = 0;
if (setsockopt(m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char*)&enable, sizeof(enable))) {
LOG_WARN("Enabling TCP_NODELAY failed");
2024-12-03 10:36:06 +08:00
}
}
2024-12-05 11:04:47 +08:00
uint32_t TcpSocket::Connect() const
2024-12-03 10:36:06 +08:00
{
struct sockaddr_storage source;
socklen_t len = sizeof(source);
if (getsockname(m_Socket, reinterpret_cast<sockaddr*>(&source), &len)) {
2024-12-05 11:04:47 +08:00
LOG_ERROR("Read local tcp/ip address failed");
throw std::runtime_error("Read local tcp/ip address failed");
2024-12-03 10:36:06 +08:00
}
switch (source.ss_family) {
case AF_INET:
return ntohl(reinterpret_cast<sockaddr_in*>(&source)->sin_addr.s_addr);
case AF_INET6:
return 0xffffffff;
default:
return 0;
}
}
2024-12-05 11:04:47 +08:00
bool TcpSocket::IsConnectedTo(const struct addrinfo* const targetAddresses) const
2024-12-03 10:36:06 +08:00
{
2024-12-05 11:04:47 +08:00
for (auto rp = targetAddresses; rp; rp = rp->ai_next) {
if (m_SockAddress.ss_family == rp->ai_family) {
if (!memcmp(&m_SockAddress, rp->ai_addr, std::min<size_t>(sizeof(m_SockAddress), rp->ai_addrlen))) {
return true;
}
}
2024-12-03 10:36:06 +08:00
}
2024-12-05 11:04:47 +08:00
return false;
2024-12-03 10:36:06 +08:00
}
UdpSocket::UdpSocket(const struct addrinfo* const host)
: Socket(host, SOCK_DGRAM)
{}