948 lines
30 KiB
C++
948 lines
30 KiB
C++
#include "socket_bridge.h"
|
|
#include <kinc/log.h>
|
|
#include <map>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#ifdef _WIN32
|
|
#include <winsock2.h>
|
|
#include <ws2tcpip.h>
|
|
#pragma comment(lib, "ws2_32.lib")
|
|
#pragma comment(lib, "wsock32.lib")
|
|
#ifdef WITH_SSL
|
|
#ifdef _WIN32
|
|
#define SECURITY_WIN32
|
|
#include <wincrypt.h>
|
|
#include <schannel.h>
|
|
#include <security.h>
|
|
#include <sspi.h>
|
|
#pragma comment(lib, "crypt32.lib")
|
|
#pragma comment(lib, "secur32.lib")
|
|
#endif
|
|
#endif
|
|
#else
|
|
#include <sys/socket.h>
|
|
#include <netinet/in.h>
|
|
#include <arpa/inet.h>
|
|
#include <netdb.h>
|
|
#include <unistd.h>
|
|
#include <fcntl.h>
|
|
#include <errno.h>
|
|
#ifdef WITH_SSL
|
|
#include <openssl/ssl.h>
|
|
#include <openssl/err.h>
|
|
#include <openssl/evp.h>
|
|
#endif
|
|
#endif
|
|
|
|
class RunTSocket {
|
|
public:
|
|
int id;
|
|
bool isBlocking;
|
|
bool isConnected;
|
|
bool isBound;
|
|
bool isListening;
|
|
std::string lastError;
|
|
|
|
#ifdef _WIN32
|
|
SOCKET socket_fd;
|
|
#else
|
|
int socket_fd;
|
|
#endif
|
|
|
|
#ifdef WITH_SSL
|
|
#ifdef _WIN32
|
|
// windows SChannel SSL support with opaque pointers
|
|
void* ssl_cred_handle_;
|
|
void* ssl_context_handle_;
|
|
bool ssl_context_initialized_;
|
|
std::vector<char> ssl_buffer_;
|
|
#else
|
|
SSL* ssl;
|
|
SSL_CTX* ssl_ctx;
|
|
#endif
|
|
bool useSSL;
|
|
#endif
|
|
|
|
RunTSocket(int socket_id) : id(socket_id), isBlocking(true), isConnected(false),
|
|
isBound(false), isListening(false) {
|
|
#ifdef _WIN32
|
|
socket_fd = INVALID_SOCKET;
|
|
#else
|
|
socket_fd = -1;
|
|
#endif
|
|
|
|
#ifdef WITH_SSL
|
|
#ifdef _WIN32
|
|
ssl_cred_handle_ = nullptr;
|
|
ssl_context_handle_ = nullptr;
|
|
ssl_context_initialized_ = false;
|
|
#else
|
|
ssl = nullptr;
|
|
ssl_ctx = nullptr;
|
|
#endif
|
|
useSSL = false;
|
|
#endif
|
|
}
|
|
|
|
~RunTSocket() {
|
|
close();
|
|
}
|
|
|
|
void close() {
|
|
if (isValid()) {
|
|
#ifdef WITH_SSL
|
|
#ifdef _WIN32
|
|
if (ssl_context_handle_) {
|
|
ssl_context_handle_ = nullptr;
|
|
}
|
|
if (ssl_cred_handle_) {
|
|
ssl_cred_handle_ = nullptr;
|
|
}
|
|
ssl_context_initialized_ = false;
|
|
#else
|
|
if (ssl) {
|
|
SSL_shutdown(ssl);
|
|
SSL_free(ssl);
|
|
ssl = nullptr;
|
|
}
|
|
if (ssl_ctx) {
|
|
SSL_CTX_free(ssl_ctx);
|
|
ssl_ctx = nullptr;
|
|
}
|
|
#endif
|
|
#endif
|
|
#ifdef _WIN32
|
|
shutdown(socket_fd, SD_SEND);
|
|
closesocket(socket_fd);
|
|
socket_fd = INVALID_SOCKET;
|
|
#else
|
|
shutdown(socket_fd, SHUT_WR);
|
|
::close(socket_fd);
|
|
socket_fd = -1;
|
|
#endif
|
|
}
|
|
isConnected = false;
|
|
isBound = false;
|
|
isListening = false;
|
|
}
|
|
|
|
bool isValid() const {
|
|
#ifdef _WIN32
|
|
return socket_fd != INVALID_SOCKET;
|
|
#else
|
|
return socket_fd >= 0;
|
|
#endif
|
|
}
|
|
};
|
|
|
|
// global socket
|
|
static std::map<int, std::unique_ptr<RunTSocket>> g_sockets;
|
|
static int g_next_socket_id = 1;
|
|
static bool g_winsock_initialized = false;
|
|
|
|
static bool initialize_networking() {
|
|
if (g_winsock_initialized) return true;
|
|
|
|
#ifdef _WIN32
|
|
WSADATA wsaData;
|
|
int result = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
|
if (result != 0) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "WSAStartup failed: %d", result);
|
|
return false;
|
|
}
|
|
#endif
|
|
|
|
#ifdef WITH_SSL
|
|
#ifdef _WIN32
|
|
// no global initialization needed for windows
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Windows SChannel SSL support initialized");
|
|
#else
|
|
SSL_library_init();
|
|
SSL_load_error_strings();
|
|
OpenSSL_add_all_algorithms();
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "OpenSSL support initialized");
|
|
#endif
|
|
#endif
|
|
|
|
g_winsock_initialized = true;
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Networking initialized successfully");
|
|
return true;
|
|
}
|
|
|
|
static void cleanup_networking() {
|
|
if (!g_winsock_initialized) return;
|
|
|
|
g_sockets.clear();
|
|
|
|
#ifdef WITH_SSL
|
|
#ifdef _WIN32
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Windows SChannel SSL support cleaned up");
|
|
#else
|
|
EVP_cleanup();
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "OpenSSL support cleaned up");
|
|
#endif
|
|
#endif
|
|
|
|
#ifdef _WIN32
|
|
WSACleanup();
|
|
#endif
|
|
|
|
g_winsock_initialized = false;
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Networking cleaned up");
|
|
}
|
|
|
|
// C++ Bridge Function Implementations
|
|
|
|
extern "C" int runt_socket_create() {
|
|
if (!initialize_networking()) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to initialize networking");
|
|
return -1;
|
|
}
|
|
|
|
int socket_id = g_next_socket_id++;
|
|
auto socket = std::make_unique<RunTSocket>(socket_id);
|
|
|
|
#ifdef _WIN32
|
|
socket->socket_fd = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
|
|
if (socket->socket_fd == INVALID_SOCKET) {
|
|
int error = WSAGetLastError();
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket creation failed: %d", error);
|
|
return -1;
|
|
}
|
|
#else
|
|
socket->socket_fd = ::socket(AF_INET, SOCK_STREAM, 0);
|
|
if (socket->socket_fd < 0) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket creation failed: %s", strerror(errno));
|
|
return -1;
|
|
}
|
|
#endif
|
|
|
|
int nodelay = 1;
|
|
setsockopt(socket->socket_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&nodelay, sizeof(nodelay));
|
|
|
|
int reuse = 1;
|
|
setsockopt(socket->socket_fd, SOL_SOCKET, SO_REUSEADDR, (const char*)&reuse, sizeof(reuse));
|
|
|
|
#if defined(__linux__) && defined(SO_REUSEPORT)
|
|
int reuseport = 1;
|
|
setsockopt(socket->socket_fd, SOL_SOCKET, SO_REUSEPORT, (const char*)&reuseport, sizeof(reuseport));
|
|
#endif
|
|
|
|
int keepalive = 1;
|
|
setsockopt(socket->socket_fd, SOL_SOCKET, SO_KEEPALIVE, (const char*)&keepalive, sizeof(keepalive));
|
|
|
|
// uWebsockets uses 1mb socket buffer sizes for high throughput
|
|
int sndbuf = 1024 * 1024;
|
|
int rcvbuf = 1024 * 1024;
|
|
setsockopt(socket->socket_fd, SOL_SOCKET, SO_SNDBUF, (const char*)&sndbuf, sizeof(sndbuf));
|
|
setsockopt(socket->socket_fd, SOL_SOCKET, SO_RCVBUF, (const char*)&rcvbuf, sizeof(rcvbuf));
|
|
|
|
#ifdef __APPLE__
|
|
// prevent SIGPIPE crashes on macOS
|
|
int no_sigpipe = 1;
|
|
setsockopt(socket->socket_fd, SOL_SOCKET, SO_NOSIGPIPE, (const char*)&no_sigpipe, sizeof(no_sigpipe));
|
|
#endif
|
|
|
|
g_sockets[socket_id] = std::move(socket);
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Socket created successfully: ID %d", socket_id);
|
|
#endif
|
|
return socket_id;
|
|
}
|
|
|
|
extern "C" bool runt_socket_bind(int socket_id, const char* address, int port) {
|
|
auto it = g_sockets.find(socket_id);
|
|
if (it == g_sockets.end()) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Invalid socket ID: %d", socket_id);
|
|
return false;
|
|
}
|
|
|
|
auto& socket = it->second;
|
|
if (!socket->isValid()) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket not valid for bind: %d", socket_id);
|
|
return false;
|
|
}
|
|
|
|
sockaddr_in addr;
|
|
memset(&addr, 0, sizeof(addr));
|
|
addr.sin_family = AF_INET;
|
|
addr.sin_port = htons(port);
|
|
|
|
if (address && strlen(address) > 0) {
|
|
if (inet_pton(AF_INET, address, &addr.sin_addr) <= 0) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Invalid address: %s", address);
|
|
return false;
|
|
}
|
|
} else {
|
|
addr.sin_addr.s_addr = INADDR_ANY;
|
|
}
|
|
|
|
if (bind(socket->socket_fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
|
|
#ifdef _WIN32
|
|
int error = WSAGetLastError();
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Bind failed: %d", error);
|
|
#else
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Bind failed: %s", strerror(errno));
|
|
#endif
|
|
return false;
|
|
}
|
|
|
|
socket->isBound = true;
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Socket bound successfully: %s:%d", address ? address : "0.0.0.0", port);
|
|
#endif
|
|
return true;
|
|
}
|
|
|
|
extern "C" bool runt_socket_listen(int socket_id, int backlog) {
|
|
auto it = g_sockets.find(socket_id);
|
|
if (it == g_sockets.end()) return false;
|
|
|
|
auto& socket = it->second;
|
|
if (!socket->isValid() || !socket->isBound) return false;
|
|
|
|
if (listen(socket->socket_fd, backlog) < 0) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Listen failed");
|
|
return false;
|
|
}
|
|
|
|
socket->isListening = true;
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Socket listening: ID %d", socket_id);
|
|
#endif
|
|
return true;
|
|
}
|
|
|
|
extern "C" int runt_socket_accept(int socket_id) {
|
|
auto it = g_sockets.find(socket_id);
|
|
if (it == g_sockets.end()) return -1;
|
|
|
|
auto& socket = it->second;
|
|
if (!socket->isValid() || !socket->isListening) return -1;
|
|
|
|
sockaddr_in client_addr;
|
|
socklen_t client_len = sizeof(client_addr);
|
|
|
|
#ifdef _WIN32
|
|
SOCKET client_fd = accept(socket->socket_fd, (sockaddr*)&client_addr, &client_len);
|
|
if (client_fd == INVALID_SOCKET) return -1;
|
|
#else
|
|
int client_fd = accept(socket->socket_fd, (sockaddr*)&client_addr, &client_len);
|
|
if (client_fd < 0) return -1;
|
|
#endif
|
|
|
|
int client_id = g_next_socket_id++;
|
|
auto client_socket = std::make_unique<RunTSocket>(client_id);
|
|
client_socket->socket_fd = client_fd;
|
|
client_socket->isConnected = true;
|
|
|
|
int nodelay = 1;
|
|
setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&nodelay, sizeof(nodelay));
|
|
|
|
int keepalive = 1;
|
|
setsockopt(client_fd, SOL_SOCKET, SO_KEEPALIVE, (const char*)&keepalive, sizeof(keepalive));
|
|
|
|
int sndbuf = 1024 * 1024;
|
|
int rcvbuf = 1024 * 1024;
|
|
setsockopt(client_fd, SOL_SOCKET, SO_SNDBUF, (const char*)&sndbuf, sizeof(sndbuf));
|
|
setsockopt(client_fd, SOL_SOCKET, SO_RCVBUF, (const char*)&rcvbuf, sizeof(rcvbuf));
|
|
|
|
#ifdef __APPLE__
|
|
int no_sigpipe = 1;
|
|
setsockopt(client_fd, SOL_SOCKET, SO_NOSIGPIPE, (const char*)&no_sigpipe, sizeof(no_sigpipe));
|
|
#endif
|
|
|
|
g_sockets[client_id] = std::move(client_socket);
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Socket accepted: ID %d", client_id);
|
|
#endif
|
|
return client_id;
|
|
}
|
|
|
|
extern "C" bool runt_socket_connect(int socket_id, const char* hostname, int port) {
|
|
auto it = g_sockets.find(socket_id);
|
|
if (it == g_sockets.end()) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Invalid socket ID: %d", socket_id);
|
|
return false;
|
|
}
|
|
|
|
auto& socket = it->second;
|
|
if (!socket->isValid()) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Socket not valid for connect: %d", socket_id);
|
|
return false;
|
|
}
|
|
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Connecting to %s:%d", hostname, port);
|
|
#endif
|
|
|
|
struct addrinfo hints, *result;
|
|
memset(&hints, 0, sizeof(hints));
|
|
hints.ai_family = AF_INET;
|
|
hints.ai_socktype = SOCK_STREAM;
|
|
|
|
char port_str[16];
|
|
snprintf(port_str, sizeof(port_str), "%d", port);
|
|
|
|
int status = getaddrinfo(hostname, port_str, &hints, &result);
|
|
if (status != 0) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "getaddrinfo failed for %s: %s", hostname, gai_strerror(status));
|
|
return false;
|
|
}
|
|
|
|
bool connected = false;
|
|
for (struct addrinfo* rp = result; rp != nullptr; rp = rp->ai_next) {
|
|
if (connect(socket->socket_fd, rp->ai_addr, static_cast<int>(rp->ai_addrlen)) == 0) {
|
|
connected = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
freeaddrinfo(result);
|
|
|
|
if (!connected) {
|
|
#ifdef _WIN32
|
|
int error = WSAGetLastError();
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Connect failed to %s:%d - Error: %d", hostname, port, error);
|
|
#else
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Connect failed to %s:%d - Error: %s", hostname, port, strerror(errno));
|
|
#endif
|
|
return false;
|
|
}
|
|
|
|
socket->isConnected = true;
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Connected successfully to %s:%d", hostname, port);
|
|
#endif
|
|
return true;
|
|
}
|
|
|
|
extern "C" int runt_socket_send(int socket_id, const char* data, int length) {
|
|
auto it = g_sockets.find(socket_id);
|
|
if (it == g_sockets.end()) return -1;
|
|
|
|
auto& socket = it->second;
|
|
if (!socket->isValid() || !socket->isConnected) return -1;
|
|
|
|
#ifdef WITH_SSL
|
|
if (socket->useSSL) {
|
|
#ifdef _WIN32
|
|
// send implementation adapted from httprequest.cpp
|
|
if (!socket->ssl_context_initialized_) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL context not initialized");
|
|
return -1;
|
|
}
|
|
|
|
SecPkgContext_StreamSizes stream_sizes;
|
|
SECURITY_STATUS status = QueryContextAttributesA(
|
|
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
|
|
SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
|
|
|
|
if (status != SEC_E_OK) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to get stream sizes: 0x%x", status);
|
|
return -1;
|
|
}
|
|
|
|
int total_size = stream_sizes.cbHeader + length + stream_sizes.cbTrailer;
|
|
char* encrypt_buffer = new char[total_size];
|
|
|
|
SecBufferDesc message_desc;
|
|
SecBuffer message_buffers[4];
|
|
|
|
message_desc.ulVersion = SECBUFFER_VERSION;
|
|
message_desc.cBuffers = 4;
|
|
message_desc.pBuffers = message_buffers;
|
|
|
|
message_buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
|
|
message_buffers[0].pvBuffer = encrypt_buffer;
|
|
message_buffers[0].cbBuffer = stream_sizes.cbHeader;
|
|
|
|
message_buffers[1].BufferType = SECBUFFER_DATA;
|
|
message_buffers[1].pvBuffer = encrypt_buffer + stream_sizes.cbHeader;
|
|
message_buffers[1].cbBuffer = length;
|
|
memcpy(message_buffers[1].pvBuffer, data, length);
|
|
|
|
message_buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
|
|
message_buffers[2].pvBuffer = encrypt_buffer + stream_sizes.cbHeader + length;
|
|
message_buffers[2].cbBuffer = stream_sizes.cbTrailer;
|
|
|
|
message_buffers[3].BufferType = SECBUFFER_EMPTY;
|
|
message_buffers[3].pvBuffer = nullptr;
|
|
message_buffers[3].cbBuffer = 0;
|
|
|
|
status = EncryptMessage(
|
|
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
|
|
0, &message_desc, 0);
|
|
|
|
if (status != SEC_E_OK) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "SChannel encrypt failed: 0x%x", status);
|
|
delete[] encrypt_buffer;
|
|
return -1;
|
|
}
|
|
|
|
int total_encrypted = message_buffers[0].cbBuffer + message_buffers[1].cbBuffer + message_buffers[2].cbBuffer;
|
|
int sent = send(socket->socket_fd, encrypt_buffer, total_encrypted, 0);
|
|
|
|
delete[] encrypt_buffer;
|
|
|
|
if (sent != total_encrypted) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to send encrypted SSL data");
|
|
return -1;
|
|
}
|
|
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "SSL sent %d bytes (encrypted: %d bytes)", length, total_encrypted);
|
|
#endif
|
|
return length;
|
|
#else
|
|
if (socket->ssl) {
|
|
int sent = SSL_write(socket->ssl, data, length);
|
|
if (sent <= 0) {
|
|
int ssl_error = SSL_get_error(socket->ssl, sent);
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_write failed: %d", ssl_error);
|
|
return -1;
|
|
}
|
|
return sent;
|
|
}
|
|
#endif
|
|
}
|
|
#endif
|
|
|
|
#ifdef _WIN32
|
|
int sent = send(socket->socket_fd, data, length, 0);
|
|
#else
|
|
int sent = send(socket->socket_fd, data, length, MSG_NOSIGNAL);
|
|
#endif
|
|
if (sent < 0) {
|
|
#ifdef _WIN32
|
|
int error = WSAGetLastError();
|
|
if (error == WSAEWOULDBLOCK) {
|
|
return 0;
|
|
}
|
|
// connection reset/aborted we mark as disconnected
|
|
if (error == WSAECONNRESET || error == WSAECONNABORTED || error == WSAENETRESET) {
|
|
socket->isConnected = false;
|
|
}
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Send failed: %d", error);
|
|
#endif
|
|
#else
|
|
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
|
return 0;
|
|
}
|
|
if (errno == ECONNRESET || errno == ENOTCONN || errno == EPIPE) {
|
|
socket->isConnected = false;
|
|
}
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Send failed: %s", strerror(errno));
|
|
#endif
|
|
#endif
|
|
return -1;
|
|
}
|
|
|
|
#ifdef DEBUG_NETWORK
|
|
if (sent > 10) {
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Sent %d bytes", sent);
|
|
}
|
|
#endif
|
|
return sent;
|
|
}
|
|
|
|
extern "C" int runt_socket_recv(int socket_id, int max_length, char** out_data) {
|
|
auto it = g_sockets.find(socket_id);
|
|
if (it == g_sockets.end()) return -1;
|
|
|
|
auto& socket = it->second;
|
|
if (!socket->isValid() || !socket->isConnected) return -1;
|
|
|
|
char* buffer = new char[max_length];
|
|
|
|
#ifdef WITH_SSL
|
|
if (socket->useSSL) {
|
|
#ifdef _WIN32
|
|
if (!socket->ssl_context_initialized_) {
|
|
delete[] buffer;
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL context not initialized");
|
|
return -1;
|
|
}
|
|
|
|
while (true) {
|
|
// decrypt for buffered data
|
|
if (!socket->ssl_buffer_.empty()) {
|
|
SecBufferDesc message_desc;
|
|
SecBuffer message_buffers[4];
|
|
|
|
// buffer descriptor for decryption
|
|
message_desc.ulVersion = SECBUFFER_VERSION;
|
|
message_desc.cBuffers = 4;
|
|
message_desc.pBuffers = message_buffers;
|
|
|
|
// input buffer with encrypted data
|
|
message_buffers[0].BufferType = SECBUFFER_DATA;
|
|
message_buffers[0].pvBuffer = socket->ssl_buffer_.data();
|
|
message_buffers[0].cbBuffer = static_cast<unsigned long>(socket->ssl_buffer_.size());
|
|
|
|
// output buffer
|
|
message_buffers[1].BufferType = SECBUFFER_EMPTY;
|
|
message_buffers[1].pvBuffer = nullptr;
|
|
message_buffers[1].cbBuffer = 0;
|
|
|
|
// extra buffer for leftover data
|
|
message_buffers[2].BufferType = SECBUFFER_EMPTY;
|
|
message_buffers[2].pvBuffer = nullptr;
|
|
message_buffers[2].cbBuffer = 0;
|
|
|
|
// stream buffer
|
|
message_buffers[3].BufferType = SECBUFFER_EMPTY;
|
|
message_buffers[3].pvBuffer = nullptr;
|
|
message_buffers[3].cbBuffer = 0;
|
|
|
|
SECURITY_STATUS status = DecryptMessage(
|
|
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
|
|
&message_desc, 0, nullptr);
|
|
|
|
if (status == SEC_E_OK) {
|
|
SecBuffer* data_buffer = nullptr;
|
|
SecBuffer* extra_buffer = nullptr;
|
|
|
|
// find data and extra buffers
|
|
for (int i = 0; i < 4; i++) {
|
|
if (message_buffers[i].BufferType == SECBUFFER_DATA) {
|
|
data_buffer = &message_buffers[i];
|
|
} else if (message_buffers[i].BufferType == SECBUFFER_EXTRA) {
|
|
extra_buffer = &message_buffers[i];
|
|
}
|
|
}
|
|
|
|
if (data_buffer && data_buffer->cbBuffer > 0) {
|
|
int bytes_to_copy = (max_length < static_cast<int>(data_buffer->cbBuffer)) ? max_length : static_cast<int>(data_buffer->cbBuffer);
|
|
memcpy(buffer, data_buffer->pvBuffer, bytes_to_copy);
|
|
|
|
// move extra data to the beginning of the buffer
|
|
if (extra_buffer && extra_buffer->cbBuffer > 0) {
|
|
std::memmove(socket->ssl_buffer_.data(), extra_buffer->pvBuffer, extra_buffer->cbBuffer);
|
|
socket->ssl_buffer_.resize(extra_buffer->cbBuffer);
|
|
} else {
|
|
socket->ssl_buffer_.clear();
|
|
}
|
|
|
|
*out_data = buffer;
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "SSL received %d bytes (decrypted)", bytes_to_copy);
|
|
#endif
|
|
return bytes_to_copy;
|
|
}
|
|
|
|
socket->ssl_buffer_.clear();
|
|
|
|
} else if (status == SEC_E_INCOMPLETE_MESSAGE) {
|
|
// needs more from the socket
|
|
char recv_buffer[8192];
|
|
int received = recv(socket->socket_fd, recv_buffer, sizeof(recv_buffer), 0);
|
|
if (received <= 0) {
|
|
delete[] buffer;
|
|
if (received == 0) {
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "SSL connection closed by server");
|
|
#endif
|
|
return 0;
|
|
} else {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to receive encrypted data for SSL read");
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
socket->ssl_buffer_.insert(socket->ssl_buffer_.end(), recv_buffer, recv_buffer + received);
|
|
continue;
|
|
|
|
} else {
|
|
delete[] buffer;
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "SChannel decrypt failed: 0x%x", status);
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
char recv_buffer[8192];
|
|
int received = recv(socket->socket_fd, recv_buffer, sizeof(recv_buffer), 0);
|
|
if (received <= 0) {
|
|
delete[] buffer;
|
|
if (received == 0) {
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "SSL connection closed by server (no buffered data)");
|
|
#endif
|
|
return 0;
|
|
} else {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to receive encrypted data");
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
socket->ssl_buffer_.insert(socket->ssl_buffer_.end(), recv_buffer, recv_buffer + received);
|
|
}
|
|
#else
|
|
if (socket->ssl) {
|
|
int received = SSL_read(socket->ssl, buffer, max_length);
|
|
if (received <= 0) {
|
|
delete[] buffer;
|
|
int ssl_error = SSL_get_error(socket->ssl, received);
|
|
if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) {
|
|
return 0;
|
|
}
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_read failed: %d", ssl_error);
|
|
return -1;
|
|
}
|
|
*out_data = buffer;
|
|
return received;
|
|
}
|
|
#endif
|
|
}
|
|
#endif
|
|
|
|
int received = recv(socket->socket_fd, buffer, max_length, 0);
|
|
if (received < 0) {
|
|
#ifdef _WIN32
|
|
int error = WSAGetLastError();
|
|
if (error == WSAEWOULDBLOCK) {
|
|
delete[] buffer;
|
|
return 0;
|
|
}
|
|
// WSAECONNRESET (10054) = connection reset by peer
|
|
// WSAECONNABORTED (10053) = connection aborted
|
|
// WSAENETRESET (10052) = Network dropped connection
|
|
// Treat all == disconnected
|
|
if (error == WSAECONNRESET || error == WSAECONNABORTED || error == WSAENETRESET) {
|
|
delete[] buffer;
|
|
socket->isConnected = false;
|
|
return -2;
|
|
}
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Recv failed: %d", error);
|
|
#endif
|
|
#else
|
|
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
|
delete[] buffer;
|
|
return 0;
|
|
}
|
|
if (errno == ECONNRESET || errno == ENOTCONN || errno == EPIPE) {
|
|
delete[] buffer;
|
|
socket->isConnected = false;
|
|
return -2;
|
|
}
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Recv failed: %s", strerror(errno));
|
|
#endif
|
|
#endif
|
|
delete[] buffer;
|
|
return -1;
|
|
}
|
|
|
|
if (received == 0) {
|
|
delete[] buffer;
|
|
socket->isConnected = false;
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Connection closed by peer");
|
|
#endif
|
|
return -2; // Distinct value for connection closed
|
|
}
|
|
|
|
*out_data = buffer;
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Received %d bytes", received);
|
|
#endif
|
|
return received;
|
|
}
|
|
|
|
extern "C" void runt_socket_close(int socket_id) {
|
|
auto it = g_sockets.find(socket_id);
|
|
if (it == g_sockets.end()) return;
|
|
|
|
it->second->close();
|
|
g_sockets.erase(it);
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Socket closed: ID %d", socket_id);
|
|
#endif
|
|
}
|
|
|
|
extern "C" bool runt_socket_is_connected(int socket_id) {
|
|
auto it = g_sockets.find(socket_id);
|
|
if (it == g_sockets.end()) return false;
|
|
|
|
auto& socket = it->second;
|
|
return socket->isValid() && socket->isConnected;
|
|
}
|
|
|
|
extern "C" void runt_socket_set_blocking(int socket_id, bool blocking) {
|
|
auto it = g_sockets.find(socket_id);
|
|
if (it == g_sockets.end()) return;
|
|
|
|
auto& socket = it->second;
|
|
if (!socket->isValid()) return;
|
|
|
|
#ifdef _WIN32
|
|
u_long mode = blocking ? 0 : 1;
|
|
ioctlsocket(socket->socket_fd, FIONBIO, &mode);
|
|
#else
|
|
int flags = fcntl(socket->socket_fd, F_GETFL, 0);
|
|
if (blocking) {
|
|
flags &= ~O_NONBLOCK;
|
|
} else {
|
|
flags |= O_NONBLOCK;
|
|
}
|
|
fcntl(socket->socket_fd, F_SETFL, flags);
|
|
#endif
|
|
|
|
socket->isBlocking = blocking;
|
|
#ifdef DEBUG_NETWORK
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "Socket blocking mode set: %s", blocking ? "blocking" : "non-blocking");
|
|
#endif
|
|
}
|
|
|
|
extern "C" bool runt_socket_select(int socket_id, double timeout_sec) {
|
|
auto it = g_sockets.find(socket_id);
|
|
if (it == g_sockets.end()) return false;
|
|
|
|
auto& socket = it->second;
|
|
if (!socket->isValid()) return false;
|
|
|
|
fd_set readfds;
|
|
FD_ZERO(&readfds);
|
|
FD_SET(socket->socket_fd, &readfds);
|
|
|
|
struct timeval tv;
|
|
tv.tv_sec = (long)timeout_sec;
|
|
tv.tv_usec = (long)((timeout_sec - tv.tv_sec) * 1000000);
|
|
|
|
int result = ::select((int)socket->socket_fd + 1, &readfds, nullptr, nullptr, &tv);
|
|
|
|
if (result > 0 && FD_ISSET(socket->socket_fd, &readfds)) {
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
#ifdef WITH_SSL
|
|
extern "C" bool runt_socket_enable_ssl(int socket_id) {
|
|
auto it = g_sockets.find(socket_id);
|
|
if (it == g_sockets.end()) return false;
|
|
|
|
auto& socket = it->second;
|
|
if (!socket->isValid() || !socket->isConnected) return false;
|
|
|
|
#ifdef _WIN32
|
|
if (socket->useSSL) return true;
|
|
|
|
if (!socket->ssl_cred_handle_) {
|
|
socket->ssl_cred_handle_ = new CredHandle();
|
|
memset(socket->ssl_cred_handle_, 0, sizeof(CredHandle));
|
|
|
|
SCHANNEL_CRED credentials;
|
|
memset(&credentials, 0, sizeof(credentials));
|
|
credentials.dwVersion = SCHANNEL_CRED_VERSION;
|
|
credentials.grbitEnabledProtocols = SP_PROT_TLS1_2;
|
|
credentials.dwFlags = SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_MANUAL_CRED_VALIDATION;
|
|
|
|
TimeStamp expiry;
|
|
SECURITY_STATUS status = AcquireCredentialsHandleA(
|
|
nullptr, const_cast<char*>(UNISP_NAME_A), SECPKG_CRED_OUTBOUND,
|
|
nullptr, &credentials, nullptr, nullptr,
|
|
reinterpret_cast<PCredHandle>(socket->ssl_cred_handle_), &expiry);
|
|
|
|
if (status != SEC_E_OK) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "AcquireCredentialsHandleA failed: 0x%x", status);
|
|
delete reinterpret_cast<CredHandle*>(socket->ssl_cred_handle_);
|
|
socket->ssl_cred_handle_ = nullptr;
|
|
return false;
|
|
}
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "SSL credentials acquired successfully");
|
|
}
|
|
|
|
SecBufferDesc outbuffer_desc;
|
|
SecBuffer outbuffers[1];
|
|
DWORD context_attributes;
|
|
TimeStamp expiry;
|
|
|
|
outbuffers[0].pvBuffer = nullptr;
|
|
outbuffers[0].BufferType = SECBUFFER_TOKEN;
|
|
outbuffers[0].cbBuffer = 0;
|
|
outbuffer_desc.cBuffers = 1;
|
|
outbuffer_desc.pBuffers = outbuffers;
|
|
outbuffer_desc.ulVersion = SECBUFFER_VERSION;
|
|
|
|
socket->ssl_context_handle_ = new CtxtHandle();
|
|
memset(socket->ssl_context_handle_, 0, sizeof(CtxtHandle));
|
|
|
|
|
|
// TODO: Use actual hostname
|
|
SECURITY_STATUS status = InitializeSecurityContextA(
|
|
reinterpret_cast<PCredHandle>(socket->ssl_cred_handle_),
|
|
nullptr,
|
|
const_cast<char*>("localhost"),
|
|
ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT |
|
|
ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR |
|
|
ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM,
|
|
0, SECURITY_NATIVE_DREP, nullptr, 0,
|
|
reinterpret_cast<PCtxtHandle>(socket->ssl_context_handle_),
|
|
&outbuffer_desc, &context_attributes, &expiry);
|
|
|
|
if (status == SEC_I_CONTINUE_NEEDED || status == SEC_E_OK) {
|
|
// send handshake data if available
|
|
if (outbuffers[0].cbBuffer > 0 && outbuffers[0].pvBuffer) {
|
|
int sent = send(socket->socket_fd, reinterpret_cast<char*>(outbuffers[0].pvBuffer), outbuffers[0].cbBuffer, 0);
|
|
if (sent != static_cast<int>(outbuffers[0].cbBuffer)) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to send SSL handshake data");
|
|
FreeContextBuffer(outbuffers[0].pvBuffer);
|
|
return false;
|
|
}
|
|
FreeContextBuffer(outbuffers[0].pvBuffer);
|
|
}
|
|
|
|
socket->ssl_context_initialized_ = true;
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "SSL handshake initiated successfully");
|
|
} else {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "InitializeSecurityContextA failed: 0x%x", status);
|
|
delete reinterpret_cast<CtxtHandle*>(socket->ssl_context_handle_);
|
|
socket->ssl_context_handle_ = nullptr;
|
|
return false;
|
|
}
|
|
#else
|
|
if (socket->ssl) return true;
|
|
socket->ssl_ctx = SSL_CTX_new(TLS_client_method());
|
|
if (!socket->ssl_ctx) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_CTX_new failed");
|
|
return false;
|
|
}
|
|
|
|
socket->ssl = SSL_new(socket->ssl_ctx);
|
|
if (!socket->ssl) {
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_new failed");
|
|
SSL_CTX_free(socket->ssl_ctx);
|
|
socket->ssl_ctx = nullptr;
|
|
return false;
|
|
}
|
|
|
|
SSL_set_fd(socket->ssl, socket->socket_fd);
|
|
|
|
int ret = SSL_connect(socket->ssl);
|
|
if (ret <= 0) {
|
|
int ssl_error = SSL_get_error(socket->ssl, ret);
|
|
kinc_log(KINC_LOG_LEVEL_ERROR, "SSL_connect failed: %d", ssl_error);
|
|
SSL_free(socket->ssl);
|
|
SSL_CTX_free(socket->ssl_ctx);
|
|
socket->ssl = nullptr;
|
|
socket->ssl_ctx = nullptr;
|
|
return false;
|
|
}
|
|
#endif
|
|
|
|
socket->useSSL = true;
|
|
kinc_log(KINC_LOG_LEVEL_INFO, "SSL enabled successfully for socket %d", socket_id);
|
|
return true;
|
|
}
|
|
#endif
|
|
|
|
extern "C" void runt_socket_cleanup() {
|
|
cleanup_networking();
|
|
}
|