Here comes he RunT!
This commit is contained in:
947
Sources/socket_bridge.cpp
Normal file
947
Sources/socket_bridge.cpp
Normal file
@ -0,0 +1,947 @@
|
||||
#include "socket_bridge.h"
|
||||
#include <kinc/log.h>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
#pragma comment(lib, "ws2_32.lib")
|
||||
#pragma comment(lib, "wsock32.lib")
|
||||
#ifdef WITH_SSL
|
||||
#ifdef _WIN32
|
||||
#define SECURITY_WIN32
|
||||
#include <wincrypt.h>
|
||||
#include <schannel.h>
|
||||
#include <security.h>
|
||||
#include <sspi.h>
|
||||
#pragma comment(lib, "crypt32.lib")
|
||||
#pragma comment(lib, "secur32.lib")
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <netdb.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <errno.h>
|
||||
#ifdef WITH_SSL
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/evp.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
class RunTSocket {
|
||||
public:
|
||||
int id;
|
||||
bool isBlocking;
|
||||
bool isConnected;
|
||||
bool isBound;
|
||||
bool isListening;
|
||||
std::string lastError;
|
||||
|
||||
#ifdef _WIN32
|
||||
SOCKET socket_fd;
|
||||
#else
|
||||
int socket_fd;
|
||||
#endif
|
||||
|
||||
#ifdef WITH_SSL
|
||||
#ifdef _WIN32
|
||||
// windows SChannel SSL support with opaque pointers
|
||||
void* ssl_cred_handle_;
|
||||
void* ssl_context_handle_;
|
||||
bool ssl_context_initialized_;
|
||||
std::vector<char> ssl_buffer_;
|
||||
#else
|
||||
SSL* ssl;
|
||||
SSL_CTX* ssl_ctx;
|
||||
#endif
|
||||
bool useSSL;
|
||||
#endif
|
||||
|
||||
RunTSocket(int socket_id) : id(socket_id), isBlocking(true), isConnected(false),
|
||||
isBound(false), isListening(false) {
|
||||
#ifdef _WIN32
|
||||
socket_fd = INVALID_SOCKET;
|
||||
#else
|
||||
socket_fd = -1;
|
||||
#endif
|
||||
|
||||
#ifdef WITH_SSL
|
||||
#ifdef _WIN32
|
||||
ssl_cred_handle_ = nullptr;
|
||||
ssl_context_handle_ = nullptr;
|
||||
ssl_context_initialized_ = false;
|
||||
#else
|
||||
ssl = nullptr;
|
||||
ssl_ctx = nullptr;
|
||||
#endif
|
||||
useSSL = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
~RunTSocket() {
|
||||
close();
|
||||
}
|
||||
|
||||
void close() {
|
||||
if (isValid()) {
|
||||
#ifdef WITH_SSL
|
||||
#ifdef _WIN32
|
||||
if (ssl_context_handle_) {
|
||||
ssl_context_handle_ = nullptr;
|
||||
}
|
||||
if (ssl_cred_handle_) {
|
||||
ssl_cred_handle_ = nullptr;
|
||||
}
|
||||
ssl_context_initialized_ = false;
|
||||
#else
|
||||
if (ssl) {
|
||||
SSL_shutdown(ssl);
|
||||
SSL_free(ssl);
|
||||
ssl = nullptr;
|
||||
}
|
||||
if (ssl_ctx) {
|
||||
SSL_CTX_free(ssl_ctx);
|
||||
ssl_ctx = nullptr;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
#ifdef _WIN32
|
||||
shutdown(socket_fd, SD_SEND);
|
||||
closesocket(socket_fd);
|
||||
socket_fd = INVALID_SOCKET;
|
||||
#else
|
||||
shutdown(socket_fd, SHUT_WR);
|
||||
::close(socket_fd);
|
||||
socket_fd = -1;
|
||||
#endif
|
||||
}
|
||||
isConnected = false;
|
||||
isBound = false;
|
||||
isListening = false;
|
||||
}
|
||||
|
||||
bool isValid() const {
|
||||
#ifdef _WIN32
|
||||
return socket_fd != INVALID_SOCKET;
|
||||
#else
|
||||
return socket_fd >= 0;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// global socket
|
||||
static std::map<int, std::unique_ptr<RunTSocket>> g_sockets;
|
||||
static int g_next_socket_id = 1;
|
||||
static bool g_winsock_initialized = false;
|
||||
|
||||
static bool initialize_networking() {
|
||||
if (g_winsock_initialized) return true;
|
||||
|
||||
#ifdef _WIN32
|
||||
WSADATA wsaData;
|
||||
int result = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
||||
if (result != 0) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "WSAStartup failed: %d", result);
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef WITH_SSL
|
||||
#ifdef _WIN32
|
||||
// no global initialization needed for windows
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Windows SChannel SSL support initialized");
|
||||
#else
|
||||
SSL_library_init();
|
||||
SSL_load_error_strings();
|
||||
OpenSSL_add_all_algorithms();
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "OpenSSL support initialized");
|
||||
#endif
|
||||
#endif
|
||||
|
||||
g_winsock_initialized = true;
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Networking initialized successfully");
|
||||
return true;
|
||||
}
|
||||
|
||||
static void cleanup_networking() {
|
||||
if (!g_winsock_initialized) return;
|
||||
|
||||
g_sockets.clear();
|
||||
|
||||
#ifdef WITH_SSL
|
||||
#ifdef _WIN32
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Windows SChannel SSL support cleaned up");
|
||||
#else
|
||||
EVP_cleanup();
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "OpenSSL support cleaned up");
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
|
||||
g_winsock_initialized = false;
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Networking cleaned up");
|
||||
}
|
||||
|
||||
// C++ Bridge Function Implementations
|
||||
|
||||
extern "C" int runt_socket_create() {
|
||||
if (!initialize_networking()) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to initialize networking");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int socket_id = g_next_socket_id++;
|
||||
auto socket = std::make_unique<RunTSocket>(socket_id);
|
||||
|
||||
#ifdef _WIN32
|
||||
socket->socket_fd = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
|
||||
if (socket->socket_fd == INVALID_SOCKET) {
|
||||
int error = WSAGetLastError();
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket creation failed: %d", error);
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
socket->socket_fd = ::socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (socket->socket_fd < 0) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket creation failed: %s", strerror(errno));
|
||||
return -1;
|
||||
}
|
||||
#endif
|
||||
|
||||
int nodelay = 1;
|
||||
setsockopt(socket->socket_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&nodelay, sizeof(nodelay));
|
||||
|
||||
int reuse = 1;
|
||||
setsockopt(socket->socket_fd, SOL_SOCKET, SO_REUSEADDR, (const char*)&reuse, sizeof(reuse));
|
||||
|
||||
#if defined(__linux__) && defined(SO_REUSEPORT)
|
||||
int reuseport = 1;
|
||||
setsockopt(socket->socket_fd, SOL_SOCKET, SO_REUSEPORT, (const char*)&reuseport, sizeof(reuseport));
|
||||
#endif
|
||||
|
||||
int keepalive = 1;
|
||||
setsockopt(socket->socket_fd, SOL_SOCKET, SO_KEEPALIVE, (const char*)&keepalive, sizeof(keepalive));
|
||||
|
||||
// uWebsockets uses 1mb socket buffer sizes for high throughput
|
||||
int sndbuf = 1024 * 1024;
|
||||
int rcvbuf = 1024 * 1024;
|
||||
setsockopt(socket->socket_fd, SOL_SOCKET, SO_SNDBUF, (const char*)&sndbuf, sizeof(sndbuf));
|
||||
setsockopt(socket->socket_fd, SOL_SOCKET, SO_RCVBUF, (const char*)&rcvbuf, sizeof(rcvbuf));
|
||||
|
||||
#ifdef __APPLE__
|
||||
// prevent SIGPIPE crashes on macOS
|
||||
int no_sigpipe = 1;
|
||||
setsockopt(socket->socket_fd, SOL_SOCKET, SO_NOSIGPIPE, (const char*)&no_sigpipe, sizeof(no_sigpipe));
|
||||
#endif
|
||||
|
||||
g_sockets[socket_id] = std::move(socket);
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Socket created successfully: ID %d", socket_id);
|
||||
#endif
|
||||
return socket_id;
|
||||
}
|
||||
|
||||
extern "C" bool runt_socket_bind(int socket_id, const char* address, int port) {
|
||||
auto it = g_sockets.find(socket_id);
|
||||
if (it == g_sockets.end()) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Invalid socket ID: %d", socket_id);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto& socket = it->second;
|
||||
if (!socket->isValid()) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket not valid for bind: %d", socket_id);
|
||||
return false;
|
||||
}
|
||||
|
||||
sockaddr_in addr;
|
||||
memset(&addr, 0, sizeof(addr));
|
||||
addr.sin_family = AF_INET;
|
||||
addr.sin_port = htons(port);
|
||||
|
||||
if (address && strlen(address) > 0) {
|
||||
if (inet_pton(AF_INET, address, &addr.sin_addr) <= 0) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Invalid address: %s", address);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
addr.sin_addr.s_addr = INADDR_ANY;
|
||||
}
|
||||
|
||||
if (bind(socket->socket_fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
|
||||
#ifdef _WIN32
|
||||
int error = WSAGetLastError();
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Bind failed: %d", error);
|
||||
#else
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Bind failed: %s", strerror(errno));
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
socket->isBound = true;
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Socket bound successfully: %s:%d", address ? address : "0.0.0.0", port);
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
extern "C" bool runt_socket_listen(int socket_id, int backlog) {
|
||||
auto it = g_sockets.find(socket_id);
|
||||
if (it == g_sockets.end()) return false;
|
||||
|
||||
auto& socket = it->second;
|
||||
if (!socket->isValid() || !socket->isBound) return false;
|
||||
|
||||
if (listen(socket->socket_fd, backlog) < 0) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Listen failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
socket->isListening = true;
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Socket listening: ID %d", socket_id);
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
extern "C" int runt_socket_accept(int socket_id) {
|
||||
auto it = g_sockets.find(socket_id);
|
||||
if (it == g_sockets.end()) return -1;
|
||||
|
||||
auto& socket = it->second;
|
||||
if (!socket->isValid() || !socket->isListening) return -1;
|
||||
|
||||
sockaddr_in client_addr;
|
||||
socklen_t client_len = sizeof(client_addr);
|
||||
|
||||
#ifdef _WIN32
|
||||
SOCKET client_fd = accept(socket->socket_fd, (sockaddr*)&client_addr, &client_len);
|
||||
if (client_fd == INVALID_SOCKET) return -1;
|
||||
#else
|
||||
int client_fd = accept(socket->socket_fd, (sockaddr*)&client_addr, &client_len);
|
||||
if (client_fd < 0) return -1;
|
||||
#endif
|
||||
|
||||
int client_id = g_next_socket_id++;
|
||||
auto client_socket = std::make_unique<RunTSocket>(client_id);
|
||||
client_socket->socket_fd = client_fd;
|
||||
client_socket->isConnected = true;
|
||||
|
||||
int nodelay = 1;
|
||||
setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&nodelay, sizeof(nodelay));
|
||||
|
||||
int keepalive = 1;
|
||||
setsockopt(client_fd, SOL_SOCKET, SO_KEEPALIVE, (const char*)&keepalive, sizeof(keepalive));
|
||||
|
||||
int sndbuf = 1024 * 1024;
|
||||
int rcvbuf = 1024 * 1024;
|
||||
setsockopt(client_fd, SOL_SOCKET, SO_SNDBUF, (const char*)&sndbuf, sizeof(sndbuf));
|
||||
setsockopt(client_fd, SOL_SOCKET, SO_RCVBUF, (const char*)&rcvbuf, sizeof(rcvbuf));
|
||||
|
||||
#ifdef __APPLE__
|
||||
int no_sigpipe = 1;
|
||||
setsockopt(client_fd, SOL_SOCKET, SO_NOSIGPIPE, (const char*)&no_sigpipe, sizeof(no_sigpipe));
|
||||
#endif
|
||||
|
||||
g_sockets[client_id] = std::move(client_socket);
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Socket accepted: ID %d", client_id);
|
||||
#endif
|
||||
return client_id;
|
||||
}
|
||||
|
||||
extern "C" bool runt_socket_connect(int socket_id, const char* hostname, int port) {
|
||||
auto it = g_sockets.find(socket_id);
|
||||
if (it == g_sockets.end()) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Invalid socket ID: %d", socket_id);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto& socket = it->second;
|
||||
if (!socket->isValid()) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket not valid for connect: %d", socket_id);
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Connecting to %s:%d", hostname, port);
|
||||
#endif
|
||||
|
||||
struct addrinfo hints, *result;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_INET;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
char port_str[16];
|
||||
snprintf(port_str, sizeof(port_str), "%d", port);
|
||||
|
||||
int status = getaddrinfo(hostname, port_str, &hints, &result);
|
||||
if (status != 0) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "getaddrinfo failed for %s: %s", hostname, gai_strerror(status));
|
||||
return false;
|
||||
}
|
||||
|
||||
bool connected = false;
|
||||
for (struct addrinfo* rp = result; rp != nullptr; rp = rp->ai_next) {
|
||||
if (connect(socket->socket_fd, rp->ai_addr, static_cast<int>(rp->ai_addrlen)) == 0) {
|
||||
connected = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
freeaddrinfo(result);
|
||||
|
||||
if (!connected) {
|
||||
#ifdef _WIN32
|
||||
int error = WSAGetLastError();
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Connect failed to %s:%d - Error: %d", hostname, port, error);
|
||||
#else
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Connect failed to %s:%d - Error: %s", hostname, port, strerror(errno));
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
socket->isConnected = true;
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Connected successfully to %s:%d", hostname, port);
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
extern "C" int runt_socket_send(int socket_id, const char* data, int length) {
|
||||
auto it = g_sockets.find(socket_id);
|
||||
if (it == g_sockets.end()) return -1;
|
||||
|
||||
auto& socket = it->second;
|
||||
if (!socket->isValid() || !socket->isConnected) return -1;
|
||||
|
||||
#ifdef WITH_SSL
|
||||
if (socket->useSSL) {
|
||||
#ifdef _WIN32
|
||||
// send implementation adapted from httprequest.cpp
|
||||
if (!socket->ssl_context_initialized_) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL context not initialized");
|
||||
return -1;
|
||||
}
|
||||
|
||||
SecPkgContext_StreamSizes stream_sizes;
|
||||
SECURITY_STATUS status = QueryContextAttributesA(
|
||||
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
|
||||
SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
|
||||
|
||||
if (status != SEC_E_OK) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to get stream sizes: 0x%x", status);
|
||||
return -1;
|
||||
}
|
||||
|
||||
int total_size = stream_sizes.cbHeader + length + stream_sizes.cbTrailer;
|
||||
char* encrypt_buffer = new char[total_size];
|
||||
|
||||
SecBufferDesc message_desc;
|
||||
SecBuffer message_buffers[4];
|
||||
|
||||
message_desc.ulVersion = SECBUFFER_VERSION;
|
||||
message_desc.cBuffers = 4;
|
||||
message_desc.pBuffers = message_buffers;
|
||||
|
||||
message_buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
|
||||
message_buffers[0].pvBuffer = encrypt_buffer;
|
||||
message_buffers[0].cbBuffer = stream_sizes.cbHeader;
|
||||
|
||||
message_buffers[1].BufferType = SECBUFFER_DATA;
|
||||
message_buffers[1].pvBuffer = encrypt_buffer + stream_sizes.cbHeader;
|
||||
message_buffers[1].cbBuffer = length;
|
||||
memcpy(message_buffers[1].pvBuffer, data, length);
|
||||
|
||||
message_buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
|
||||
message_buffers[2].pvBuffer = encrypt_buffer + stream_sizes.cbHeader + length;
|
||||
message_buffers[2].cbBuffer = stream_sizes.cbTrailer;
|
||||
|
||||
message_buffers[3].BufferType = SECBUFFER_EMPTY;
|
||||
message_buffers[3].pvBuffer = nullptr;
|
||||
message_buffers[3].cbBuffer = 0;
|
||||
|
||||
status = EncryptMessage(
|
||||
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
|
||||
0, &message_desc, 0);
|
||||
|
||||
if (status != SEC_E_OK) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "SChannel encrypt failed: 0x%x", status);
|
||||
delete[] encrypt_buffer;
|
||||
return -1;
|
||||
}
|
||||
|
||||
int total_encrypted = message_buffers[0].cbBuffer + message_buffers[1].cbBuffer + message_buffers[2].cbBuffer;
|
||||
int sent = send(socket->socket_fd, encrypt_buffer, total_encrypted, 0);
|
||||
|
||||
delete[] encrypt_buffer;
|
||||
|
||||
if (sent != total_encrypted) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to send encrypted SSL data");
|
||||
return -1;
|
||||
}
|
||||
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "SSL sent %d bytes (encrypted: %d bytes)", length, total_encrypted);
|
||||
#endif
|
||||
return length;
|
||||
#else
|
||||
if (socket->ssl) {
|
||||
int sent = SSL_write(socket->ssl, data, length);
|
||||
if (sent <= 0) {
|
||||
int ssl_error = SSL_get_error(socket->ssl, sent);
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_write failed: %d", ssl_error);
|
||||
return -1;
|
||||
}
|
||||
return sent;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
int sent = send(socket->socket_fd, data, length, 0);
|
||||
#else
|
||||
int sent = send(socket->socket_fd, data, length, MSG_NOSIGNAL);
|
||||
#endif
|
||||
if (sent < 0) {
|
||||
#ifdef _WIN32
|
||||
int error = WSAGetLastError();
|
||||
if (error == WSAEWOULDBLOCK) {
|
||||
return 0;
|
||||
}
|
||||
// connection reset/aborted we mark as disconnected
|
||||
if (error == WSAECONNRESET || error == WSAECONNABORTED || error == WSAENETRESET) {
|
||||
socket->isConnected = false;
|
||||
}
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Send failed: %d", error);
|
||||
#endif
|
||||
#else
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
return 0;
|
||||
}
|
||||
if (errno == ECONNRESET || errno == ENOTCONN || errno == EPIPE) {
|
||||
socket->isConnected = false;
|
||||
}
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Send failed: %s", strerror(errno));
|
||||
#endif
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
#ifdef DEBUG_NETWORK
|
||||
if (sent > 10) {
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Sent %d bytes", sent);
|
||||
}
|
||||
#endif
|
||||
return sent;
|
||||
}
|
||||
|
||||
extern "C" int runt_socket_recv(int socket_id, int max_length, char** out_data) {
|
||||
auto it = g_sockets.find(socket_id);
|
||||
if (it == g_sockets.end()) return -1;
|
||||
|
||||
auto& socket = it->second;
|
||||
if (!socket->isValid() || !socket->isConnected) return -1;
|
||||
|
||||
char* buffer = new char[max_length];
|
||||
|
||||
#ifdef WITH_SSL
|
||||
if (socket->useSSL) {
|
||||
#ifdef _WIN32
|
||||
if (!socket->ssl_context_initialized_) {
|
||||
delete[] buffer;
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL context not initialized");
|
||||
return -1;
|
||||
}
|
||||
|
||||
while (true) {
|
||||
// decrypt for buffered data
|
||||
if (!socket->ssl_buffer_.empty()) {
|
||||
SecBufferDesc message_desc;
|
||||
SecBuffer message_buffers[4];
|
||||
|
||||
// buffer descriptor for decryption
|
||||
message_desc.ulVersion = SECBUFFER_VERSION;
|
||||
message_desc.cBuffers = 4;
|
||||
message_desc.pBuffers = message_buffers;
|
||||
|
||||
// input buffer with encrypted data
|
||||
message_buffers[0].BufferType = SECBUFFER_DATA;
|
||||
message_buffers[0].pvBuffer = socket->ssl_buffer_.data();
|
||||
message_buffers[0].cbBuffer = static_cast<unsigned long>(socket->ssl_buffer_.size());
|
||||
|
||||
// output buffer
|
||||
message_buffers[1].BufferType = SECBUFFER_EMPTY;
|
||||
message_buffers[1].pvBuffer = nullptr;
|
||||
message_buffers[1].cbBuffer = 0;
|
||||
|
||||
// extra buffer for leftover data
|
||||
message_buffers[2].BufferType = SECBUFFER_EMPTY;
|
||||
message_buffers[2].pvBuffer = nullptr;
|
||||
message_buffers[2].cbBuffer = 0;
|
||||
|
||||
// stream buffer
|
||||
message_buffers[3].BufferType = SECBUFFER_EMPTY;
|
||||
message_buffers[3].pvBuffer = nullptr;
|
||||
message_buffers[3].cbBuffer = 0;
|
||||
|
||||
SECURITY_STATUS status = DecryptMessage(
|
||||
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
|
||||
&message_desc, 0, nullptr);
|
||||
|
||||
if (status == SEC_E_OK) {
|
||||
SecBuffer* data_buffer = nullptr;
|
||||
SecBuffer* extra_buffer = nullptr;
|
||||
|
||||
// find data and extra buffers
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (message_buffers[i].BufferType == SECBUFFER_DATA) {
|
||||
data_buffer = &message_buffers[i];
|
||||
} else if (message_buffers[i].BufferType == SECBUFFER_EXTRA) {
|
||||
extra_buffer = &message_buffers[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (data_buffer && data_buffer->cbBuffer > 0) {
|
||||
int bytes_to_copy = (max_length < static_cast<int>(data_buffer->cbBuffer)) ? max_length : static_cast<int>(data_buffer->cbBuffer);
|
||||
memcpy(buffer, data_buffer->pvBuffer, bytes_to_copy);
|
||||
|
||||
// move extra data to the beginning of the buffer
|
||||
if (extra_buffer && extra_buffer->cbBuffer > 0) {
|
||||
std::memmove(socket->ssl_buffer_.data(), extra_buffer->pvBuffer, extra_buffer->cbBuffer);
|
||||
socket->ssl_buffer_.resize(extra_buffer->cbBuffer);
|
||||
} else {
|
||||
socket->ssl_buffer_.clear();
|
||||
}
|
||||
|
||||
*out_data = buffer;
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "SSL received %d bytes (decrypted)", bytes_to_copy);
|
||||
#endif
|
||||
return bytes_to_copy;
|
||||
}
|
||||
|
||||
socket->ssl_buffer_.clear();
|
||||
|
||||
} else if (status == SEC_E_INCOMPLETE_MESSAGE) {
|
||||
// needs more from the socket
|
||||
char recv_buffer[8192];
|
||||
int received = recv(socket->socket_fd, recv_buffer, sizeof(recv_buffer), 0);
|
||||
if (received <= 0) {
|
||||
delete[] buffer;
|
||||
if (received == 0) {
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "SSL connection closed by server");
|
||||
#endif
|
||||
return 0;
|
||||
} else {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to receive encrypted data for SSL read");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
socket->ssl_buffer_.insert(socket->ssl_buffer_.end(), recv_buffer, recv_buffer + received);
|
||||
continue;
|
||||
|
||||
} else {
|
||||
delete[] buffer;
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "SChannel decrypt failed: 0x%x", status);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
char recv_buffer[8192];
|
||||
int received = recv(socket->socket_fd, recv_buffer, sizeof(recv_buffer), 0);
|
||||
if (received <= 0) {
|
||||
delete[] buffer;
|
||||
if (received == 0) {
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "SSL connection closed by server (no buffered data)");
|
||||
#endif
|
||||
return 0;
|
||||
} else {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to receive encrypted data");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
socket->ssl_buffer_.insert(socket->ssl_buffer_.end(), recv_buffer, recv_buffer + received);
|
||||
}
|
||||
#else
|
||||
if (socket->ssl) {
|
||||
int received = SSL_read(socket->ssl, buffer, max_length);
|
||||
if (received <= 0) {
|
||||
delete[] buffer;
|
||||
int ssl_error = SSL_get_error(socket->ssl, received);
|
||||
if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) {
|
||||
return 0;
|
||||
}
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_read failed: %d", ssl_error);
|
||||
return -1;
|
||||
}
|
||||
*out_data = buffer;
|
||||
return received;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
int received = recv(socket->socket_fd, buffer, max_length, 0);
|
||||
if (received < 0) {
|
||||
#ifdef _WIN32
|
||||
int error = WSAGetLastError();
|
||||
if (error == WSAEWOULDBLOCK) {
|
||||
delete[] buffer;
|
||||
return 0;
|
||||
}
|
||||
// WSAECONNRESET (10054) = connection reset by peer
|
||||
// WSAECONNABORTED (10053) = connection aborted
|
||||
// WSAENETRESET (10052) = Network dropped connection
|
||||
// Treat all == disconnected
|
||||
if (error == WSAECONNRESET || error == WSAECONNABORTED || error == WSAENETRESET) {
|
||||
delete[] buffer;
|
||||
socket->isConnected = false;
|
||||
return -2;
|
||||
}
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Recv failed: %d", error);
|
||||
#endif
|
||||
#else
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
delete[] buffer;
|
||||
return 0;
|
||||
}
|
||||
if (errno == ECONNRESET || errno == ENOTCONN || errno == EPIPE) {
|
||||
delete[] buffer;
|
||||
socket->isConnected = false;
|
||||
return -2;
|
||||
}
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Recv failed: %s", strerror(errno));
|
||||
#endif
|
||||
#endif
|
||||
delete[] buffer;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (received == 0) {
|
||||
delete[] buffer;
|
||||
socket->isConnected = false;
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Connection closed by peer");
|
||||
#endif
|
||||
return -2; // Distinct value for connection closed
|
||||
}
|
||||
|
||||
*out_data = buffer;
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Received %d bytes", received);
|
||||
#endif
|
||||
return received;
|
||||
}
|
||||
|
||||
extern "C" void runt_socket_close(int socket_id) {
|
||||
auto it = g_sockets.find(socket_id);
|
||||
if (it == g_sockets.end()) return;
|
||||
|
||||
it->second->close();
|
||||
g_sockets.erase(it);
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Socket closed: ID %d", socket_id);
|
||||
#endif
|
||||
}
|
||||
|
||||
extern "C" bool runt_socket_is_connected(int socket_id) {
|
||||
auto it = g_sockets.find(socket_id);
|
||||
if (it == g_sockets.end()) return false;
|
||||
|
||||
auto& socket = it->second;
|
||||
return socket->isValid() && socket->isConnected;
|
||||
}
|
||||
|
||||
extern "C" void runt_socket_set_blocking(int socket_id, bool blocking) {
|
||||
auto it = g_sockets.find(socket_id);
|
||||
if (it == g_sockets.end()) return;
|
||||
|
||||
auto& socket = it->second;
|
||||
if (!socket->isValid()) return;
|
||||
|
||||
#ifdef _WIN32
|
||||
u_long mode = blocking ? 0 : 1;
|
||||
ioctlsocket(socket->socket_fd, FIONBIO, &mode);
|
||||
#else
|
||||
int flags = fcntl(socket->socket_fd, F_GETFL, 0);
|
||||
if (blocking) {
|
||||
flags &= ~O_NONBLOCK;
|
||||
} else {
|
||||
flags |= O_NONBLOCK;
|
||||
}
|
||||
fcntl(socket->socket_fd, F_SETFL, flags);
|
||||
#endif
|
||||
|
||||
socket->isBlocking = blocking;
|
||||
#ifdef DEBUG_NETWORK
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "Socket blocking mode set: %s", blocking ? "blocking" : "non-blocking");
|
||||
#endif
|
||||
}
|
||||
|
||||
extern "C" bool runt_socket_select(int socket_id, double timeout_sec) {
|
||||
auto it = g_sockets.find(socket_id);
|
||||
if (it == g_sockets.end()) return false;
|
||||
|
||||
auto& socket = it->second;
|
||||
if (!socket->isValid()) return false;
|
||||
|
||||
fd_set readfds;
|
||||
FD_ZERO(&readfds);
|
||||
FD_SET(socket->socket_fd, &readfds);
|
||||
|
||||
struct timeval tv;
|
||||
tv.tv_sec = (long)timeout_sec;
|
||||
tv.tv_usec = (long)((timeout_sec - tv.tv_sec) * 1000000);
|
||||
|
||||
int result = ::select((int)socket->socket_fd + 1, &readfds, nullptr, nullptr, &tv);
|
||||
|
||||
if (result > 0 && FD_ISSET(socket->socket_fd, &readfds)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef WITH_SSL
|
||||
extern "C" bool runt_socket_enable_ssl(int socket_id) {
|
||||
auto it = g_sockets.find(socket_id);
|
||||
if (it == g_sockets.end()) return false;
|
||||
|
||||
auto& socket = it->second;
|
||||
if (!socket->isValid() || !socket->isConnected) return false;
|
||||
|
||||
#ifdef _WIN32
|
||||
if (socket->useSSL) return true;
|
||||
|
||||
if (!socket->ssl_cred_handle_) {
|
||||
socket->ssl_cred_handle_ = new CredHandle();
|
||||
memset(socket->ssl_cred_handle_, 0, sizeof(CredHandle));
|
||||
|
||||
SCHANNEL_CRED credentials;
|
||||
memset(&credentials, 0, sizeof(credentials));
|
||||
credentials.dwVersion = SCHANNEL_CRED_VERSION;
|
||||
credentials.grbitEnabledProtocols = SP_PROT_TLS1_2;
|
||||
credentials.dwFlags = SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_MANUAL_CRED_VALIDATION;
|
||||
|
||||
TimeStamp expiry;
|
||||
SECURITY_STATUS status = AcquireCredentialsHandleA(
|
||||
nullptr, const_cast<char*>(UNISP_NAME_A), SECPKG_CRED_OUTBOUND,
|
||||
nullptr, &credentials, nullptr, nullptr,
|
||||
reinterpret_cast<PCredHandle>(socket->ssl_cred_handle_), &expiry);
|
||||
|
||||
if (status != SEC_E_OK) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "AcquireCredentialsHandleA failed: 0x%x", status);
|
||||
delete reinterpret_cast<CredHandle*>(socket->ssl_cred_handle_);
|
||||
socket->ssl_cred_handle_ = nullptr;
|
||||
return false;
|
||||
}
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "SSL credentials acquired successfully");
|
||||
}
|
||||
|
||||
SecBufferDesc outbuffer_desc;
|
||||
SecBuffer outbuffers[1];
|
||||
DWORD context_attributes;
|
||||
TimeStamp expiry;
|
||||
|
||||
outbuffers[0].pvBuffer = nullptr;
|
||||
outbuffers[0].BufferType = SECBUFFER_TOKEN;
|
||||
outbuffers[0].cbBuffer = 0;
|
||||
outbuffer_desc.cBuffers = 1;
|
||||
outbuffer_desc.pBuffers = outbuffers;
|
||||
outbuffer_desc.ulVersion = SECBUFFER_VERSION;
|
||||
|
||||
socket->ssl_context_handle_ = new CtxtHandle();
|
||||
memset(socket->ssl_context_handle_, 0, sizeof(CtxtHandle));
|
||||
|
||||
|
||||
// TODO: Use actual hostname
|
||||
SECURITY_STATUS status = InitializeSecurityContextA(
|
||||
reinterpret_cast<PCredHandle>(socket->ssl_cred_handle_),
|
||||
nullptr,
|
||||
const_cast<char*>("localhost"),
|
||||
ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT |
|
||||
ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR |
|
||||
ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM,
|
||||
0, SECURITY_NATIVE_DREP, nullptr, 0,
|
||||
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
|
||||
&outbuffer_desc, &context_attributes, &expiry);
|
||||
|
||||
if (status == SEC_I_CONTINUE_NEEDED || status == SEC_E_OK) {
|
||||
// send handshake data if available
|
||||
if (outbuffers[0].cbBuffer > 0 && outbuffers[0].pvBuffer) {
|
||||
int sent = send(socket->socket_fd, reinterpret_cast<char*>(outbuffers[0].pvBuffer), outbuffers[0].cbBuffer, 0);
|
||||
if (sent != static_cast<int>(outbuffers[0].cbBuffer)) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to send SSL handshake data");
|
||||
FreeContextBuffer(outbuffers[0].pvBuffer);
|
||||
return false;
|
||||
}
|
||||
FreeContextBuffer(outbuffers[0].pvBuffer);
|
||||
}
|
||||
|
||||
socket->ssl_context_initialized_ = true;
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "SSL handshake initiated successfully");
|
||||
} else {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "InitializeSecurityContextA failed: 0x%x", status);
|
||||
delete reinterpret_cast<CtxtHandle*>(socket->ssl_context_handle_);
|
||||
socket->ssl_context_handle_ = nullptr;
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
if (socket->ssl) return true;
|
||||
socket->ssl_ctx = SSL_CTX_new(TLS_client_method());
|
||||
if (!socket->ssl_ctx) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_CTX_new failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
socket->ssl = SSL_new(socket->ssl_ctx);
|
||||
if (!socket->ssl) {
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_new failed");
|
||||
SSL_CTX_free(socket->ssl_ctx);
|
||||
socket->ssl_ctx = nullptr;
|
||||
return false;
|
||||
}
|
||||
|
||||
SSL_set_fd(socket->ssl, socket->socket_fd);
|
||||
|
||||
int ret = SSL_connect(socket->ssl);
|
||||
if (ret <= 0) {
|
||||
int ssl_error = SSL_get_error(socket->ssl, ret);
|
||||
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_connect failed: %d", ssl_error);
|
||||
SSL_free(socket->ssl);
|
||||
SSL_CTX_free(socket->ssl_ctx);
|
||||
socket->ssl = nullptr;
|
||||
socket->ssl_ctx = nullptr;
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
socket->useSSL = true;
|
||||
kinc_log(KINC_LOG_LEVEL_INFO, "SSL enabled successfully for socket %d", socket_id);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
extern "C" void runt_socket_cleanup() {
|
||||
cleanup_networking();
|
||||
}
|
||||
Reference in New Issue
Block a user