Files
LNXRNT/Sources/socket_bridge.cpp

948 lines
30 KiB
C++
Raw Normal View History

2026-02-20 23:40:15 -08:00
#include "socket_bridge.h"
#include <kinc/log.h>
#include <map>
#include <memory>
#include <string>
#include <vector>
#ifdef _WIN32
#include <winsock2.h>
#include <ws2tcpip.h>
#pragma comment(lib, "ws2_32.lib")
#pragma comment(lib, "wsock32.lib")
#ifdef WITH_SSL
#ifdef _WIN32
#define SECURITY_WIN32
#include <wincrypt.h>
#include <schannel.h>
#include <security.h>
#include <sspi.h>
#pragma comment(lib, "crypt32.lib")
#pragma comment(lib, "secur32.lib")
#endif
#endif
#else
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#ifdef WITH_SSL
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#endif
#endif
class RunTSocket {
public:
int id;
bool isBlocking;
bool isConnected;
bool isBound;
bool isListening;
std::string lastError;
#ifdef _WIN32
SOCKET socket_fd;
#else
int socket_fd;
#endif
#ifdef WITH_SSL
#ifdef _WIN32
// windows SChannel SSL support with opaque pointers
void* ssl_cred_handle_;
void* ssl_context_handle_;
bool ssl_context_initialized_;
std::vector<char> ssl_buffer_;
#else
SSL* ssl;
SSL_CTX* ssl_ctx;
#endif
bool useSSL;
#endif
RunTSocket(int socket_id) : id(socket_id), isBlocking(true), isConnected(false),
isBound(false), isListening(false) {
#ifdef _WIN32
socket_fd = INVALID_SOCKET;
#else
socket_fd = -1;
#endif
#ifdef WITH_SSL
#ifdef _WIN32
ssl_cred_handle_ = nullptr;
ssl_context_handle_ = nullptr;
ssl_context_initialized_ = false;
#else
ssl = nullptr;
ssl_ctx = nullptr;
#endif
useSSL = false;
#endif
}
~RunTSocket() {
close();
}
void close() {
if (isValid()) {
#ifdef WITH_SSL
#ifdef _WIN32
if (ssl_context_handle_) {
ssl_context_handle_ = nullptr;
}
if (ssl_cred_handle_) {
ssl_cred_handle_ = nullptr;
}
ssl_context_initialized_ = false;
#else
if (ssl) {
SSL_shutdown(ssl);
SSL_free(ssl);
ssl = nullptr;
}
if (ssl_ctx) {
SSL_CTX_free(ssl_ctx);
ssl_ctx = nullptr;
}
#endif
#endif
#ifdef _WIN32
shutdown(socket_fd, SD_SEND);
closesocket(socket_fd);
socket_fd = INVALID_SOCKET;
#else
shutdown(socket_fd, SHUT_WR);
::close(socket_fd);
socket_fd = -1;
#endif
}
isConnected = false;
isBound = false;
isListening = false;
}
bool isValid() const {
#ifdef _WIN32
return socket_fd != INVALID_SOCKET;
#else
return socket_fd >= 0;
#endif
}
};
// global socket
static std::map<int, std::unique_ptr<RunTSocket>> g_sockets;
static int g_next_socket_id = 1;
static bool g_winsock_initialized = false;
static bool initialize_networking() {
if (g_winsock_initialized) return true;
#ifdef _WIN32
WSADATA wsaData;
int result = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (result != 0) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WSAStartup failed: %d", result);
return false;
}
#endif
#ifdef WITH_SSL
#ifdef _WIN32
// no global initialization needed for windows
kinc_log(KINC_LOG_LEVEL_INFO, "Windows SChannel SSL support initialized");
#else
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_all_algorithms();
kinc_log(KINC_LOG_LEVEL_INFO, "OpenSSL support initialized");
#endif
#endif
g_winsock_initialized = true;
kinc_log(KINC_LOG_LEVEL_INFO, "Networking initialized successfully");
return true;
}
static void cleanup_networking() {
if (!g_winsock_initialized) return;
g_sockets.clear();
#ifdef WITH_SSL
#ifdef _WIN32
kinc_log(KINC_LOG_LEVEL_INFO, "Windows SChannel SSL support cleaned up");
#else
EVP_cleanup();
kinc_log(KINC_LOG_LEVEL_INFO, "OpenSSL support cleaned up");
#endif
#endif
#ifdef _WIN32
WSACleanup();
#endif
g_winsock_initialized = false;
kinc_log(KINC_LOG_LEVEL_INFO, "Networking cleaned up");
}
// C++ Bridge Function Implementations
extern "C" int runt_socket_create() {
if (!initialize_networking()) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to initialize networking");
return -1;
}
int socket_id = g_next_socket_id++;
auto socket = std::make_unique<RunTSocket>(socket_id);
#ifdef _WIN32
socket->socket_fd = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (socket->socket_fd == INVALID_SOCKET) {
int error = WSAGetLastError();
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket creation failed: %d", error);
return -1;
}
#else
socket->socket_fd = ::socket(AF_INET, SOCK_STREAM, 0);
if (socket->socket_fd < 0) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket creation failed: %s", strerror(errno));
return -1;
}
#endif
int nodelay = 1;
setsockopt(socket->socket_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&nodelay, sizeof(nodelay));
int reuse = 1;
setsockopt(socket->socket_fd, SOL_SOCKET, SO_REUSEADDR, (const char*)&reuse, sizeof(reuse));
#if defined(__linux__) && defined(SO_REUSEPORT)
int reuseport = 1;
setsockopt(socket->socket_fd, SOL_SOCKET, SO_REUSEPORT, (const char*)&reuseport, sizeof(reuseport));
#endif
int keepalive = 1;
setsockopt(socket->socket_fd, SOL_SOCKET, SO_KEEPALIVE, (const char*)&keepalive, sizeof(keepalive));
// uWebsockets uses 1mb socket buffer sizes for high throughput
int sndbuf = 1024 * 1024;
int rcvbuf = 1024 * 1024;
setsockopt(socket->socket_fd, SOL_SOCKET, SO_SNDBUF, (const char*)&sndbuf, sizeof(sndbuf));
setsockopt(socket->socket_fd, SOL_SOCKET, SO_RCVBUF, (const char*)&rcvbuf, sizeof(rcvbuf));
#ifdef __APPLE__
// prevent SIGPIPE crashes on macOS
int no_sigpipe = 1;
setsockopt(socket->socket_fd, SOL_SOCKET, SO_NOSIGPIPE, (const char*)&no_sigpipe, sizeof(no_sigpipe));
#endif
g_sockets[socket_id] = std::move(socket);
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "Socket created successfully: ID %d", socket_id);
#endif
return socket_id;
}
extern "C" bool runt_socket_bind(int socket_id, const char* address, int port) {
auto it = g_sockets.find(socket_id);
if (it == g_sockets.end()) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Invalid socket ID: %d", socket_id);
return false;
}
auto& socket = it->second;
if (!socket->isValid()) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket not valid for bind: %d", socket_id);
return false;
}
sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
if (address && strlen(address) > 0) {
if (inet_pton(AF_INET, address, &addr.sin_addr) <= 0) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Invalid address: %s", address);
return false;
}
} else {
addr.sin_addr.s_addr = INADDR_ANY;
}
if (bind(socket->socket_fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
#ifdef _WIN32
int error = WSAGetLastError();
kinc_log(KINC_LOG_LEVEL_ERROR, "Bind failed: %d", error);
#else
kinc_log(KINC_LOG_LEVEL_ERROR, "Bind failed: %s", strerror(errno));
#endif
return false;
}
socket->isBound = true;
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "Socket bound successfully: %s:%d", address ? address : "0.0.0.0", port);
#endif
return true;
}
extern "C" bool runt_socket_listen(int socket_id, int backlog) {
auto it = g_sockets.find(socket_id);
if (it == g_sockets.end()) return false;
auto& socket = it->second;
if (!socket->isValid() || !socket->isBound) return false;
if (listen(socket->socket_fd, backlog) < 0) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Listen failed");
return false;
}
socket->isListening = true;
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "Socket listening: ID %d", socket_id);
#endif
return true;
}
extern "C" int runt_socket_accept(int socket_id) {
auto it = g_sockets.find(socket_id);
if (it == g_sockets.end()) return -1;
auto& socket = it->second;
if (!socket->isValid() || !socket->isListening) return -1;
sockaddr_in client_addr;
socklen_t client_len = sizeof(client_addr);
#ifdef _WIN32
SOCKET client_fd = accept(socket->socket_fd, (sockaddr*)&client_addr, &client_len);
if (client_fd == INVALID_SOCKET) return -1;
#else
int client_fd = accept(socket->socket_fd, (sockaddr*)&client_addr, &client_len);
if (client_fd < 0) return -1;
#endif
int client_id = g_next_socket_id++;
auto client_socket = std::make_unique<RunTSocket>(client_id);
client_socket->socket_fd = client_fd;
client_socket->isConnected = true;
int nodelay = 1;
setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&nodelay, sizeof(nodelay));
int keepalive = 1;
setsockopt(client_fd, SOL_SOCKET, SO_KEEPALIVE, (const char*)&keepalive, sizeof(keepalive));
int sndbuf = 1024 * 1024;
int rcvbuf = 1024 * 1024;
setsockopt(client_fd, SOL_SOCKET, SO_SNDBUF, (const char*)&sndbuf, sizeof(sndbuf));
setsockopt(client_fd, SOL_SOCKET, SO_RCVBUF, (const char*)&rcvbuf, sizeof(rcvbuf));
#ifdef __APPLE__
int no_sigpipe = 1;
setsockopt(client_fd, SOL_SOCKET, SO_NOSIGPIPE, (const char*)&no_sigpipe, sizeof(no_sigpipe));
#endif
g_sockets[client_id] = std::move(client_socket);
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "Socket accepted: ID %d", client_id);
#endif
return client_id;
}
extern "C" bool runt_socket_connect(int socket_id, const char* hostname, int port) {
auto it = g_sockets.find(socket_id);
if (it == g_sockets.end()) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Invalid socket ID: %d", socket_id);
return false;
}
auto& socket = it->second;
if (!socket->isValid()) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket not valid for connect: %d", socket_id);
return false;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "Connecting to %s:%d", hostname, port);
#endif
struct addrinfo hints, *result;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
char port_str[16];
snprintf(port_str, sizeof(port_str), "%d", port);
int status = getaddrinfo(hostname, port_str, &hints, &result);
if (status != 0) {
kinc_log(KINC_LOG_LEVEL_ERROR, "getaddrinfo failed for %s: %s", hostname, gai_strerror(status));
return false;
}
bool connected = false;
for (struct addrinfo* rp = result; rp != nullptr; rp = rp->ai_next) {
if (connect(socket->socket_fd, rp->ai_addr, static_cast<int>(rp->ai_addrlen)) == 0) {
connected = true;
break;
}
}
freeaddrinfo(result);
if (!connected) {
#ifdef _WIN32
int error = WSAGetLastError();
kinc_log(KINC_LOG_LEVEL_ERROR, "Connect failed to %s:%d - Error: %d", hostname, port, error);
#else
kinc_log(KINC_LOG_LEVEL_ERROR, "Connect failed to %s:%d - Error: %s", hostname, port, strerror(errno));
#endif
return false;
}
socket->isConnected = true;
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "Connected successfully to %s:%d", hostname, port);
#endif
return true;
}
extern "C" int runt_socket_send(int socket_id, const char* data, int length) {
auto it = g_sockets.find(socket_id);
if (it == g_sockets.end()) return -1;
auto& socket = it->second;
if (!socket->isValid() || !socket->isConnected) return -1;
#ifdef WITH_SSL
if (socket->useSSL) {
#ifdef _WIN32
// send implementation adapted from httprequest.cpp
if (!socket->ssl_context_initialized_) {
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL context not initialized");
return -1;
}
SecPkgContext_StreamSizes stream_sizes;
SECURITY_STATUS status = QueryContextAttributesA(
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
if (status != SEC_E_OK) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to get stream sizes: 0x%x", status);
return -1;
}
int total_size = stream_sizes.cbHeader + length + stream_sizes.cbTrailer;
char* encrypt_buffer = new char[total_size];
SecBufferDesc message_desc;
SecBuffer message_buffers[4];
message_desc.ulVersion = SECBUFFER_VERSION;
message_desc.cBuffers = 4;
message_desc.pBuffers = message_buffers;
message_buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
message_buffers[0].pvBuffer = encrypt_buffer;
message_buffers[0].cbBuffer = stream_sizes.cbHeader;
message_buffers[1].BufferType = SECBUFFER_DATA;
message_buffers[1].pvBuffer = encrypt_buffer + stream_sizes.cbHeader;
message_buffers[1].cbBuffer = length;
memcpy(message_buffers[1].pvBuffer, data, length);
message_buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
message_buffers[2].pvBuffer = encrypt_buffer + stream_sizes.cbHeader + length;
message_buffers[2].cbBuffer = stream_sizes.cbTrailer;
message_buffers[3].BufferType = SECBUFFER_EMPTY;
message_buffers[3].pvBuffer = nullptr;
message_buffers[3].cbBuffer = 0;
status = EncryptMessage(
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
0, &message_desc, 0);
if (status != SEC_E_OK) {
kinc_log(KINC_LOG_LEVEL_ERROR, "SChannel encrypt failed: 0x%x", status);
delete[] encrypt_buffer;
return -1;
}
int total_encrypted = message_buffers[0].cbBuffer + message_buffers[1].cbBuffer + message_buffers[2].cbBuffer;
int sent = send(socket->socket_fd, encrypt_buffer, total_encrypted, 0);
delete[] encrypt_buffer;
if (sent != total_encrypted) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to send encrypted SSL data");
return -1;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "SSL sent %d bytes (encrypted: %d bytes)", length, total_encrypted);
#endif
return length;
#else
if (socket->ssl) {
int sent = SSL_write(socket->ssl, data, length);
if (sent <= 0) {
int ssl_error = SSL_get_error(socket->ssl, sent);
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_write failed: %d", ssl_error);
return -1;
}
return sent;
}
#endif
}
#endif
#ifdef _WIN32
int sent = send(socket->socket_fd, data, length, 0);
#else
int sent = send(socket->socket_fd, data, length, MSG_NOSIGNAL);
#endif
if (sent < 0) {
#ifdef _WIN32
int error = WSAGetLastError();
if (error == WSAEWOULDBLOCK) {
return 0;
}
// connection reset/aborted we mark as disconnected
if (error == WSAECONNRESET || error == WSAECONNABORTED || error == WSAENETRESET) {
socket->isConnected = false;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "Send failed: %d", error);
#endif
#else
if (errno == EAGAIN || errno == EWOULDBLOCK) {
return 0;
}
if (errno == ECONNRESET || errno == ENOTCONN || errno == EPIPE) {
socket->isConnected = false;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "Send failed: %s", strerror(errno));
#endif
#endif
return -1;
}
#ifdef DEBUG_NETWORK
if (sent > 10) {
kinc_log(KINC_LOG_LEVEL_INFO, "Sent %d bytes", sent);
}
#endif
return sent;
}
extern "C" int runt_socket_recv(int socket_id, int max_length, char** out_data) {
auto it = g_sockets.find(socket_id);
if (it == g_sockets.end()) return -1;
auto& socket = it->second;
if (!socket->isValid() || !socket->isConnected) return -1;
char* buffer = new char[max_length];
#ifdef WITH_SSL
if (socket->useSSL) {
#ifdef _WIN32
if (!socket->ssl_context_initialized_) {
delete[] buffer;
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL context not initialized");
return -1;
}
while (true) {
// decrypt for buffered data
if (!socket->ssl_buffer_.empty()) {
SecBufferDesc message_desc;
SecBuffer message_buffers[4];
// buffer descriptor for decryption
message_desc.ulVersion = SECBUFFER_VERSION;
message_desc.cBuffers = 4;
message_desc.pBuffers = message_buffers;
// input buffer with encrypted data
message_buffers[0].BufferType = SECBUFFER_DATA;
message_buffers[0].pvBuffer = socket->ssl_buffer_.data();
message_buffers[0].cbBuffer = static_cast<unsigned long>(socket->ssl_buffer_.size());
// output buffer
message_buffers[1].BufferType = SECBUFFER_EMPTY;
message_buffers[1].pvBuffer = nullptr;
message_buffers[1].cbBuffer = 0;
// extra buffer for leftover data
message_buffers[2].BufferType = SECBUFFER_EMPTY;
message_buffers[2].pvBuffer = nullptr;
message_buffers[2].cbBuffer = 0;
// stream buffer
message_buffers[3].BufferType = SECBUFFER_EMPTY;
message_buffers[3].pvBuffer = nullptr;
message_buffers[3].cbBuffer = 0;
SECURITY_STATUS status = DecryptMessage(
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
&message_desc, 0, nullptr);
if (status == SEC_E_OK) {
SecBuffer* data_buffer = nullptr;
SecBuffer* extra_buffer = nullptr;
// find data and extra buffers
for (int i = 0; i < 4; i++) {
if (message_buffers[i].BufferType == SECBUFFER_DATA) {
data_buffer = &message_buffers[i];
} else if (message_buffers[i].BufferType == SECBUFFER_EXTRA) {
extra_buffer = &message_buffers[i];
}
}
if (data_buffer && data_buffer->cbBuffer > 0) {
int bytes_to_copy = (max_length < static_cast<int>(data_buffer->cbBuffer)) ? max_length : static_cast<int>(data_buffer->cbBuffer);
memcpy(buffer, data_buffer->pvBuffer, bytes_to_copy);
// move extra data to the beginning of the buffer
if (extra_buffer && extra_buffer->cbBuffer > 0) {
std::memmove(socket->ssl_buffer_.data(), extra_buffer->pvBuffer, extra_buffer->cbBuffer);
socket->ssl_buffer_.resize(extra_buffer->cbBuffer);
} else {
socket->ssl_buffer_.clear();
}
*out_data = buffer;
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "SSL received %d bytes (decrypted)", bytes_to_copy);
#endif
return bytes_to_copy;
}
socket->ssl_buffer_.clear();
} else if (status == SEC_E_INCOMPLETE_MESSAGE) {
// needs more from the socket
char recv_buffer[8192];
int received = recv(socket->socket_fd, recv_buffer, sizeof(recv_buffer), 0);
if (received <= 0) {
delete[] buffer;
if (received == 0) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "SSL connection closed by server");
#endif
return 0;
} else {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to receive encrypted data for SSL read");
return -1;
}
}
socket->ssl_buffer_.insert(socket->ssl_buffer_.end(), recv_buffer, recv_buffer + received);
continue;
} else {
delete[] buffer;
kinc_log(KINC_LOG_LEVEL_ERROR, "SChannel decrypt failed: 0x%x", status);
return -1;
}
}
char recv_buffer[8192];
int received = recv(socket->socket_fd, recv_buffer, sizeof(recv_buffer), 0);
if (received <= 0) {
delete[] buffer;
if (received == 0) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "SSL connection closed by server (no buffered data)");
#endif
return 0;
} else {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to receive encrypted data");
return -1;
}
}
socket->ssl_buffer_.insert(socket->ssl_buffer_.end(), recv_buffer, recv_buffer + received);
}
#else
if (socket->ssl) {
int received = SSL_read(socket->ssl, buffer, max_length);
if (received <= 0) {
delete[] buffer;
int ssl_error = SSL_get_error(socket->ssl, received);
if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) {
return 0;
}
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_read failed: %d", ssl_error);
return -1;
}
*out_data = buffer;
return received;
}
#endif
}
#endif
int received = recv(socket->socket_fd, buffer, max_length, 0);
if (received < 0) {
#ifdef _WIN32
int error = WSAGetLastError();
if (error == WSAEWOULDBLOCK) {
delete[] buffer;
return 0;
}
// WSAECONNRESET (10054) = connection reset by peer
// WSAECONNABORTED (10053) = connection aborted
// WSAENETRESET (10052) = Network dropped connection
// Treat all == disconnected
if (error == WSAECONNRESET || error == WSAECONNABORTED || error == WSAENETRESET) {
delete[] buffer;
socket->isConnected = false;
return -2;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "Recv failed: %d", error);
#endif
#else
if (errno == EAGAIN || errno == EWOULDBLOCK) {
delete[] buffer;
return 0;
}
if (errno == ECONNRESET || errno == ENOTCONN || errno == EPIPE) {
delete[] buffer;
socket->isConnected = false;
return -2;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "Recv failed: %s", strerror(errno));
#endif
#endif
delete[] buffer;
return -1;
}
if (received == 0) {
delete[] buffer;
socket->isConnected = false;
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "Connection closed by peer");
#endif
return -2; // Distinct value for connection closed
}
*out_data = buffer;
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "Received %d bytes", received);
#endif
return received;
}
extern "C" void runt_socket_close(int socket_id) {
auto it = g_sockets.find(socket_id);
if (it == g_sockets.end()) return;
it->second->close();
g_sockets.erase(it);
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "Socket closed: ID %d", socket_id);
#endif
}
extern "C" bool runt_socket_is_connected(int socket_id) {
auto it = g_sockets.find(socket_id);
if (it == g_sockets.end()) return false;
auto& socket = it->second;
return socket->isValid() && socket->isConnected;
}
extern "C" void runt_socket_set_blocking(int socket_id, bool blocking) {
auto it = g_sockets.find(socket_id);
if (it == g_sockets.end()) return;
auto& socket = it->second;
if (!socket->isValid()) return;
#ifdef _WIN32
u_long mode = blocking ? 0 : 1;
ioctlsocket(socket->socket_fd, FIONBIO, &mode);
#else
int flags = fcntl(socket->socket_fd, F_GETFL, 0);
if (blocking) {
flags &= ~O_NONBLOCK;
} else {
flags |= O_NONBLOCK;
}
fcntl(socket->socket_fd, F_SETFL, flags);
#endif
socket->isBlocking = blocking;
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "Socket blocking mode set: %s", blocking ? "blocking" : "non-blocking");
#endif
}
extern "C" bool runt_socket_select(int socket_id, double timeout_sec) {
auto it = g_sockets.find(socket_id);
if (it == g_sockets.end()) return false;
auto& socket = it->second;
if (!socket->isValid()) return false;
fd_set readfds;
FD_ZERO(&readfds);
FD_SET(socket->socket_fd, &readfds);
struct timeval tv;
tv.tv_sec = (long)timeout_sec;
tv.tv_usec = (long)((timeout_sec - tv.tv_sec) * 1000000);
int result = ::select((int)socket->socket_fd + 1, &readfds, nullptr, nullptr, &tv);
if (result > 0 && FD_ISSET(socket->socket_fd, &readfds)) {
return true;
}
return false;
}
#ifdef WITH_SSL
extern "C" bool runt_socket_enable_ssl(int socket_id) {
auto it = g_sockets.find(socket_id);
if (it == g_sockets.end()) return false;
auto& socket = it->second;
if (!socket->isValid() || !socket->isConnected) return false;
#ifdef _WIN32
if (socket->useSSL) return true;
if (!socket->ssl_cred_handle_) {
socket->ssl_cred_handle_ = new CredHandle();
memset(socket->ssl_cred_handle_, 0, sizeof(CredHandle));
SCHANNEL_CRED credentials;
memset(&credentials, 0, sizeof(credentials));
credentials.dwVersion = SCHANNEL_CRED_VERSION;
credentials.grbitEnabledProtocols = SP_PROT_TLS1_2;
credentials.dwFlags = SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_MANUAL_CRED_VALIDATION;
TimeStamp expiry;
SECURITY_STATUS status = AcquireCredentialsHandleA(
nullptr, const_cast<char*>(UNISP_NAME_A), SECPKG_CRED_OUTBOUND,
nullptr, &credentials, nullptr, nullptr,
reinterpret_cast<PCredHandle>(socket->ssl_cred_handle_), &expiry);
if (status != SEC_E_OK) {
kinc_log(KINC_LOG_LEVEL_ERROR, "AcquireCredentialsHandleA failed: 0x%x", status);
delete reinterpret_cast<CredHandle*>(socket->ssl_cred_handle_);
socket->ssl_cred_handle_ = nullptr;
return false;
}
kinc_log(KINC_LOG_LEVEL_INFO, "SSL credentials acquired successfully");
}
SecBufferDesc outbuffer_desc;
SecBuffer outbuffers[1];
DWORD context_attributes;
TimeStamp expiry;
outbuffers[0].pvBuffer = nullptr;
outbuffers[0].BufferType = SECBUFFER_TOKEN;
outbuffers[0].cbBuffer = 0;
outbuffer_desc.cBuffers = 1;
outbuffer_desc.pBuffers = outbuffers;
outbuffer_desc.ulVersion = SECBUFFER_VERSION;
socket->ssl_context_handle_ = new CtxtHandle();
memset(socket->ssl_context_handle_, 0, sizeof(CtxtHandle));
// TODO: Use actual hostname
SECURITY_STATUS status = InitializeSecurityContextA(
reinterpret_cast<PCredHandle>(socket->ssl_cred_handle_),
nullptr,
const_cast<char*>("localhost"),
ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT |
ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR |
ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM,
0, SECURITY_NATIVE_DREP, nullptr, 0,
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
&outbuffer_desc, &context_attributes, &expiry);
if (status == SEC_I_CONTINUE_NEEDED || status == SEC_E_OK) {
// send handshake data if available
if (outbuffers[0].cbBuffer > 0 && outbuffers[0].pvBuffer) {
int sent = send(socket->socket_fd, reinterpret_cast<char*>(outbuffers[0].pvBuffer), outbuffers[0].cbBuffer, 0);
if (sent != static_cast<int>(outbuffers[0].cbBuffer)) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to send SSL handshake data");
FreeContextBuffer(outbuffers[0].pvBuffer);
return false;
}
FreeContextBuffer(outbuffers[0].pvBuffer);
}
socket->ssl_context_initialized_ = true;
kinc_log(KINC_LOG_LEVEL_INFO, "SSL handshake initiated successfully");
} else {
kinc_log(KINC_LOG_LEVEL_ERROR, "InitializeSecurityContextA failed: 0x%x", status);
delete reinterpret_cast<CtxtHandle*>(socket->ssl_context_handle_);
socket->ssl_context_handle_ = nullptr;
return false;
}
#else
if (socket->ssl) return true;
socket->ssl_ctx = SSL_CTX_new(TLS_client_method());
if (!socket->ssl_ctx) {
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_CTX_new failed");
return false;
}
socket->ssl = SSL_new(socket->ssl_ctx);
if (!socket->ssl) {
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_new failed");
SSL_CTX_free(socket->ssl_ctx);
socket->ssl_ctx = nullptr;
return false;
}
SSL_set_fd(socket->ssl, socket->socket_fd);
int ret = SSL_connect(socket->ssl);
if (ret <= 0) {
int ssl_error = SSL_get_error(socket->ssl, ret);
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_connect failed: %d", ssl_error);
SSL_free(socket->ssl);
SSL_CTX_free(socket->ssl_ctx);
socket->ssl = nullptr;
socket->ssl_ctx = nullptr;
return false;
}
#endif
socket->useSSL = true;
kinc_log(KINC_LOG_LEVEL_INFO, "SSL enabled successfully for socket %d", socket_id);
return true;
}
#endif
extern "C" void runt_socket_cleanup() {
cleanup_networking();
}