map/das-dn/third_party/AdsLib/Sockets.cpp
2024-12-26 14:54:03 +08:00

354 lines
9.8 KiB
C++

// SPDX-License-Identifier: MIT
/**
Copyright (c) 2015 - 2022 Beckhoff Automation GmbH & Co. KG
*/
#include "Sockets.h"
#include "Log.h"
#include <algorithm>
#include <cstring>
#include <exception>
#include <limits>
#include <sstream>
#include <system_error>
#include <iostream>
static const struct addrinfo addrinfo = []() {
struct addrinfo a;
memset(&a, 0, sizeof(a));
a.ai_family = AF_INET;
a.ai_socktype = SOCK_STREAM;
a.ai_protocol = IPPROTO_TCP;
return a;
} ();
namespace Beckhoff
{
/**
* Splits the provided host string into host and port. If no port was found
* in the host string, port is returned untouched acting as a default value.
*/
static void ParseHostAndPort(std::string& host, std::string& port)
{
if (host.empty()) {
return;
}
auto split = host.find_last_of(":");
if (host.find_first_of(":") != split) {
// more than one colon -> IPv6
const auto closingBracket = host.find_last_of("]");
if (closingBracket > split) {
// IPv6 without port
split = host.npos;
}
}
if (split != host.npos) {
port = host.substr(split + 1);
host.resize(split);
}
// remove brackets
if (*host.crbegin() == ']') {
host.pop_back();
}
if (*host.begin() == '[') {
host.erase(host.begin());
}
}
AddressList GetHostAddresses(const std::string& host, const std::string& port, bool* found)
{
auto node = std::string(host);
auto service = std::string(port);
ParseHostAndPort(node, service);
InitSocketLibrary();
struct addrinfo* results;
if (getaddrinfo(node.c_str(), service.c_str(), nullptr, &results)) {
//throw std::runtime_error("Invalid or unknown host: " + node);
std::cout<<"Invalid or unknown host: "<<node<<'\n';
if(found)
*found = false;
}
if(found)
*found = true;
return AddressList { results, [](struct addrinfo* p) { freeaddrinfo(p); }};
}
uint32_t getIpv4(const std::string& addr)
{
struct addrinfo* res;
InitSocketLibrary();
const auto status = getaddrinfo(addr.c_str(), nullptr, &addrinfo, &res);
if (status) {
throw std::runtime_error("Invalid IPv4 address or unknown hostname: " + addr);
}
const auto value = ((struct sockaddr_in*)res->ai_addr)->sin_addr.s_addr;
freeaddrinfo(res);
WSACleanup();
return ntohl(value);
}
IpV4::IpV4(const std::string& addr)
: value(getIpv4(addr))
{}
IpV4::IpV4(uint32_t __val)
: value(__val)
{}
bool IpV4::operator<(const IpV4& ref) const
{
return value < ref.value;
}
bool IpV4::operator==(const IpV4& ref) const
{
return value == ref.value;
}
Socket::Socket(const struct addrinfo* const host, const int type) :
m_WSAInitialized(!InitSocketLibrary()),
m_DestAddr(SOCK_DGRAM == type ? reinterpret_cast<const struct sockaddr*>(&m_SockAddress) : nullptr),
m_DestAddrLen(0),
m_LastError(0),
m_type(type),
m_Connected(false)
{
for (auto rp = host; rp; rp = rp->ai_next) {
m_Socket = socket(rp->ai_family, type, 0);
if (INVALID_SOCKET == m_Socket) {
continue;
}
if (SOCK_STREAM == type) {
if (::connect(m_Socket, rp->ai_addr, rp->ai_addrlen)) {
LOG_WARN("Socket connect["<<std::string(inet_ntoa(reinterpret_cast<sockaddr_in*>(rp->ai_addr)->sin_addr)) << "] timeout");
closesocket(m_Socket);
m_Socket = INVALID_SOCKET;
continue;
} else {
m_Connected = true;
m_HostAddr = *(reinterpret_cast<sockaddr_in*>(rp->ai_addr));
}
} else { /*if (SOCK_DGRAM == type)*/
#if defined(_WIN32) || defined(__CYGWIN__)
// MSVC on Windows is the only platform using different types for connect() and ai_addrlen ...
m_DestAddrLen = static_cast<decltype(m_DestAddrLen)>(rp->ai_addrlen);
#else
m_DestAddrLen = rp->ai_addrlen;
#endif
}
memcpy(&m_SockAddress, rp->ai_addr, std::min<size_t>(sizeof(m_SockAddress), rp->ai_addrlen));
return;
}
m_LastError = WSAGetLastError();
LOG_ERROR("Unable to create socket with error: " << std::strerror(m_LastError));
//throw std::system_error(WSAGetLastError(), std::system_category());
}
Socket::~Socket()
{
Shutdown();
// closesocket(m_Socket);
if (m_WSAInitialized) {
WSACleanup();
}
}
bool Socket::IsValid() const
{
return m_Socket != INVALID_SOCKET;
}
int Socket::GetError() const
{
return m_LastError;
}
bool Socket::IsConnected() const
{
return IsValid() && m_Connected;
}
void Socket::Shutdown()
{
if(IsValid())
{
shutdown(m_Socket, SHUT_RDWR);
closesocket(m_Socket);
m_Socket = INVALID_SOCKET;
m_Connected = false;
}
}
bool Socket::IsConnectedTo(const struct addrinfo* const targetAddresses) const
{
for (auto rp = targetAddresses; rp; rp = rp->ai_next) {
if (m_SockAddress.ss_family == rp->ai_family) {
if (!memcmp(&m_SockAddress, rp->ai_addr, std::min<size_t>(sizeof(m_SockAddress), rp->ai_addrlen))) {
return true;
}
}
}
return false;
}
size_t Socket::read(uint8_t* buffer, size_t maxBytes, timeval* timeout)
{
if (m_type == SOCK_STREAM) {
if(m_Connected == false || IsValid() == false) {
return 0;
}
}
if (!Select(timeout)) {
return 0;
}
const auto msvcMaxBytes = static_cast<int>(std::min<size_t>(std::numeric_limits<int>::max(), maxBytes));
const int bytesRead = recv(m_Socket, reinterpret_cast<char*>(buffer), msvcMaxBytes, 0);
if (bytesRead > 0) {
return bytesRead;
}
m_LastError = WSAGetLastError();
if ((0 == bytesRead) || (m_LastError == CONNECTION_CLOSED) || (m_LastError == CONNECTION_ABORTED)) {
LOG_INFO("read error, Socket: " << m_Socket << " connection closed by remote with error: " << std::dec << std::strerror(m_LastError));
} else {
LOG_ERROR("Socket read frame failed with error: " << std::dec << std::strerror(m_LastError));
}
Shutdown();
return 0;
}
Frame& Socket::read(Frame& frame, timeval* timeout)
{
const size_t bytesRead = read(frame.rawData(), frame.capacity(), timeout);
if (bytesRead) {
return frame.limit(bytesRead);
}
return frame.clear();
}
bool Socket::Select(timeval* timeout)
{
/* prepare socket set for select() */
fd_set readSockets;
FD_ZERO(&readSockets);
FD_SET(m_Socket, &readSockets);
if (!IsValid()) return false;
/* wait for receive data */
const int state = NATIVE_SELECT(m_Socket + 1, &readSockets, nullptr, nullptr, timeout);
if (0 == state) {
LOG_ERROR("Socket select() timeout.");
//throw TimeoutEx("select() timeout");
}
m_LastError = WSAGetLastError();
if (m_LastError == WSAENOTSOCK) {
//throw std::runtime_error("connection closed");
Shutdown();
}
/* and check if socket was correct */
if (1 != state) {
LOG_ERROR("Socket select something strange happen while waiting for socket in state: " << state << " with error: " << std::strerror(m_LastError));
return false;
}
if(!FD_ISSET(m_Socket, &readSockets))
{
return false;
}
return true;
}
size_t Socket::write(const Frame& frame)
{
if (frame.size() > std::numeric_limits<int>::max()) {
LOG_ERROR("Socket write frame length: " << frame.size() << " exceeds maximum length.");
return 0;
}
if (!IsValid()) return 0;
const int bufferLength = static_cast<int>(frame.size());
const char* const buffer = reinterpret_cast<const char*>(frame.data());
const int len = sendto(m_Socket, buffer, bufferLength, 0, m_DestAddr, m_DestAddrLen);
if (SOCKET_ERROR == len) {
LOG_ERROR("Socket write frame failed with error: " << std::strerror(WSAGetLastError()));
return 0;
}
if(0 == len){
LOG_INFO("write error, Socket: " << m_Socket << " connection closed by remote with error: " << std::dec << std::strerror(m_LastError));
Shutdown();
return 0;
}
//LOG_INFO("Socket frame is sent to remote host data length "<<len);
return len;
}
TcpSocket::TcpSocket(const struct addrinfo* const host) : Socket(host, SOCK_STREAM)
{
if(m_Connected == true){
// AdsDll.lib seems to use TCP_NODELAY, we use it to be compatible
const int enable = 0;
if (setsockopt(m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char*)&enable, sizeof(enable))) {
LOG_WARN("Socket enabling TCP_NODELAY failed.");
}
}
}
uint32_t TcpSocket::GetLocalSockAddr() const
{
struct sockaddr_storage source;
socklen_t len = sizeof(source);
if(m_Socket == INVALID_SOCKET)
LOG_WARN("SocketTcp is invalid.");
if (getsockname(m_Socket, reinterpret_cast<sockaddr*>(&source), &len)) {
LOG_ERROR("TcpSocket read local tcp/ip address failed.");
//throw std::runtime_error("Read local tcp/ip address failed");
}
switch (source.ss_family) {
case AF_INET:
return ntohl(reinterpret_cast<sockaddr_in*>(&source)->sin_addr.s_addr);
case AF_INET6:
return 0xffffffff;
default:
return 0;
}
}
uint32_t TcpSocket::GetHostSockAddr() const
{
switch (m_HostAddr.sin_family) {
case AF_INET:
return ntohl(m_HostAddr.sin_addr.s_addr);
case AF_INET6:
return 0xffffffff;
default:
return 0;
}
}
UdpSocket::UdpSocket(const struct addrinfo* const host)
: Socket(host, SOCK_DGRAM)
{}
}