Files
LNXRNT/Sources/websocket.cpp

2865 lines
120 KiB
C++
Raw Normal View History

2026-02-20 23:40:15 -08:00
#ifdef WITH_NETWORKING
#include "websocket.h"
#include "socket_optimization.h"
#include <kinc/log.h>
#include <string.h>
#include <emmintrin.h> // SSE2
#include <unordered_map>
#include <thread>
#include <atomic>
#include <regex>
#include <algorithm>
#ifdef _WIN32
#define NOMINMAX
#include <winsock2.h>
#include <ws2tcpip.h>
#include <windows.h>
#pragma comment(lib, "ws2_32.lib")
#else
#define SOCKET int
#define INVALID_SOCKET -1
#define SOCKET_ERROR -1
#define closesocket close
#endif
#include <vector>
#include <chrono>
#include <queue>
#include <mutex>
#include "ring_buffer.h"
#ifdef WITH_SSL
#ifdef _WIN32
#define SECURITY_WIN32
#include <wincrypt.h>
#include <schannel.h>
#include <security.h>
#include <sspi.h>
#pragma comment(lib, "crypt32.lib")
#pragma comment(lib, "secur32.lib")
#else
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/bio.h>
#include <openssl/x509.h>
#include <openssl/x509v3.h>
#include <openssl/pem.h>
#include <openssl/sha.h>
#include <openssl/rand.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netdb.h>
#include <unistd.h>
#include <fcntl.h>
#endif
#endif
using namespace v8;
namespace WebSocketWrapper {
// global state
static std::unordered_map<int, std::unique_ptr<WebSocketClient>> active_websockets;
static std::atomic<int> next_websocket_id{1};
static std::atomic<bool> winsock_initialized{false};
void initialize() {
if (!winsock_initialized.exchange(true)) {
WSADATA wsaData;
int result = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (result != 0) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WSAStartup failed: %d", result);
}
#ifdef DEBUG_NETWORK
else {
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket client support initialized successfully");
}
#endif
}
}
void cleanup() {
active_websockets.clear();
if (winsock_initialized.exchange(false)) {
WSACleanup();
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket cleanup complete");
#endif
}
}
void processEvents() {
try {
for (auto& [id, client] : active_websockets) {
if (client) {
client->processEvents();
}
}
} catch (...) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: CRASH in global processEvents");
#endif
}
}
int createWebSocketConnection(v8::Isolate* isolate, const std::string& url) {
int id = next_websocket_id++;
auto client = std::make_unique<WebSocketClient>(isolate, getGlobalContext(), url);
active_websockets[id] = std::move(client);
return id;
}
class SimpleWebSocketClient {
private:
SOCKET socket_;
std::thread* thread_;
std::atomic<bool> connected_{false};
std::atomic<bool> should_stop_{false};
std::string host_;
int port_;
std::string path_;
public:
SimpleWebSocketClient(const std::string& url) : socket_(INVALID_SOCKET), thread_(nullptr) {
parseUrl(url);
}
~SimpleWebSocketClient() {
disconnect();
}
bool connect() {
socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (socket_ == INVALID_SOCKET) {
return false;
}
SocketOptimization::optimizeWebSocket(static_cast<int>(socket_));
sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_port = htons(port_);
inet_pton(AF_INET, host_.c_str(), &addr.sin_addr);
if (::connect(socket_, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) {
closesocket(socket_);
socket_ = INVALID_SOCKET;
return false;
}
std::string handshake =
"GET " + path_ + " HTTP/1.1\r\n"
"Host: " + host_ + ":" + std::to_string(port_) + "\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
"Sec-WebSocket-Version: 13\r\n\r\n";
if (send(socket_, handshake.c_str(), static_cast<int>(handshake.length()), 0) == SOCKET_ERROR) {
closesocket(socket_);
socket_ = INVALID_SOCKET;
return false;
}
char buffer[SocketOptimization::SMALL_BUFFER_SIZE];
int bytes = recv(socket_, buffer, sizeof(buffer) - 1, 0);
if (bytes > 0) {
buffer[bytes] = '\0';
if (strstr(buffer, "101 Switching Protocols")) {
connected_ = true;
return true;
}
}
closesocket(socket_);
socket_ = INVALID_SOCKET;
return false;
}
void disconnect() {
should_stop_ = true;
if (socket_ != INVALID_SOCKET) {
closesocket(socket_);
socket_ = INVALID_SOCKET;
}
if (thread_ && thread_->joinable()) {
thread_->join();
delete thread_;
thread_ = nullptr;
}
connected_ = false;
}
bool isConnected() const {
return connected_;
}
void sendText(const std::string& data) {
if (!connected_ || socket_ == INVALID_SOCKET) return;
// text frame format FIN + text frame
std::vector<uint8_t> frame;
frame.push_back(0x81);
if (data.length() < 126) {
frame.push_back(0x80 | static_cast<uint8_t>(data.length())); // MASK + length
} else {
// TODO: support longer messages
return;
}
// adds simplified masking key
uint8_t mask[4] = {0x12, 0x34, 0x56, 0x78};
frame.insert(frame.end(), mask, mask + 4);
// adds the masked payload
for (size_t i = 0; i < data.length(); i++) {
frame.push_back(data[i] ^ mask[i % 4]);
}
send(socket_, (char*)frame.data(), static_cast<int>(frame.size()), 0);
}
private:
void parseUrl(const std::string& url) {
std::regex url_regex(R"(^wss?://([^:/]+)(?::(\d+))?(/.*)?$)");
std::smatch matches;
if (std::regex_match(url, matches, url_regex)) {
host_ = matches[1].str();
port_ = matches[2].matched ? std::stoi(matches[2].str()) : 80;
path_ = matches[3].matched ? matches[3].str() : "/";
} else {
host_ = "localhost";
port_ = 80;
path_ = "/";
}
}
};
WebSocketClient::WebSocketClient(Isolate* isolate, Global<Context>* global_context, const std::string& url)
: isolate_(isolate), global_context_(global_context), url_(url), ready_state_(CONNECTING),
is_ssl_(false), buffered_amount_(0), ws_(nullptr), ssl_ws_(nullptr)
#ifdef WITH_SSL
#ifdef _WIN32
, ssl_cred_handle_(nullptr), ssl_context_handle_(nullptr), ssl_context_initialized_(false), ssl_socket_(-1)
#else
, ssl_ctx_(nullptr), ssl_(nullptr), ssl_initialized_(false)
#endif
#endif
{
#ifdef WITH_SSL
#ifdef _WIN32
memset(&ssl_cred_handle_, 0, sizeof(ssl_cred_handle_));
memset(&ssl_context_handle_, 0, sizeof(ssl_context_handle_));
#endif
#endif
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "Creating WebSocket connection to: %s", url.c_str());
#endif
std::string host, path;
int port;
parseUrl(url, host, port, path);
is_ssl_ = (url.find("wss://") == 0);
ready_state_ = CONNECTING;
// start connection using thread pool
GlobalThreadPool::getInstance().enqueue([this, host, port, path]() {
try {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket connecting to %s:%d%s (SSL: %s)",
host.c_str(), port, path.c_str(), is_ssl_ ? "yes" : "no");
#endif
if (connectToServer(host, port, path)) {
ready_state_ = OPEN;
handleOpen();
messageLoop();
} else {
ready_state_ = CLOSED;
handleError("Connection failed");
}
} catch (...) {
ready_state_ = CLOSED;
}
});
}
WebSocketClient::~WebSocketClient() {
if (ready_state_ == OPEN) {
close();
}
#ifdef WITH_SSL
cleanupWSL();
#endif
}
bool WebSocketClient::parseUrl(const std::string& url, std::string& host, int& port, std::string& path) {
std::regex url_regex(R"(^(wss?)://([^:/]+)(?::(\d+))?(/.*)?$)");
std::smatch matches;
if (!std::regex_match(url, matches, url_regex)) {
return false;
}
std::string protocol = matches[1].str();
host = matches[2].str();
if (matches[3].matched) {
port = std::stoi(matches[3].str());
} else {
port = (protocol == "wss") ? 443 : 80;
}
path = matches[4].matched ? matches[4].str() : "/";
is_ssl_ = (protocol == "wss");
return true;
}
void WebSocketClient::send(const std::string& data) {
if (ready_state_ != OPEN) {
handleError("WebSocket is not open");
return;
}
if (ws_ == nullptr) {
handleError("WebSocket: No socket available for sending");
return;
}
try {
buffered_amount_ += static_cast<int>(data.length());
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket sending: %s", data.c_str());
#endif
// create and send WebSocket frame
std::vector<uint8_t> frame = createWebSocketFrame(data);
int sock = static_cast<int>(reinterpret_cast<intptr_t>(ws_));
int sendResult;
if (is_ssl_ && ssl_context_initialized_) {
sendResult = webSocketSSLSend(frame.data(), static_cast<int>(frame.size()));
} else {
sendResult = ::send(sock, reinterpret_cast<const char*>(frame.data()), static_cast<int>(frame.size()), 0);
}
if (sendResult <= 0) {
buffered_amount_ = 0;
handleError("WebSocket: Failed to send message");
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Failed to send %d bytes to server", (int)frame.size());
#endif
return;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Successfully sent %d bytes to server", sendResult);
#endif
buffered_amount_ = 0;
} catch (const std::exception& e) {
buffered_amount_ = 0;
handleError(std::string("Send error: ") + e.what());
}
}
void WebSocketClient::sendBinary(const std::string& data) {
if (ready_state_ != OPEN) {
handleError("WebSocket is not open");
return;
}
if (ws_ == nullptr) {
handleError("WebSocket: No socket available for sending");
return;
}
try {
buffered_amount_ += static_cast<int>(data.length());
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket sending binary: %d bytes", (int)data.length());
#endif
// binary frame with opcode 0x02
std::vector<uint8_t> frame = createWebSocketBinaryFrame(data);
int sock = static_cast<int>(reinterpret_cast<intptr_t>(ws_));
int sendResult;
if (is_ssl_ && ssl_context_initialized_) {
sendResult = webSocketSSLSend(frame.data(), static_cast<int>(frame.size()));
} else {
sendResult = ::send(sock, reinterpret_cast<const char*>(frame.data()), static_cast<int>(frame.size()), 0);
}
if (sendResult <= 0) {
buffered_amount_ = 0;
handleError("WebSocket: Failed to send binary message");
return;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Successfully sent %d binary bytes to server", sendResult);
#endif
buffered_amount_ = 0;
} catch (const std::exception& e) {
buffered_amount_ = 0;
handleError(std::string("Send binary error: ") + e.what());
}
}
void WebSocketClient::close(int code, const std::string& reason) {
if (ready_state_ == CLOSING || ready_state_ == CLOSED) {
return;
}
ready_state_ = CLOSING;
ready_state_ = CLOSED;
handleClose(code, reason);
}
void WebSocketClient::setOnOpen(Local<Function> callback) {
on_open_.Reset(isolate_, callback);
}
void WebSocketClient::setOnMessage(Local<Function> callback) {
on_message_.Reset(isolate_, callback);
}
void WebSocketClient::setOnError(Local<Function> callback) {
on_error_.Reset(isolate_, callback);
}
void WebSocketClient::setOnClose(Local<Function> callback) {
on_close_.Reset(isolate_, callback);
}
// event handlers called from worker thread that queue events safely
void WebSocketClient::handleOpen() {
event_queue_.push(WebSocketEvent(EVENT_OPEN));
}
void WebSocketClient::handleMessage(const std::string& message) {
// push directly to lock free queue
event_queue_.push(WebSocketEvent(EVENT_MESSAGE, message));
}
void WebSocketClient::handleError(const std::string& error) {
try {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket error: %s", error.c_str());
#endif
event_queue_.push(WebSocketEvent(EVENT_ERROR, error));
} catch (...) {
// silent crash
}
}
void WebSocketClient::handleClose(int code, const std::string& reason) {
ready_state_ = CLOSED;
event_queue_.push(WebSocketEvent(EVENT_CLOSE, code, reason));
}
// process events on main thread with lock free queue and batching
void WebSocketClient::processEvents() {
try {
// collect for batched processing and pre allocate
std::vector<WebSocketEvent> pending_events;
pending_events.reserve(64);
WebSocketEvent event;
while (event_queue_.try_pop(event)) {
pending_events.push_back(std::move(event));
}
if (pending_events.empty()) return;
// message events for batch processing
std::vector<std::string> message_batch;
message_batch.reserve(pending_events.size());
// process non message events and collect messages
for (auto& evt : pending_events) {
switch (evt.type) {
case EVENT_OPEN:
processOpenEvent();
break;
case EVENT_MESSAGE:
message_batch.push_back(std::move(evt.data));
break;
case EVENT_ERROR:
processErrorEvent(evt.data);
break;
case EVENT_CLOSE:
processCloseEvent(evt.code, evt.reason);
break;
}
}
// batch process all messages
if (!message_batch.empty()) {
processMessageBatch(message_batch);
}
} catch (...) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: CRASH in processEvents");
}
}
void WebSocketClient::processMessageBatch(const std::vector<std::string>& messages) {
if (on_message_.IsEmpty() || messages.empty()) return;
Locker locker{isolate_};
Isolate::Scope isolate_scope(isolate_);
HandleScope handle_scope(isolate_);
Local<Context> context = Local<Context>::New(isolate_, *global_context_);
Context::Scope context_scope(context);
Local<Function> func = on_message_.Get(isolate_);
Local<String> dataKey = String::NewFromUtf8(isolate_, "data").ToLocalChecked();
for (const auto& message : messages) {
Local<Object> event = Object::New(isolate_);
Local<ArrayBuffer> arrayBuffer = ArrayBuffer::New(isolate_, message.length());
std::shared_ptr<BackingStore> backing = arrayBuffer->GetBackingStore();
memcpy(backing->Data(), message.c_str(), message.length());
event->Set(context, dataKey, arrayBuffer);
Local<Value> argv[1] = { event };
TryCatch try_catch(isolate_);
func->Call(context, context->Global(), 1, argv);
}
}
// called on main thread
void WebSocketClient::processOpenEvent() {
if (!on_open_.IsEmpty()) {
Local<Value> argv[1] = {};
callCallback(on_open_, 0, argv);
}
}
void WebSocketClient::processMessageEvent(const std::string& message) {
if (!on_message_.IsEmpty()) {
Locker locker{isolate_};
Isolate::Scope isolate_scope(isolate_);
HandleScope handle_scope(isolate_);
Local<Context> context = Local<Context>::New(isolate_, *global_context_);
Context::Scope context_scope(context);
Local<Object> event = Object::New(isolate_);
Local<ArrayBuffer> arrayBuffer = ArrayBuffer::New(isolate_, message.length());
std::shared_ptr<BackingStore> backing = arrayBuffer->GetBackingStore();
memcpy(backing->Data(), message.c_str(), message.length());
event->Set(context, String::NewFromUtf8(isolate_, "data").ToLocalChecked(), arrayBuffer);
Local<Value> argv[1] = { event };
callCallback(on_message_, 1, argv);
}
}
void WebSocketClient::processErrorEvent(const std::string& error) {
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: processErrorEvent START");
try {
if (!on_error_.IsEmpty()) {
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: processErrorEvent getting locker");
Locker locker{isolate_};
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: processErrorEvent setting up V8");
Isolate::Scope isolate_scope(isolate_);
HandleScope handle_scope(isolate_);
Local<Context> context = Local<Context>::New(isolate_, *global_context_);
Context::Scope context_scope(context);
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: processErrorEvent creating event object");
Local<Object> event = Object::New(isolate_);
event->Set(context, String::NewFromUtf8(isolate_, "message").ToLocalChecked(),
String::NewFromUtf8(isolate_, error.c_str()).ToLocalChecked());
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: processErrorEvent calling callback");
Local<Value> argv[1] = { event };
callCallback(on_error_, 1, argv);
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: processErrorEvent DONE");
}
} catch (...) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: CRASH in processErrorEvent");
}
}
void WebSocketClient::processCloseEvent(int code, const std::string& reason) {
if (!on_close_.IsEmpty()) {
Locker locker{isolate_};
Isolate::Scope isolate_scope(isolate_);
HandleScope handle_scope(isolate_);
Local<Context> context = Local<Context>::New(isolate_, *global_context_);
Context::Scope context_scope(context);
Local<Object> event = Object::New(isolate_);
event->Set(context, String::NewFromUtf8(isolate_, "code").ToLocalChecked(),
Number::New(isolate_, code));
event->Set(context, String::NewFromUtf8(isolate_, "reason").ToLocalChecked(),
String::NewFromUtf8(isolate_, reason.c_str()).ToLocalChecked());
Local<Value> argv[1] = { event };
callCallback(on_close_, 1, argv);
}
}
void WebSocketClient::callCallback(Global<Function>& callback, int argc, Local<Value> argv[]) {
if (callback.IsEmpty()) return;
Locker locker{isolate_};
Isolate::Scope isolate_scope(isolate_);
HandleScope handle_scope(isolate_);
Local<Context> context = Local<Context>::New(isolate_, *global_context_);
Context::Scope context_scope(context);
TryCatch try_catch(isolate_);
Local<Function> func = callback.Get(isolate_);
Local<Value> result;
if (func->Call(context, context->Global(), argc, argv).ToLocal(&result)) {
// Success
if (try_catch.HasCaught()) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket callback error");
#endif
}
}
}
int WebSocketClient::webSocketSSLReceive(char* buffer, int bufferSize) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: webSocketSSLReceive called with bufferSize=%d", bufferSize);
#endif
if (!ssl_context_initialized_) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: SSL not initialized for receive");
#endif
return -1;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL context initialized, starting receive process");
#endif
// first try to decrypt any existing buffered data
if (!ssl_buffer_.empty()) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Attempting to decrypt existing buffer data (%d bytes)", (int)ssl_buffer_.size());
#endif
SecBuffer buffers[4];
SecBufferDesc message;
// set up decrypt buffers with existing data
buffers[0].pvBuffer = ssl_buffer_.data();
buffers[0].cbBuffer = static_cast<unsigned long>(ssl_buffer_.size());
buffers[0].BufferType = SECBUFFER_DATA;
buffers[1].BufferType = SECBUFFER_EMPTY;
buffers[2].BufferType = SECBUFFER_EMPTY;
buffers[3].BufferType = SECBUFFER_EMPTY;
message.ulVersion = SECBUFFER_VERSION;
message.cBuffers = 4;
message.pBuffers = buffers;
SECURITY_STATUS status = DecryptMessage(reinterpret_cast<PCtxtHandle>(ssl_context_handle_), &message, 0, nullptr);
if (status == SEC_E_OK) {
// find decrypted data buffer
for (int i = 0; i < 4; i++) {
if (buffers[i].BufferType == SECBUFFER_DATA && buffers[i].cbBuffer > 0) {
int decrypted_size = (std::min)(static_cast<int>(buffers[i].cbBuffer), bufferSize);
memcpy(buffer, buffers[i].pvBuffer, decrypted_size);
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Successfully decrypted %d bytes from existing buffer", decrypted_size);
#endif
// remove processed data from buffer and keep extra data
if (buffers[3].BufferType == SECBUFFER_EXTRA && buffers[3].cbBuffer > 0) {
memmove(ssl_buffer_.data(), (char*)buffers[3].pvBuffer, buffers[3].cbBuffer);
ssl_buffer_.resize(buffers[3].cbBuffer);
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Kept %d bytes of extra data in buffer", buffers[3].cbBuffer);
#endif
} else {
ssl_buffer_.clear();
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Cleared SSL buffer after successful decryption");
#endif
}
return decrypted_size;
}
}
} else if (status != SEC_E_INCOMPLETE_MESSAGE) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: SSL decrypt failed with status 0x%x, clearing buffer", status);
#endif
ssl_buffer_.clear();
return -1;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Incomplete SSL message, need more data (current buffer: %d bytes)", (int)ssl_buffer_.size());
#endif
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Reading more SSL data from socket %d with proper polling", ssl_socket_);
#endif
fd_set readfds;
struct timeval timeout;
FD_ZERO(&readfds);
FD_SET(ssl_socket_, &readfds);
timeout.tv_sec = 5;
timeout.tv_usec = 0;
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Waiting for WebSocket handshake response on socket %d with 5s timeout", ssl_socket_);
#endif
int select_result = select(0, &readfds, nullptr, nullptr, &timeout);
if (select_result == SOCKET_ERROR) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: select() failed: %d", WSAGetLastError());
#endif
return -1;
} else if (select_result == 0) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Timeout waiting for WebSocket handshake response (5s) - server may not support WebSocket protocol");
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: This could indicate: 1) Server doesn't support WebSocket, 2) Network issue, 3) Server overloaded");
#endif
return -1;
} else if (!FD_ISSET(ssl_socket_, &readfds)) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Socket not ready for SSL reading after select");
#endif
return -1;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL data available, reading from socket %d", ssl_socket_);
#endif
char temp_buffer[SocketOptimization::SMALL_BUFFER_SIZE];
int raw_bytes = recv(ssl_socket_, temp_buffer, sizeof(temp_buffer), 0);
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: recv returned %d bytes of raw SSL data", raw_bytes);
#endif
if (raw_bytes <= 0) {
if (raw_bytes == 0) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: SSL connection closed by server during handshake");
#endif
} else {
int recv_error = WSAGetLastError();
if (recv_error == WSAEWOULDBLOCK) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: No SSL data available right now (WSAEWOULDBLOCK) - normal for non-blocking sockets");
#endif
return 0;
} else {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: recv failed with SSL error %d", recv_error);
#endif
}
}
return -1;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Successfully received %d bytes of raw SSL data, adding to buffer (current size: %d)", raw_bytes, (int)ssl_buffer_.size());
#endif
// add new data to SSL buffer
size_t old_size = ssl_buffer_.size();
ssl_buffer_.insert(ssl_buffer_.end(), temp_buffer, temp_buffer + raw_bytes);
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL buffer updated: %d -> %d bytes", (int)old_size, (int)ssl_buffer_.size());
#endif
// try to decrypt the accumulated data and set up decrypt buffers with all accumulated data
SecBuffer buffers[4];
SecBufferDesc message;
buffers[0].pvBuffer = ssl_buffer_.data();
buffers[0].cbBuffer = static_cast<unsigned long>(ssl_buffer_.size());
buffers[0].BufferType = SECBUFFER_DATA;
buffers[1].BufferType = SECBUFFER_EMPTY;
buffers[2].BufferType = SECBUFFER_EMPTY;
buffers[3].BufferType = SECBUFFER_EMPTY;
message.ulVersion = SECBUFFER_VERSION;
message.cBuffers = 4;
message.pBuffers = buffers;
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Attempting SSL decryption of %d bytes", (int)ssl_buffer_.size());
#endif
SECURITY_STATUS status = DecryptMessage(reinterpret_cast<PCtxtHandle>(ssl_context_handle_), &message, 0, nullptr);
if (status == SEC_E_OK) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL decryption successful");
#endif
for (int i = 0; i < 4; i++) {
if (buffers[i].BufferType == SECBUFFER_DATA && buffers[i].cbBuffer > 0) {
int decrypted_size = (std::min)(static_cast<int>(buffers[i].cbBuffer), bufferSize);
memcpy(buffer, buffers[i].pvBuffer, decrypted_size);
if (decrypted_size < bufferSize) {
buffer[decrypted_size] = '\0';
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Successfully decrypted %d bytes of SSL handshake response", decrypted_size);
#endif
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Handshake response preview: %.200s...", buffer);
#endif
if (buffers[3].BufferType == SECBUFFER_EXTRA && buffers[3].cbBuffer > 0) {
memmove(ssl_buffer_.data(), (char*)buffers[3].pvBuffer, buffers[3].cbBuffer);
ssl_buffer_.resize(buffers[3].cbBuffer);
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Kept %d bytes of extra SSL data", buffers[3].cbBuffer);
#endif
} else {
ssl_buffer_.clear();
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Cleared SSL buffer after successful handshake decryption");
#endif
}
return decrypted_size;
}
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL decrypt succeeded but no handshake data found");
#endif
return 0;
} else if (status == SEC_E_INCOMPLETE_MESSAGE) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Incomplete SSL message, need more handshake data (buffer size: %d)", (int)ssl_buffer_.size());
#endif
return 0; // need more data so we keep buffer intact and try again
} else {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: SSL decrypt failed with handshake status 0x%x, clearing buffer", status);
#endif
ssl_buffer_.clear();
return -1;
}
}
bool WebSocketClient::connectToServer(const std::string& host, int port, const std::string& path) {
try {
int sock = static_cast<int>(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
if (sock == -1) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Failed to create socket");
#endif
return false;
}
sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
if (inet_pton(AF_INET, host.c_str(), &addr.sin_addr) != 1) {
// try DNS resolution
struct addrinfo hints = {}, *result;
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
if (getaddrinfo(host.c_str(), nullptr, &hints, &result) != 0) {
#ifdef _WIN32
closesocket(sock);
#else
close(sock);
#endif
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Failed to resolve host %s", host.c_str());
#endif
return false;
}
addr.sin_addr = ((sockaddr_in*)result->ai_addr)->sin_addr;
freeaddrinfo(result);
}
if (connect(sock, (sockaddr*)&addr, sizeof(addr)) == -1) {
#ifdef _WIN32
closesocket(sock);
#else
close(sock);
#endif
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Failed to connect to %s:%d", host.c_str(), port);
#endif
return false;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: TCP connected to %s:%d", host.c_str(), port);
#endif
ws_ = reinterpret_cast<void*>(static_cast<intptr_t>(sock));
if (is_ssl_ && port == 443) {
if (!initializeWebSocketSSL(sock, host)) {
#ifdef _WIN32
closesocket(sock);
#else
close(sock);
#endif
ws_ = nullptr;
return false;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL/TLS handshake completed");
#endif
}
if (!performWebSocketHandshake(sock, host, port, path)) {
if (is_ssl_) cleanupWebSocketSSL();
#ifdef _WIN32
closesocket(sock);
#else
close(sock);
#endif
ws_ = nullptr;
return false;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Handshake completed successfully");
#endif
ready_state_ = OPEN;
fireOpenEvent();
return true;
} catch (const std::exception& e) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Connection error: %s", e.what());
#endif
return false;
}
}
// RFC 6455 handshake implementation
bool WebSocketClient::performWebSocketHandshake(int sock, const std::string& host, int port, const std::string& path) {
try {
std::string webSocketKey = generateWebSocketKey();
std::string hostHeader = host;
// dont include port for standard ports (80 for HTTP, 443 for HTTPS)
if (!((is_ssl_ && port == 443) || (!is_ssl_ && port == 80))) {
hostHeader += ":" + std::to_string(port);
}
std::string handshakeRequest =
"GET " + path + " HTTP/1.1\r\n"
"Host: " + hostHeader + "\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: " + webSocketKey + "\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Origin: https://" + host + "\r\n"
"User-Agent: RunT/1.0\r\n"
"\r\n";
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Sending handshake request");
#endif
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Handshake request content:\n%s", handshakeRequest.c_str());
#endif
int sendResult;
if (is_ssl_ && ssl_context_initialized_) {
sendResult = webSocketSSLSend(handshakeRequest.c_str(), static_cast<int>(handshakeRequest.length()));
} else {
sendResult = ::send(sock, handshakeRequest.c_str(), static_cast<int>(handshakeRequest.length()), 0);
}
if (sendResult <= 0) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Failed to send handshake request");
#endif
return false;
}
char buffer[SocketOptimization::SMALL_BUFFER_SIZE];
int bytesReceived;
if (is_ssl_ && ssl_context_initialized_) {
bytesReceived = webSocketSSLReceive(buffer, sizeof(buffer) - 1);
} else {
bytesReceived = recv(sock, buffer, sizeof(buffer) - 1, 0);
}
if (bytesReceived <= 0) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Failed to receive handshake response");
#endif
return false;
}
buffer[bytesReceived] = '\0';
std::string response(buffer);
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Received handshake response");
#endif
if (response.find("HTTP/1.1 101") != 0) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Invalid handshake response status");
#endif
return false;
}
// Sec-WebSocket-Accept case insensitive per RFC 7230
std::string expectedAccept = base64Encode(sha1Hash(webSocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
// convert response to lowercase for case insensitive header search
std::string responseLower = response;
std::transform(responseLower.begin(), responseLower.end(), responseLower.begin(), ::tolower);
size_t acceptPos = responseLower.find("sec-websocket-accept: ");
if (acceptPos == std::string::npos) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Missing Sec-WebSocket-Accept header");
#endif
return false;
}
acceptPos += 22; // is the length of "sec-websocket-accept: "
size_t acceptEnd = response.find("\r\n", acceptPos);
std::string actualAccept = response.substr(acceptPos, acceptEnd - acceptPos);
if (actualAccept != expectedAccept) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Invalid Sec-WebSocket-Accept value");
#endif
return false;
}
return true;
} catch (const std::exception& e) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Handshake error: %s", e.what());
#endif
return false;
}
}
// use event queue for thread safety
void WebSocketClient::fireOpenEvent() {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Connection opened");
#endif
event_queue_.push(WebSocketEvent(EVENT_OPEN));
}
void WebSocketClient::fireMessageEvent(const std::string& message, bool binary) {
#ifdef DEBUG_NETWORK
// TODO: only log first 100 chars to avoid stack overflow in kinc_log for large messages
if (message.length() > 100) {
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Message received (%zu bytes): %.100s...", message.length(), message.c_str());
} else {
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Message received: %s", message.c_str());
}
#endif
event_queue_.push(WebSocketEvent(EVENT_MESSAGE, message));
}
void WebSocketClient::fireErrorEvent(const std::string& error) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Error: %s", error.c_str());
#endif
event_queue_.push(WebSocketEvent(EVENT_ERROR, error));
}
void WebSocketClient::fireCloseEvent(int code, const std::string& reason) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Connection closed: %d - %s", code, reason.c_str());
#endif
ready_state_ = CLOSED;
event_queue_.push(WebSocketEvent(EVENT_CLOSE, code, reason));
}
// complete frame helper for ring buffer messageLoop
void WebSocketClient::processFrame(uint8_t opcode, bool fin, const std::vector<uint8_t>& payload) {
switch (opcode) {
case 0x0: // continuation frame
case 0x1: // text frame
case 0x2: // binary frame
if (fin) {
std::string message(payload.begin(), payload.end());
// force binary because we are fancy
fireMessageEvent(message, true);
} else {
partial_frame_buffer_ = payload;
expecting_continuation_ = true;
continuation_opcode_ = opcode;
}
break;
case 0x8: // close frame
{
uint16_t close_code = 1000;
std::string close_reason;
if (payload.size() >= 2) {
close_code = (payload[0] << 8) | payload[1];
if (payload.size() > 2) {
close_reason = std::string(payload.begin() + 2, payload.end());
}
}
fireCloseEvent(close_code, close_reason);
}
break;
case 0x9: // ping frame
{
std::vector<uint8_t> pongFrame = createPongFrame(payload);
int sock = static_cast<int>(reinterpret_cast<intptr_t>(ws_));
if (is_ssl_ && ssl_context_initialized_) {
webSocketSSLSend(pongFrame.data(), static_cast<int>(pongFrame.size()));
} else {
::send(sock, reinterpret_cast<const char*>(pongFrame.data()), static_cast<int>(pongFrame.size()), 0);
}
}
break;
case 0xA: // pong frame
break;
default:
kinc_log(KINC_LOG_LEVEL_WARNING, "WebSocket: Unknown opcode 0x%X", opcode);
break;
}
}
// following uWebSockets
void WebSocketClient::messageLoop() {
if (ws_ == nullptr) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: No socket available for message loop");
#endif
return;
}
int sock = static_cast<int>(reinterpret_cast<intptr_t>(ws_));
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Starting message loop");
#endif
// 512KB heap allocated shared receive
const size_t LIBUS_RECV_BUFFER_LENGTH = 524288;
RingBuffer ringBuf(LIBUS_RECV_BUFFER_LENGTH);
// fragmentBuffer for accumulating message fragments grows dynamically on heap when message spans multiple recv() calls
std::string fragmentBuffer;
// state machine
bool wantsHead = true; // true = expecting frame header, false = continuation
uint64_t remainingBytes = 0; // bytes remaining for current message payload
uint8_t currentOpcode = 0; // opcode of current message being assembled
bool currentFin = false; // FIN flag of current frame
while (ready_state_ == OPEN) {
size_t contiguousSpace;
uint8_t* writePtr = ringBuf.getWritePtr(&contiguousSpace);
if (contiguousSpace == 0) {
if (!wantsHead && ringBuf.size() > 0) {
size_t available = ringBuf.size();
size_t toRead = (available < remainingBytes) ? available : (size_t)remainingBytes;
std::vector<uint8_t> chunk(toRead);
ringBuf.read(chunk.data(), toRead);
fragmentBuffer.append(reinterpret_cast<char*>(chunk.data()), toRead);
remainingBytes -= toRead;
if (remainingBytes == 0 && currentFin) {
std::vector<uint8_t> payload(fragmentBuffer.begin(), fragmentBuffer.end());
processFrame(currentOpcode, currentFin, payload);
fragmentBuffer.clear();
wantsHead = true;
}
continue;
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_WARNING, "WebSocket: Buffer full (%zu bytes)", LIBUS_RECV_BUFFER_LENGTH);
#endif
break;
}
int bytesReceived;
if (is_ssl_ && ssl_context_initialized_) {
bytesReceived = webSocketSSLReceive(reinterpret_cast<char*>(writePtr), static_cast<int>(contiguousSpace));
} else {
bytesReceived = recv(sock, reinterpret_cast<char*>(writePtr), static_cast<int>(contiguousSpace), 0);
}
if (bytesReceived <= 0) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Connection closed by server");
#endif
break;
}
ringBuf.commitWrite(bytesReceived);
if (!wantsHead) {
size_t available = ringBuf.size();
size_t toRead = (available < remainingBytes) ? available : (size_t)remainingBytes;
std::vector<uint8_t> chunk(toRead);
ringBuf.read(chunk.data(), toRead);
fragmentBuffer.append(reinterpret_cast<char*>(chunk.data()), toRead);
remainingBytes -= toRead;
if (remainingBytes == 0 && currentFin) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Large message complete (%zu bytes)", fragmentBuffer.size());
std::vector<uint8_t> payload(fragmentBuffer.begin(), fragmentBuffer.end());
#endif
processFrame(currentOpcode, currentFin, payload);
fragmentBuffer.clear();
wantsHead = true;
}
continue;
}
while (ringBuf.size() >= 2) {
WsFrameInfo frameInfo = {}; // Zero-initialize for safe partial-parse checking
WsFrameResult result = parseWsFrameHeader(ringBuf, frameInfo);
if (result == WsFrameResult::WS_INCOMPLETE) {
// can mean many things, we follow uWebsockets
if (frameInfo.header_size > 0 &&
ringBuf.size() >= frameInfo.header_size &&
frameInfo.payload_length > LIBUS_RECV_BUFFER_LENGTH / 2) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Large payload (%llu bytes), using fragmentBuffer",
(unsigned long long)frameInfo.payload_length);
#endif
fragmentBuffer.reserve((size_t)frameInfo.payload_length);
currentOpcode = frameInfo.opcode;
currentFin = frameInfo.fin;
remainingBytes = frameInfo.payload_length;
wantsHead = false;
ringBuf.consume(frameInfo.header_size);
break;
}
break;
}
if (result == WsFrameResult::WS_ERROR) {
// so we resync and skip to next valid frame header
ringBuf.consume(1);
continue;
}
if (!frameInfo.fin || fragmentBuffer.length() > 0) {
// first fragment reserves, subsequent appends
if (fragmentBuffer.length() == 0) {
fragmentBuffer.reserve((size_t)frameInfo.payload_length);
currentOpcode = frameInfo.opcode ? frameInfo.opcode : currentOpcode;
}
std::vector<uint8_t> payload((size_t)frameInfo.payload_length);
extractWsPayload(ringBuf, frameInfo, payload.data());
fragmentBuffer.append(reinterpret_cast<char*>(payload.data()), payload.size());
if (frameInfo.fin) {
std::vector<uint8_t> fullPayload(fragmentBuffer.begin(), fragmentBuffer.end());
processFrame(currentOpcode, true, fullPayload);
fragmentBuffer.clear();
}
} else {
std::vector<uint8_t> payload((size_t)frameInfo.payload_length);
extractWsPayload(ringBuf, frameInfo, payload.data());
processFrame(frameInfo.opcode, frameInfo.fin, payload);
}
}
}
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Message loop completed");
#endif
}
// frame parsing (RFC 6455)
bool WebSocketClient::parseWebSocketFrame(const std::vector<uint8_t>& buffer, size_t& offset) {
size_t original_offset = offset;
if (offset + 2 > buffer.size()) {
return false;
}
uint8_t byte1 = buffer[offset++];
uint8_t byte2 = buffer[offset++];
bool fin = (byte1 & 0x80) != 0;
uint8_t rsv = (byte1 >> 4) & 0x07; // RSV1, RSV2, RSV3
uint8_t opcode = byte1 & 0x0F;
bool masked = (byte2 & 0x80) != 0;
uint64_t payload_length = byte2 & 0x7F;
// RFC 6455: RSV bits must be 0 unless an extension is negotiated
if (rsv != 0) {
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_WARNING, "WebSocket: Non-zero RSV bits (%d) - possible frame corruption", rsv);
#endif
offset = original_offset + 1;
return true;
}
if (payload_length == 126) {
if (offset + 2 > buffer.size()) {
offset = original_offset;
return false;
}
payload_length = (buffer[offset] << 8) | buffer[offset + 1];
offset += 2;
} else if (payload_length == 127) {
if (offset + 8 > buffer.size()) {
offset = original_offset;
return false;
}
payload_length = 0;
for (int i = 0; i < 8; i++) {
payload_length = (payload_length << 8) | buffer[offset + i];
}
offset += 8;
}
// TODO: confirm we shouldnt send masked frames
uint32_t mask_key = 0;
if (masked) {
if (offset + 4 > buffer.size()) {
offset = original_offset;
return false;
}
mask_key = (buffer[offset] << 24) | (buffer[offset + 1] << 16) |
(buffer[offset + 2] << 8) | buffer[offset + 3];
offset += 4;
}
if (offset + payload_length > buffer.size()) {
offset = original_offset;
return false;
}
std::vector<uint8_t> payload(buffer.begin() + offset, buffer.begin() + offset + payload_length);
offset += payload_length;
if (masked) {
maskData(payload, mask_key);
}
switch (opcode) {
case 0x0: // continuation frame
case 0x1: // text frame
case 0x2: // binary frame
if (fin) {
std::string message(payload.begin(), payload.end());
fireMessageEvent(message, true); // force binary mode
} else {
partial_frame_buffer_ = payload;
expecting_continuation_ = true;
continuation_opcode_ = opcode;
}
break;
case 0x8:
{
uint16_t close_code = 1000;
std::string close_reason;
if (payload.size() >= 2) {
close_code = (payload[0] << 8) | payload[1];
if (payload.size() > 2) {
close_reason = std::string(payload.begin() + 2, payload.end());
}
}
fireCloseEvent(close_code, close_reason);
return false;
}
break;
case 0x9:
{
std::vector<uint8_t> pongFrame = createPongFrame(payload);
int sock = static_cast<int>(reinterpret_cast<intptr_t>(ws_));
int sendResult;
if (is_ssl_ && ssl_context_initialized_) {
sendResult = webSocketSSLSend(pongFrame.data(), static_cast<int>(pongFrame.size()));
} else {
sendResult = ::send(sock, reinterpret_cast<const char*>(pongFrame.data()), static_cast<int>(pongFrame.size()), 0);
}
#ifdef DEBUG_NETWORK
if (sendResult > 0) {
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Sent pong response");
}
#endif
}
break;
case 0xA:
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Received pong frame");
#endif
break;
case 0x3: case 0x4: case 0x5: case 0x6: case 0x7: // non-control frames
case 0xB: case 0xC: case 0xD: case 0xE: case 0xF: // control frames
// RFC 6455: reserved opcodes we skip the frame but continue processing
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_WARNING, "WebSocket: Received reserved opcode %d, skipping frame", opcode);
#endif
break;
default:
#ifdef DEBUG_NETWORK
// invalid opcode indicates frame misalignment so we try to recover
kinc_log(KINC_LOG_LEVEL_WARNING, "WebSocket: Invalid opcode %d - possible frame misalignment", opcode);
// we dont return false we allow message loop to continue with fresh data
#endif
break;
}
return true;
}
// RFC 6455 frame creation
std::vector<uint8_t> WebSocketClient::createWebSocketFrame(const std::string& message, uint8_t opcode) {
std::vector<uint8_t> frame;
// first byte: FIN (1) + RSV (000) + Opcode (4 bits)
frame.push_back(0x80 | opcode);
uint64_t payload_length = message.length();
uint32_t mask_key = rand() | (rand() << 16);
if (payload_length < 126) {
frame.push_back(0x80 | static_cast<uint8_t>(payload_length)); // MASK (1) + length
} else if (payload_length < 65536) {
frame.push_back(0x80 | 126); // MASK (1) + 126
frame.push_back(static_cast<uint8_t>(payload_length >> 8));
frame.push_back(static_cast<uint8_t>(payload_length & 0xFF));
} else {
frame.push_back(0x80 | 127); // MASK (1) + 127
for (int i = 7; i >= 0; i--) {
frame.push_back(static_cast<uint8_t>((payload_length >> (i * 8)) & 0xFF));
}
}
frame.push_back(static_cast<uint8_t>(mask_key >> 24));
frame.push_back(static_cast<uint8_t>(mask_key >> 16));
frame.push_back(static_cast<uint8_t>(mask_key >> 8));
frame.push_back(static_cast<uint8_t>(mask_key));
std::vector<uint8_t> payload(message.begin(), message.end());
maskData(payload, mask_key);
frame.insert(frame.end(), payload.begin(), payload.end());
return frame;
}
std::vector<uint8_t> WebSocketClient::createWebSocketBinaryFrame(const std::string& data) {
return createWebSocketFrame(data, 0x02);
}
std::vector<uint8_t> WebSocketClient::createCloseFrame(uint16_t code, const std::string& reason) {
std::vector<uint8_t> payload;
payload.push_back(static_cast<uint8_t>(code >> 8));
payload.push_back(static_cast<uint8_t>(code & 0xFF));
payload.insert(payload.end(), reason.begin(), reason.end());
return createWebSocketFrame(std::string(payload.begin(), payload.end()), 0x8);
}
std::vector<uint8_t> WebSocketClient::createPongFrame(const std::vector<uint8_t>& payload) {
std::string payloadStr(payload.begin(), payload.end());
return createWebSocketFrame(payloadStr, 0xA);
}
std::string WebSocketClient::generateWebSocketKey() {
std::vector<uint8_t> key(16);
for (int i = 0; i < 16; i++) {
key[i] = rand() & 0xFF;
}
return base64Encode(key);
}
std::string WebSocketClient::base64Encode(const std::vector<uint8_t>& data) {
static const char* chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
std::string result;
int pad = data.size() % 3;
for (size_t i = 0; i < data.size(); i += 3) {
uint32_t tmp = (data[i] << 16);
if (i + 1 < data.size()) tmp |= (data[i + 1] << 8);
if (i + 2 < data.size()) tmp |= data[i + 2];
result += chars[(tmp >> 18) & 0x3F];
result += chars[(tmp >> 12) & 0x3F];
result += (i + 1 < data.size()) ? chars[(tmp >> 6) & 0x3F] : '=';
result += (i + 2 < data.size()) ? chars[tmp & 0x3F] : '=';
}
return result;
}
// SHA1 hash using Windows CryptoAPI
std::vector<uint8_t> WebSocketClient::sha1Hash(const std::string& data) {
#ifdef _WIN32
HCRYPTPROV hProv = 0;
HCRYPTHASH hHash = 0;
std::vector<uint8_t> result(20, 0);
DWORD hashLen = 20;
try {
if (!CryptAcquireContext(&hProv, NULL, NULL, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT)) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: CryptAcquireContext failed");
return result;
}
if (!CryptCreateHash(hProv, CALG_SHA1, 0, 0, &hHash)) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: CryptCreateHash failed");
CryptReleaseContext(hProv, 0);
return result;
}
if (!CryptHashData(hHash, reinterpret_cast<const BYTE*>(data.c_str()), static_cast<DWORD>(data.length()), 0)) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: CryptHashData failed");
CryptDestroyHash(hHash);
CryptReleaseContext(hProv, 0);
return result;
}
if (!CryptGetHashParam(hHash, HP_HASHVAL, result.data(), &hashLen, 0)) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: CryptGetHashParam failed");
CryptDestroyHash(hHash);
CryptReleaseContext(hProv, 0);
return result;
}
CryptDestroyHash(hHash);
CryptReleaseContext(hProv, 0);
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SHA1 hash computed successfully");
#endif
return result;
} catch (const std::exception& e) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: SHA1 hash error: %s", e.what());
if (hHash) CryptDestroyHash(hHash);
if (hProv) CryptReleaseContext(hProv, 0);
return result;
}
#else
// TODO: SHA1
kinc_log(KINC_LOG_LEVEL_WARNING, "WebSocket: SHA1 not implemented for this platform");
return std::vector<uint8_t>(20, 0);
#endif
}
// masking with SIMD optimization
void WebSocketClient::maskData(std::vector<uint8_t>& data, uint32_t maskKey) {
uint8_t mask[4] = {
static_cast<uint8_t>(maskKey >> 24),
static_cast<uint8_t>(maskKey >> 16),
static_cast<uint8_t>(maskKey >> 8),
static_cast<uint8_t>(maskKey)
};
size_t len = data.size();
uint8_t* ptr = data.data();
#ifdef _WIN32
if (len >= 16) {
// byte rotation: mask[0,1,2,3,0,1,2,3,...]
__m128i mask_vec = _mm_set_epi8(
mask[3], mask[2], mask[1], mask[0], // bytes 12-15
mask[3], mask[2], mask[1], mask[0], // bytes 8-11
mask[3], mask[2], mask[1], mask[0], // bytes 4-7
mask[3], mask[2], mask[1], mask[0] // bytes 0-3
);
size_t simd_len = len & ~15;
for (size_t i = 0; i < simd_len; i += 16) {
__m128i data_vec = _mm_loadu_si128((__m128i*)(ptr + i));
__m128i result = _mm_xor_si128(data_vec, mask_vec);
_mm_storeu_si128((__m128i*)(ptr + i), result);
}
for (size_t i = simd_len; i < len; i++) {
ptr[i] ^= mask[i % 4];
}
} else
#endif
{
// fallback to scalar implementation
for (size_t i = 0; i < len; i++) {
ptr[i] ^= mask[i % 4];
}
}
}
bool WebSocketClient::initializeWebSocketSSL(int sock, const std::string& host) {
#ifdef WITH_SSL
#ifdef _WIN32
if (ssl_context_initialized_) {
return true;
}
ssl_socket_ = sock;
if (!initWSL()) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Failed to initialize SSL context");
return false;
}
if (!performWSLHandshake(sock, host)) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: SSL handshake failed");
cleanupWSL();
return false;
}
ssl_context_initialized_ = true;
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL/TLS initialized successfully");
#endif
return true;
#else
if (ssl_initialized_) {
return true;
}
// TODO: OpenSSL
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: SSL not supported on this platform");
return false;
#endif
#else
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: SSL not compiled in");
return false;
#endif
}
void WebSocketClient::cleanupWebSocketSSL() {
#ifdef WITH_SSL
#ifdef _WIN32
if (ssl_context_initialized_) {
cleanupWSL();
ssl_context_initialized_ = false;
ssl_socket_ = -1;
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL/TLS cleaned up");
}
#else
if (ssl_initialized_) {
// TODO: OpenSSL cleanup
ssl_initialized_ = false;
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL/TLS cleaned up");
}
#endif
#endif
}
int WebSocketClient::webSocketSSLSend(const void* data, int len) {
#ifdef WITH_SSL
#ifdef _WIN32
if (!ssl_context_initialized_) {
return ::send(ssl_socket_, reinterpret_cast<const char*>(data), len, 0);
}
return webSocketSSLSend(reinterpret_cast<const char*>(data), len);
#else
return ::send(ssl_socket_, reinterpret_cast<const char*>(data), len, 0);
#endif
#else
return ::send(ssl_socket_, reinterpret_cast<const char*>(data), len, 0);
#endif
}
#ifdef WITH_SSL
bool WebSocketClient::initWSL() {
#ifdef _WIN32
if (ssl_context_initialized_) {
return true;
}
ssl_cred_handle_ = new CredHandle();
memset(ssl_cred_handle_, 0, sizeof(CredHandle));
SCHANNEL_CRED cred = {0};
cred.dwVersion = SCHANNEL_CRED_VERSION;
cred.grbitEnabledProtocols = SP_PROT_TLS1_2 | SP_PROT_TLS1_1;
cred.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT;
cred.dwMinimumCipherStrength = 0;
cred.dwMaximumCipherStrength = 0;
SECURITY_STATUS status = AcquireCredentialsHandleA(
NULL, const_cast<LPSTR>(UNISP_NAME_A), SECPKG_CRED_OUTBOUND, NULL, &cred, NULL, NULL,
reinterpret_cast<PCredHandle>(ssl_cred_handle_), NULL);
if (status != SEC_E_OK) {
const char* cred_error_desc = "Unknown credential error";
switch (status) {
case SEC_E_SECPKG_NOT_FOUND: cred_error_desc = "SEC_E_SECPKG_NOT_FOUND - Security package not found"; break;
case SEC_E_NOT_OWNER: cred_error_desc = "SEC_E_NOT_OWNER - Not owner"; break;
case SEC_E_CANNOT_INSTALL: cred_error_desc = "SEC_E_CANNOT_INSTALL - Cannot install"; break;
case SEC_E_INVALID_TOKEN: cred_error_desc = "SEC_E_INVALID_TOKEN - Invalid token"; break;
case SEC_E_LOGON_DENIED: cred_error_desc = "SEC_E_LOGON_DENIED - Logon denied"; break;
}
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to acquire WebSocket SSL credentials: 0x%x (%s)", status, cred_error_desc);
if (ssl_cred_handle_) {
delete reinterpret_cast<CredHandle*>(ssl_cred_handle_);
ssl_cred_handle_ = nullptr;
}
return false;
}
kinc_log(KINC_LOG_LEVEL_INFO, "Successfully acquired WebSocket SSL credentials");
ssl_context_initialized_ = true;
return true;
#else
if (ssl_initialized_) {
return true;
}
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_all_algorithms();
ssl_ctx_ = SSL_CTX_new(TLS_client_method());
if (!ssl_ctx_) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to create WebSocket SSL context");
return false;
}
SSL_CTX_set_options(ssl_ctx_, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER, NULL);
SSL_CTX_set_default_verify_paths(ssl_ctx_);
ssl_initialized_ = true;
return true;
#endif
}
void WebSocketClient::cleanupWSL() {
#ifdef _WIN32
if (ssl_context_initialized_) {
if (ssl_context_handle_) {
DeleteSecurityContext(reinterpret_cast<PCtxtHandle>(ssl_context_handle_));
delete reinterpret_cast<CtxtHandle*>(ssl_context_handle_);
ssl_context_handle_ = nullptr;
}
if (ssl_cred_handle_) {
FreeCredentialsHandle(reinterpret_cast<PCredHandle>(ssl_cred_handle_));
delete reinterpret_cast<CredHandle*>(ssl_cred_handle_);
ssl_cred_handle_ = nullptr;
}
ssl_context_initialized_ = false;
}
ssl_buffer_.clear();
ssl_socket_ = -1;
#else
if (ssl_) {
SSL_free(ssl_);
ssl_ = nullptr;
}
if (ssl_ctx_) {
SSL_CTX_free(ssl_ctx_);
ssl_ctx_ = nullptr;
}
ssl_initialized_ = false;
#endif
}
bool WebSocketClient::performWSLHandshake(int socket, const std::string& host) {
if (!initWSL()) {
return false;
}
ssl_socket_ = socket;
#ifdef _WIN32
kinc_log(KINC_LOG_LEVEL_INFO, "Performing SChannel WebSocket SSL handshake...");
SecBufferDesc outbuffer_desc, inbuffer_desc;
SecBuffer outbuffers[1], inbuffers[2];
DWORD context_attributes;
TimeStamp expiry;
outbuffers[0].pvBuffer = nullptr;
outbuffers[0].BufferType = SECBUFFER_TOKEN;
outbuffers[0].cbBuffer = 0;
outbuffer_desc.cBuffers = 1;
outbuffer_desc.pBuffers = outbuffers;
outbuffer_desc.ulVersion = SECBUFFER_VERSION;
inbuffers[0].pvBuffer = nullptr;
inbuffers[0].BufferType = SECBUFFER_TOKEN;
inbuffers[0].cbBuffer = 0;
inbuffers[1].pvBuffer = nullptr;
inbuffers[1].BufferType = SECBUFFER_EMPTY;
inbuffers[1].cbBuffer = 0;
inbuffer_desc.cBuffers = 2;
inbuffer_desc.pBuffers = inbuffers;
inbuffer_desc.ulVersion = SECBUFFER_VERSION;
SECURITY_STATUS status = SEC_I_CONTINUE_NEEDED;
bool first_call = true;
std::vector<char> handshake_buffer;
while (status == SEC_I_CONTINUE_NEEDED || status == SEC_E_INCOMPLETE_MESSAGE) {
if (first_call) {
// the initial call / first call should use NULL for phContext and allocate CtxtHandle structure for the output
ssl_context_handle_ = new CtxtHandle();
memset(ssl_context_handle_, 0, sizeof(CtxtHandle));
kinc_log(KINC_LOG_LEVEL_INFO, "Calling InitializeSecurityContextA for WebSocket first handshake...");
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket buffer before call: cbBuffer=%d, pvBuffer=%p", outbuffers[0].cbBuffer, outbuffers[0].pvBuffer);
status = InitializeSecurityContextA(
reinterpret_cast<PCredHandle>(ssl_cred_handle_),
nullptr,
const_cast<char*>(host.c_str()),
ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT |
ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR |
ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM,
0, SECURITY_NATIVE_DREP, nullptr, 0,
reinterpret_cast<PCtxtHandle>(ssl_context_handle_),
&outbuffer_desc, &context_attributes, &expiry);
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket InitializeSecurityContextA result: status=0x%x", status);
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket buffer after call: cbBuffer=%d, pvBuffer=%p", outbuffers[0].cbBuffer, outbuffers[0].pvBuffer);
if (status != SEC_I_CONTINUE_NEEDED && status != SEC_E_OK) {
const char* init_error_desc = "Unknown initialization error";
switch (status) {
case SEC_E_INVALID_HANDLE: init_error_desc = "SEC_E_INVALID_HANDLE - Invalid handle"; break;
case SEC_E_INVALID_TOKEN: init_error_desc = "SEC_E_INVALID_TOKEN - Invalid token"; break;
case SEC_E_LOGON_DENIED: init_error_desc = "SEC_E_LOGON_DENIED - Logon denied"; break;
case SEC_E_TARGET_UNKNOWN: init_error_desc = "SEC_E_TARGET_UNKNOWN - Target unknown"; break;
case SEC_E_INTERNAL_ERROR: init_error_desc = "SEC_E_INTERNAL_ERROR - Internal error"; break;
case SEC_E_SECPKG_NOT_FOUND: init_error_desc = "SEC_E_SECPKG_NOT_FOUND - Security package not found"; break;
}
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket InitializeSecurityContextA failed: 0x%x (%s)", status, init_error_desc);
if (ssl_context_handle_) {
delete reinterpret_cast<CtxtHandle*>(ssl_context_handle_);
ssl_context_handle_ = nullptr;
}
return false;
}
first_call = false;
} else {
outbuffers[0].pvBuffer = nullptr;
outbuffers[0].BufferType = SECBUFFER_TOKEN;
outbuffers[0].cbBuffer = 0;
if (!handshake_buffer.empty()) {
inbuffers[0].pvBuffer = handshake_buffer.data();
inbuffers[0].cbBuffer = static_cast<unsigned long>(handshake_buffer.size());
inbuffers[0].BufferType = SECBUFFER_TOKEN;
inbuffers[1].pvBuffer = nullptr;
inbuffers[1].BufferType = SECBUFFER_EMPTY;
inbuffers[1].cbBuffer = 0;
}
kinc_log(KINC_LOG_LEVEL_INFO, "Calling InitializeSecurityContextA for WebSocket continuation handshake...");
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket input buffer: cbBuffer=%d, pvBuffer=%p", inbuffers[0].cbBuffer, inbuffers[0].pvBuffer);
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket context handle: %p", ssl_context_handle_);
status = InitializeSecurityContextA(
reinterpret_cast<PCredHandle>(ssl_cred_handle_),
reinterpret_cast<PCtxtHandle>(ssl_context_handle_),
const_cast<char*>(host.c_str()),
ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT |
ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR |
ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM,
0, SECURITY_NATIVE_DREP, &inbuffer_desc, 0,
reinterpret_cast<PCtxtHandle>(ssl_context_handle_),
&outbuffer_desc, &context_attributes, &expiry);
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket continuation InitializeSecurityContextA result: status=0x%x", status);
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket output buffer after continuation: cbBuffer=%d, pvBuffer=%p", outbuffers[0].cbBuffer, outbuffers[0].pvBuffer);
// handle for some data consumed, check for extra data in the buffer and move to front
if (status == SEC_E_OK || status == SEC_I_CONTINUE_NEEDED) {
//
if (inbuffers[1].BufferType == SECBUFFER_EXTRA && inbuffers[1].cbBuffer > 0) {
size_t extra_size = inbuffers[1].cbBuffer;
size_t consumed = handshake_buffer.size() - extra_size;
std::memmove(handshake_buffer.data(), handshake_buffer.data() + consumed, extra_size);
handshake_buffer.resize(extra_size);
} else {
handshake_buffer.clear();
}
}
}
if (status == SEC_I_CONTINUE_NEEDED && outbuffers[0].cbBuffer > 0 && outbuffers[0].pvBuffer != nullptr) {
// validate buffer before sending
if (outbuffers[0].cbBuffer < 0 || outbuffers[0].cbBuffer > 65536) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Invalid WebSocket SSL handshake buffer size: %d bytes (corrupted data)", outbuffers[0].cbBuffer);
if (outbuffers[0].pvBuffer) FreeContextBuffer(outbuffers[0].pvBuffer);
return false;
}
kinc_log(KINC_LOG_LEVEL_INFO, "Sending WebSocket SSL handshake data: %d bytes", outbuffers[0].cbBuffer);
int sent = ::send(socket, static_cast<char*>(outbuffers[0].pvBuffer),
outbuffers[0].cbBuffer, 0);
if (sent <= 0) {
int error = WSAGetLastError();
const char* error_desc = "Unknown socket error";
switch (error) {
case WSAENOTCONN: error_desc = "WSAENOTCONN - Socket is not connected"; break;
case WSAECONNRESET: error_desc = "WSAECONNRESET - Connection reset by peer"; break;
case WSAEWOULDBLOCK: error_desc = "WSAEWOULDBLOCK - Resource temporarily unavailable"; break;
case WSAENETDOWN: error_desc = "WSAENETDOWN - Network is down"; break;
case WSAEFAULT: error_desc = "WSAEFAULT - Bad address"; break;
case WSAEINVAL: error_desc = "WSAEINVAL - Invalid argument"; break;
case WSAENOTSOCK: error_desc = "WSAENOTSOCK - Socket operation on non-socket"; break;
}
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to send WebSocket SSL handshake data: sent=%d, error=%d (%s)", sent, error, error_desc);
if (outbuffers[0].pvBuffer) FreeContextBuffer(outbuffers[0].pvBuffer);
return false;
}
kinc_log(KINC_LOG_LEVEL_INFO, "Successfully sent WebSocket SSL handshake data: %d bytes", sent);
FreeContextBuffer(outbuffers[0].pvBuffer);
outbuffers[0].pvBuffer = nullptr;
outbuffers[0].cbBuffer = 0;
}
if (status == SEC_I_CONTINUE_NEEDED || status == SEC_E_INCOMPLETE_MESSAGE) {
char recv_buffer[SocketOptimization::SMALL_BUFFER_SIZE];
kinc_log(KINC_LOG_LEVEL_INFO, "Waiting to receive WebSocket SSL handshake response...");
int received = ::recv(socket, recv_buffer, sizeof(recv_buffer), 0);
if (received <= 0) {
int error = WSAGetLastError();
const char* error_desc = "Unknown socket error";
switch (error) {
case WSAENOTCONN: error_desc = "WSAENOTCONN - Socket is not connected"; break;
case WSAECONNRESET: error_desc = "WSAECONNRESET - Connection reset by peer"; break;
case WSAEWOULDBLOCK: error_desc = "WSAEWOULDBLOCK - Resource temporarily unavailable"; break;
case WSAENETDOWN: error_desc = "WSAENETDOWN - Network is down"; break;
case WSAEFAULT: error_desc = "WSAEFAULT - Bad address"; break;
case WSAEINVAL: error_desc = "WSAEINVAL - Invalid argument"; break;
case WSAENOTSOCK: error_desc = "WSAENOTSOCK - Socket operation on non-socket"; break;
case WSAECONNABORTED: error_desc = "WSAECONNABORTED - Connection aborted"; break;
}
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to receive WebSocket SSL handshake response: received=%d, error=%d (%s)", received, error, error_desc);
return false;
}
kinc_log(KINC_LOG_LEVEL_INFO, "Received WebSocket SSL handshake response: %d bytes", received);
handshake_buffer.insert(handshake_buffer.end(), recv_buffer, recv_buffer + received);
}
}
if (status == SEC_E_OK) {
ssl_context_initialized_ = true;
kinc_log(KINC_LOG_LEVEL_INFO, "SChannel WebSocket SSL handshake completed successfully");
return true;
} else {
const char* error_desc = "Unknown error";
switch (status) {
case SEC_E_LOGON_DENIED: error_desc = "SEC_E_LOGON_DENIED - The logon attempt failed"; break;
case SEC_E_INVALID_TOKEN: error_desc = "SEC_E_INVALID_TOKEN - Invalid security token"; break;
case SEC_E_INVALID_HANDLE: error_desc = "SEC_E_INVALID_HANDLE - Invalid handle"; break;
case SEC_E_INTERNAL_ERROR: error_desc = "SEC_E_INTERNAL_ERROR - Internal error"; break;
case SEC_E_NO_CREDENTIALS: error_desc = "SEC_E_NO_CREDENTIALS - No credentials available"; break;
case SEC_E_WRONG_PRINCIPAL: error_desc = "SEC_E_WRONG_PRINCIPAL - Wrong principal"; break;
case SEC_E_CERT_EXPIRED: error_desc = "SEC_E_CERT_EXPIRED - Certificate expired"; break;
case SEC_E_CERT_UNKNOWN: error_desc = "SEC_E_CERT_UNKNOWN - Certificate unknown"; break;
case SEC_E_UNTRUSTED_ROOT: error_desc = "SEC_E_UNTRUSTED_ROOT - Untrusted root certificate"; break;
}
kinc_log(KINC_LOG_LEVEL_ERROR, "SChannel WebSocket SSL handshake failed with status: 0x%x (%s)", status, error_desc);
return false;
}
#else
ssl_ = SSL_new(ssl_ctx_);
if (!ssl_) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to create WebSocket SSL structure");
return false;
}
if (SSL_set_fd(ssl_, socket) != 1) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to set WebSocket SSL socket");
SSL_free(ssl_);
ssl_ = nullptr;
return false;
}
SSL_set_tlsext_host_name(ssl_, host.c_str());
int result = SSL_connect(ssl_);
if (result != 1) {
int error = SSL_get_error(ssl_, result);
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket SSL handshake failed: %d", error);
SSL_free(ssl_);
ssl_ = nullptr;
return false;
}
ssl_initialized_ = true;
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket OpenSSL handshake completed successfully");
return true;
#endif
}
int WebSocketClient::wslRead(char* buffer, int length) {
#ifdef _WIN32
if (!ssl_context_initialized_) {
return -1;
}
while (true) {
if (!ssl_buffer_.empty()) {
SecBufferDesc message_desc;
SecBuffer message_buffers[4];
message_desc.ulVersion = SECBUFFER_VERSION;
message_desc.cBuffers = 4;
message_desc.pBuffers = message_buffers;
message_buffers[0].BufferType = SECBUFFER_DATA;
message_buffers[0].pvBuffer = ssl_buffer_.data();
message_buffers[0].cbBuffer = static_cast<unsigned long>(ssl_buffer_.size());
message_buffers[1].BufferType = SECBUFFER_EMPTY;
message_buffers[1].pvBuffer = nullptr;
message_buffers[1].cbBuffer = 0;
message_buffers[2].BufferType = SECBUFFER_EMPTY;
message_buffers[2].pvBuffer = nullptr;
message_buffers[2].cbBuffer = 0;
message_buffers[3].BufferType = SECBUFFER_EMPTY;
message_buffers[3].pvBuffer = nullptr;
message_buffers[3].cbBuffer = 0;
SECURITY_STATUS status = DecryptMessage(
reinterpret_cast<PCtxtHandle>(ssl_context_handle_),
&message_desc, 0, nullptr);
if (status == SEC_E_OK) {
SecBuffer* data_buffer = nullptr;
SecBuffer* extra_buffer = nullptr;
for (int i = 0; i < 4; i++) {
if (message_buffers[i].BufferType == SECBUFFER_DATA) {
data_buffer = &message_buffers[i];
} else if (message_buffers[i].BufferType == SECBUFFER_EXTRA) {
extra_buffer = &message_buffers[i];
}
}
if (data_buffer && data_buffer->cbBuffer > 0) {
int bytes_to_copy = (std::min)(length, static_cast<int>(data_buffer->cbBuffer));
memcpy(buffer, data_buffer->pvBuffer, bytes_to_copy);
if (extra_buffer && extra_buffer->cbBuffer > 0) {
std::memmove(ssl_buffer_.data(), extra_buffer->pvBuffer, extra_buffer->cbBuffer);
ssl_buffer_.resize(extra_buffer->cbBuffer);
} else {
ssl_buffer_.clear();
}
return bytes_to_copy;
}
ssl_buffer_.clear();
} else if (status == SEC_E_INCOMPLETE_MESSAGE) {
// Need more encrypted data, read from socket
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL incomplete message detected (status=0x%x, SEC_E_INCOMPLETE_MESSAGE=0x%x), reading more data (buffer size: %zu)", status, SEC_E_INCOMPLETE_MESSAGE, ssl_buffer_.size());
char recv_buffer[SocketOptimization::SMALL_BUFFER_SIZE];
int received = ::recv(ssl_socket_, recv_buffer, sizeof(recv_buffer), 0);
if (received <= 0) {
if (received == 0) {
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL connection closed by server during read");
} else {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Failed to receive encrypted data: %d (WSAError: %d)", received, WSAGetLastError());
}
return -1;
}
size_t old_size = ssl_buffer_.size();
ssl_buffer_.insert(ssl_buffer_.end(), recv_buffer, recv_buffer + received);
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Added %d bytes to SSL buffer (total: %zu -> %zu)", received, old_size, ssl_buffer_.size());
continue;
} else if (status == 0x90317) {
// incomplete message error
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Detected incomplete message error (0x90317), buffer size: %zu", ssl_buffer_.size());
// log the actual buffer content as hex to see what we received
if (ssl_buffer_.size() > 0) {
std::string hex_preview;
for (size_t i = 0; i < (std::min)(static_cast<size_t>(31), ssl_buffer_.size()); i++) {
char hex_byte[4];
sprintf(hex_byte, "%02X ", (unsigned char)ssl_buffer_[i]);
hex_preview += hex_byte;
}
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL buffer hex (first %d bytes): %s",
static_cast<int>((std::min)(static_cast<size_t>(31), ssl_buffer_.size())), hex_preview.c_str());
// check for ssl alert record type 0x15
if (ssl_buffer_.size() >= 5 && ssl_buffer_[0] == 0x15) {
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Detected SSL Alert record (0x15) - server rejected handshake");
if (ssl_buffer_.size() >= 7) {
uint8_t alert_level = ssl_buffer_[5];
uint8_t alert_desc = ssl_buffer_[6];
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL Alert - Level: %d, Description: %d", alert_level, alert_desc);
}
} else if (ssl_buffer_.size() >= 1) {
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL record type: 0x%02X", (unsigned char)ssl_buffer_[0]);
}
}
// try a smaller read first to see if server has more data
char recv_buffer[SocketOptimization::SMALL_BUFFER_SIZE / 8];
int received = ::recv(ssl_socket_, recv_buffer, sizeof(recv_buffer), 0);
if (received <= 0) {
if (received == 0) {
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL connection closed by server - may be complete response");
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Attempting to process existing buffer as complete response");
return -1;
} else {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: Failed to receive encrypted data: %d (WSAError: %d)", received, WSAGetLastError());
}
return -1;
}
size_t old_size = ssl_buffer_.size();
ssl_buffer_.insert(ssl_buffer_.end(), recv_buffer, recv_buffer + received);
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: Added %d bytes to SSL buffer (total: %zu -> %zu)", received, old_size, ssl_buffer_.size());
continue;
} else {
const char* status_desc = "Unknown SSL error";
switch (status) {
case SEC_E_INVALID_TOKEN: status_desc = "SEC_E_INVALID_TOKEN - Invalid token"; break;
case SEC_E_INVALID_HANDLE: status_desc = "SEC_E_INVALID_HANDLE - Invalid handle"; break;
case SEC_E_MESSAGE_ALTERED: status_desc = "SEC_E_MESSAGE_ALTERED - Message altered"; break;
case SEC_E_OUT_OF_SEQUENCE: status_desc = "SEC_E_OUT_OF_SEQUENCE - Out of sequence"; break;
case SEC_E_NO_AUTHENTICATING_AUTHORITY: status_desc = "SEC_E_NO_AUTHENTICATING_AUTHORITY - No auth authority"; break;
case 0x90317: status_desc = "0x90317 - Incomplete message (should be handled above)"; break;
}
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: SChannel decrypt failed: 0x%x (%s) - buffer size: %zu, SEC_E_INCOMPLETE_MESSAGE=0x%x", status, status_desc, ssl_buffer_.size(), SEC_E_INCOMPLETE_MESSAGE);
return -1;
}
}
char recv_buffer[SocketOptimization::SMALL_BUFFER_SIZE];
int received = ::recv(ssl_socket_, recv_buffer, sizeof(recv_buffer), 0);
if (received <= 0) {
if (received == 0) {
return 0;
}
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to receive encrypted data for WebSocket: %d", received);
return -1;
}
ssl_buffer_.insert(ssl_buffer_.end(), recv_buffer, recv_buffer + received);
// retry decryption
}
#else
if (!ssl_) {
return -1;
}
int bytes_read = SSL_read(ssl_, buffer, length);
if (bytes_read <= 0) {
int error = SSL_get_error(ssl_, bytes_read);
if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) {
return 0;
}
return -1;
}
return bytes_read;
#endif
}
int WebSocketClient::webSocketSSLSend(const char* buffer, int length) {
#ifdef WITH_SSL
#ifdef _WIN32
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: webSocketSSLSend(const char*) called, length=%d, ssl_initialized=%d", length, ssl_context_initialized_ ? 1 : 0);
if (!ssl_context_initialized_) {
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL not initialized, using plain send");
return ::send(ssl_socket_, buffer, length, 0);
}
SecBufferDesc message_desc;
SecBuffer message_buffers[4];
SecPkgContext_StreamSizes stream_sizes;
SECURITY_STATUS status = QueryContextAttributesA(
static_cast<PCtxtHandle>(ssl_context_handle_),
SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
if (status != SEC_E_OK) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to get WebSocket stream sizes: 0x%x", status);
return -1;
}
int total_size = stream_sizes.cbHeader + length + stream_sizes.cbTrailer;
char* encrypt_buffer = new char[total_size];
message_desc.ulVersion = SECBUFFER_VERSION;
message_desc.cBuffers = 4;
message_desc.pBuffers = message_buffers;
message_buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
message_buffers[0].pvBuffer = encrypt_buffer;
message_buffers[0].cbBuffer = stream_sizes.cbHeader;
message_buffers[1].BufferType = SECBUFFER_DATA;
message_buffers[1].pvBuffer = encrypt_buffer + stream_sizes.cbHeader;
message_buffers[1].cbBuffer = length;
memcpy(message_buffers[1].pvBuffer, buffer, length);
message_buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
message_buffers[2].pvBuffer = encrypt_buffer + stream_sizes.cbHeader + length;
message_buffers[2].cbBuffer = stream_sizes.cbTrailer;
message_buffers[3].BufferType = SECBUFFER_EMPTY;
message_buffers[3].pvBuffer = nullptr;
message_buffers[3].cbBuffer = 0;
status = EncryptMessage(
static_cast<PCtxtHandle>(ssl_context_handle_),
0, &message_desc, 0);
if (status != SEC_E_OK) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket SChannel encrypt failed: 0x%x", status);
delete[] encrypt_buffer;
return -1;
}
// calculate encrypted size and send to socket
int encrypted_size = message_buffers[0].cbBuffer + message_buffers[1].cbBuffer + message_buffers[2].cbBuffer;
int bytes_sent = ::send(ssl_socket_, encrypt_buffer, encrypted_size, 0);
#ifdef DEBUG_NETWORK
kinc_log(KINC_LOG_LEVEL_INFO, "WebSocket: SSL send - encrypted %d bytes, sent %d bytes", encrypted_size, bytes_sent);
#endif
delete[] encrypt_buffer;
if (bytes_sent <= 0) {
kinc_log(KINC_LOG_LEVEL_ERROR, "WebSocket: SSL send failed: %d", WSAGetLastError());
return -1;
}
return length;
#else
if (!ssl_) {
return -1;
}
int bytes_written = SSL_write(ssl_, buffer, length);
if (bytes_written <= 0) {
int error = SSL_get_error(ssl_, bytes_written);
if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) {
return 0;
}
return -1;
}
return bytes_written;
#endif
#else
return ::send(ssl_socket_, buffer, length, 0);
#endif
}
std::string WebSocketClient::base64Encode(const std::string& data) {
const char* chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
std::string result;
int i = 0;
unsigned char char_array_3[3];
unsigned char char_array_4[4];
const char* bytes_to_encode = data.c_str();
int in_len = static_cast<int>(data.length());
while (in_len--) {
char_array_3[i++] = *(bytes_to_encode++);
if (i == 3) {
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
char_array_4[3] = char_array_3[2] & 0x3f;
for (i = 0; i < 4; i++) {
result += chars[char_array_4[i]];
}
i = 0;
}
}
if (i) {
for (int j = i; j < 3; j++) {
char_array_3[j] = '\0';
}
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
char_array_4[3] = char_array_3[2] & 0x3f;
for (int j = 0; j < i + 1; j++) {
result += chars[char_array_4[j]];
}
while (i++ < 3) {
result += '=';
}
}
return result;
}
#endif
}
void runt_websocket_create(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
if (args.Length() < 1 || !args[0]->IsString()) {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "URL required").ToLocalChecked()));
return;
}
String::Utf8Value url(isolate, args[0]);
kinc_log(KINC_LOG_LEVEL_INFO, "[WebSocket Client] Creating connection to: %s", *url);
try {
int id = WebSocketWrapper::next_websocket_id++;
auto client = std::make_unique<WebSocketWrapper::WebSocketClient>(isolate, WebSocketWrapper::getGlobalContext(), *url);
WebSocketWrapper::active_websockets[id] = std::move(client);
kinc_log(KINC_LOG_LEVEL_INFO, "[WebSocket Client] Connection created successfully with ID: %d", id);
args.GetReturnValue().Set(Number::New(isolate, id));
} catch (const std::exception& e) {
kinc_log(KINC_LOG_LEVEL_ERROR, "[WebSocket Client] Failed to create connection: %s", e.what());
args.GetReturnValue().Set(Number::New(isolate, -1));
} catch (...) {
kinc_log(KINC_LOG_LEVEL_ERROR, "[WebSocket Client] Unknown error creating connection");
args.GetReturnValue().Set(Number::New(isolate, -1));
}
}
void runt_websocket_send(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
if (args.Length() < 2 || !args[0]->IsNumber()) {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "ID and data required").ToLocalChecked()));
return;
}
int id = args[0]->Int32Value(isolate->GetCurrentContext()).FromJust();
auto it = WebSocketWrapper::active_websockets.find(id);
if (it == WebSocketWrapper::active_websockets.end()) {
return;
}
if (args[1]->IsArrayBuffer()) {
Local<ArrayBuffer> buffer = args[1].As<ArrayBuffer>();
void* data = buffer->GetBackingStore()->Data();
size_t length = buffer->ByteLength();
std::string binaryData(static_cast<const char*>(data), length);
it->second->sendBinary(binaryData);
} else if (args[1]->IsArrayBufferView()) {
Local<ArrayBufferView> view = args[1].As<ArrayBufferView>();
Local<ArrayBuffer> buffer = view->Buffer();
void* data = static_cast<char*>(buffer->GetBackingStore()->Data()) + view->ByteOffset();
size_t length = view->ByteLength();
std::string binaryData(static_cast<const char*>(data), length);
it->second->sendBinary(binaryData);
} else if (args[1]->IsString()) {
String::Utf8Value data(isolate, args[1]);
it->second->send(*data);
} else {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "Data must be String or ArrayBuffer").ToLocalChecked()));
}
}
void runt_websocket_send_binary(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
if (args.Length() < 2 || !args[0]->IsNumber()) {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "ID and binary data required").ToLocalChecked()));
return;
}
int id = args[0]->Int32Value(isolate->GetCurrentContext()).FromJust();
auto it = WebSocketWrapper::active_websockets.find(id);
if (it == WebSocketWrapper::active_websockets.end()) {
return;
}
if (args[1]->IsArrayBuffer()) {
Local<ArrayBuffer> buffer = args[1].As<ArrayBuffer>();
void* data = buffer->GetBackingStore()->Data();
size_t length = buffer->ByteLength();
std::string binaryData(static_cast<const char*>(data), length);
it->second->sendBinary(binaryData);
} else if (args[1]->IsArrayBufferView()) {
Local<ArrayBufferView> view = args[1].As<ArrayBufferView>();
Local<ArrayBuffer> buffer = view->Buffer();
void* data = static_cast<char*>(buffer->GetBackingStore()->Data()) + view->ByteOffset();
size_t length = view->ByteLength();
std::string binaryData(static_cast<const char*>(data), length);
it->second->sendBinary(binaryData);
} else {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "Binary data must be ArrayBuffer or ArrayBufferView").ToLocalChecked()));
}
}
void runt_websocket_close(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
if (args.Length() < 1 || !args[0]->IsNumber()) {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "ID required").ToLocalChecked()));
return;
}
int id = args[0]->Int32Value(isolate->GetCurrentContext()).FromJust();
int code = 1000;
std::string reason = "";
if (args.Length() > 1 && args[1]->IsNumber()) {
code = args[1]->Int32Value(isolate->GetCurrentContext()).FromJust();
}
if (args.Length() > 2 && args[2]->IsString()) {
String::Utf8Value reasonStr(isolate, args[2]);
reason = *reasonStr;
}
auto it = WebSocketWrapper::active_websockets.find(id);
if (it != WebSocketWrapper::active_websockets.end()) {
it->second->close(code, reason);
// erase here causes iterator invalidation when called from callback
// closed websockets will be cleaned up naturally
// WebSocketWrapper::active_websockets.erase(it);
}
}
void runt_websocket_get_ready_state(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
if (args.Length() < 1 || !args[0]->IsNumber()) {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "ID required").ToLocalChecked()));
return;
}
int id = args[0]->Int32Value(isolate->GetCurrentContext()).FromJust();
auto it = WebSocketWrapper::active_websockets.find(id);
if (it != WebSocketWrapper::active_websockets.end()) {
args.GetReturnValue().Set(Number::New(isolate, it->second->getReadyState()));
} else {
args.GetReturnValue().Set(Number::New(isolate, WebSocketWrapper::CLOSED));
}
}
void runt_websocket_get_url(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
if (args.Length() < 1 || !args[0]->IsNumber()) {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "ID required").ToLocalChecked()));
return;
}
int id = args[0]->Int32Value(isolate->GetCurrentContext()).FromJust();
auto it = WebSocketWrapper::active_websockets.find(id);
if (it != WebSocketWrapper::active_websockets.end()) {
args.GetReturnValue().Set(String::NewFromUtf8(isolate, it->second->getUrl().c_str()).ToLocalChecked());
} else {
args.GetReturnValue().Set(String::NewFromUtf8(isolate, "").ToLocalChecked());
}
}
void runt_websocket_get_buffered_amount(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
if (args.Length() < 1 || !args[0]->IsNumber()) {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "ID required").ToLocalChecked()));
return;
}
int id = args[0]->Int32Value(isolate->GetCurrentContext()).FromJust();
auto it = WebSocketWrapper::active_websockets.find(id);
if (it != WebSocketWrapper::active_websockets.end()) {
args.GetReturnValue().Set(Number::New(isolate, it->second->getBufferedAmount()));
} else {
args.GetReturnValue().Set(Number::New(isolate, 0));
}
}
void runt_websocket_set_onopen(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
if (args.Length() < 2 || !args[0]->IsNumber() || !args[1]->IsFunction()) {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "ID and callback required").ToLocalChecked()));
return;
}
int id = args[0]->Int32Value(isolate->GetCurrentContext()).FromJust();
Local<Function> callback = Local<Function>::Cast(args[1]);
auto it = WebSocketWrapper::active_websockets.find(id);
if (it != WebSocketWrapper::active_websockets.end()) {
it->second->setOnOpen(callback);
}
}
void runt_websocket_set_onmessage(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
if (args.Length() < 2 || !args[0]->IsNumber() || !args[1]->IsFunction()) {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "ID and callback required").ToLocalChecked()));
return;
}
int id = args[0]->Int32Value(isolate->GetCurrentContext()).FromJust();
Local<Function> callback = Local<Function>::Cast(args[1]);
auto it = WebSocketWrapper::active_websockets.find(id);
if (it != WebSocketWrapper::active_websockets.end()) {
it->second->setOnMessage(callback);
}
}
void runt_websocket_set_onerror(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
if (args.Length() < 2 || !args[0]->IsNumber() || !args[1]->IsFunction()) {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "ID and callback required").ToLocalChecked()));
return;
}
int id = args[0]->Int32Value(isolate->GetCurrentContext()).FromJust();
Local<Function> callback = Local<Function>::Cast(args[1]);
auto it = WebSocketWrapper::active_websockets.find(id);
if (it != WebSocketWrapper::active_websockets.end()) {
it->second->setOnError(callback);
}
}
void runt_websocket_set_onclose(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
if (args.Length() < 2 || !args[0]->IsNumber() || !args[1]->IsFunction()) {
isolate->ThrowException(Exception::Error(String::NewFromUtf8(isolate, "ID and callback required").ToLocalChecked()));
return;
}
int id = args[0]->Int32Value(isolate->GetCurrentContext()).FromJust();
Local<Function> callback = Local<Function>::Cast(args[1]);
auto it = WebSocketWrapper::active_websockets.find(id);
if (it != WebSocketWrapper::active_websockets.end()) {
it->second->setOnClose(callback);
}
}
// Native WebSocket class implementation for V8
static void WebSocketConstructor(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
Local<Context> context = isolate->GetCurrentContext();
if (!args.IsConstructCall()) {
isolate->ThrowException(Exception::TypeError(
String::NewFromUtf8(isolate, "WebSocket constructor requires 'new'").ToLocalChecked()));
return;
}
if (args.Length() < 1 || !args[0]->IsString()) {
isolate->ThrowException(Exception::TypeError(
String::NewFromUtf8(isolate, "WebSocket constructor requires a URL").ToLocalChecked()));
return;
}
String::Utf8Value url(isolate, args[0]);
int wsId = WebSocketWrapper::createWebSocketConnection(isolate, *url);
Local<Object> instance = args.This();
instance->Set(context, String::NewFromUtf8(isolate, "_id").ToLocalChecked(), Integer::New(isolate, wsId));
instance->Set(context, String::NewFromUtf8(isolate, "url").ToLocalChecked(), args[0]);
instance->Set(context, String::NewFromUtf8(isolate, "readyState").ToLocalChecked(), Integer::New(isolate, 0));
instance->Set(context, String::NewFromUtf8(isolate, "protocol").ToLocalChecked(), String::NewFromUtf8(isolate, "").ToLocalChecked());
instance->Set(context, String::NewFromUtf8(isolate, "extensions").ToLocalChecked(), String::NewFromUtf8(isolate, "").ToLocalChecked());
instance->Set(context, String::NewFromUtf8(isolate, "binaryType").ToLocalChecked(), String::NewFromUtf8(isolate, "blob").ToLocalChecked());
instance->Set(context, String::NewFromUtf8(isolate, "bufferedAmount").ToLocalChecked(), Integer::New(isolate, 0));
instance->Set(context, String::NewFromUtf8(isolate, "onopen").ToLocalChecked(), Null(isolate));
instance->Set(context, String::NewFromUtf8(isolate, "onmessage").ToLocalChecked(), Null(isolate));
instance->Set(context, String::NewFromUtf8(isolate, "onerror").ToLocalChecked(), Null(isolate));
instance->Set(context, String::NewFromUtf8(isolate, "onclose").ToLocalChecked(), Null(isolate));
args.GetReturnValue().Set(instance);
}
static void WebSocketSend(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
Local<Context> context = isolate->GetCurrentContext();
Local<Object> self = args.Holder();
Local<Value> readyStateVal = self->Get(context, String::NewFromUtf8(isolate, "readyState").ToLocalChecked()).ToLocalChecked();
int readyState = readyStateVal->Int32Value(context).FromJust();
if (readyState == 0) {
isolate->ThrowException(Exception::Error(
String::NewFromUtf8(isolate, "WebSocket is not open: readyState 0 (CONNECTING)").ToLocalChecked()));
return;
}
if (readyState != 1) {
isolate->ThrowException(Exception::Error(
String::NewFromUtf8(isolate, "WebSocket is not open").ToLocalChecked()));
return;
}
Local<Value> idVal = self->Get(context, String::NewFromUtf8(isolate, "_id").ToLocalChecked()).ToLocalChecked();
int wsId = idVal->Int32Value(context).FromJust();
auto it = WebSocketWrapper::active_websockets.find(wsId);
if (it == WebSocketWrapper::active_websockets.end()) {
isolate->ThrowException(Exception::Error(
String::NewFromUtf8(isolate, "WebSocket connection not found").ToLocalChecked()));
return;
}
if (args.Length() < 1) return;
if (args[0]->IsArrayBuffer()) {
Local<ArrayBuffer> ab = args[0].As<ArrayBuffer>();
std::string data(static_cast<char*>(ab->GetBackingStore()->Data()), ab->ByteLength());
it->second->sendBinary(data);
} else if (args[0]->IsArrayBufferView()) {
Local<ArrayBufferView> view = args[0].As<ArrayBufferView>();
Local<ArrayBuffer> ab = view->Buffer();
size_t offset = view->ByteOffset();
size_t length = view->ByteLength();
std::string data(static_cast<char*>(ab->GetBackingStore()->Data()) + offset, length);
it->second->sendBinary(data);
} else {
String::Utf8Value str(isolate, args[0]);
it->second->send(*str);
}
}
static void WebSocketClose(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
Local<Context> context = isolate->GetCurrentContext();
Local<Object> self = args.Holder();
Local<Value> readyStateVal = self->Get(context, String::NewFromUtf8(isolate, "readyState").ToLocalChecked()).ToLocalChecked();
int readyState = readyStateVal->Int32Value(context).FromJust();
if (readyState == 2 || readyState == 3) return;
self->Set(context, String::NewFromUtf8(isolate, "readyState").ToLocalChecked(), Integer::New(isolate, 2));
Local<Value> idVal = self->Get(context, String::NewFromUtf8(isolate, "_id").ToLocalChecked()).ToLocalChecked();
int wsId = idVal->Int32Value(context).FromJust();
int code = 1000;
std::string reason = "";
if (args.Length() >= 1 && args[0]->IsNumber()) {
code = args[0]->Int32Value(context).FromJust();
}
if (args.Length() >= 2 && args[1]->IsString()) {
String::Utf8Value r(isolate, args[1]);
reason = *r;
}
auto it = WebSocketWrapper::active_websockets.find(wsId);
if (it != WebSocketWrapper::active_websockets.end()) {
it->second->close(code, reason);
}
}
void createWebSocketClass(Isolate* isolate, Local<ObjectTemplate>& global) {
Local<FunctionTemplate> wsTpl = FunctionTemplate::New(isolate, WebSocketConstructor);
wsTpl->SetClassName(String::NewFromUtf8(isolate, "WebSocket").ToLocalChecked());
wsTpl->InstanceTemplate()->SetInternalFieldCount(1);
Local<ObjectTemplate> proto = wsTpl->PrototypeTemplate();
proto->Set(isolate, "send", FunctionTemplate::New(isolate, WebSocketSend));
proto->Set(isolate, "close", FunctionTemplate::New(isolate, WebSocketClose));
Local<Function> wsFunc = wsTpl->GetFunction(isolate->GetCurrentContext()).ToLocalChecked();
wsFunc->Set(isolate->GetCurrentContext(), String::NewFromUtf8(isolate, "CONNECTING").ToLocalChecked(), Integer::New(isolate, 0));
wsFunc->Set(isolate->GetCurrentContext(), String::NewFromUtf8(isolate, "OPEN").ToLocalChecked(), Integer::New(isolate, 1));
wsFunc->Set(isolate->GetCurrentContext(), String::NewFromUtf8(isolate, "CLOSING").ToLocalChecked(), Integer::New(isolate, 2));
wsFunc->Set(isolate->GetCurrentContext(), String::NewFromUtf8(isolate, "CLOSED").ToLocalChecked(), Integer::New(isolate, 3));
global->Set(isolate, "WebSocket", wsTpl);
}
static void EventConstructor(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
Local<Context> context = isolate->GetCurrentContext();
Local<Object> instance = args.This();
Local<String> type = args.Length() > 0 && args[0]->IsString()
? args[0].As<String>()
: String::NewFromUtf8(isolate, "").ToLocalChecked();
instance->Set(context, String::NewFromUtf8(isolate, "type").ToLocalChecked(), type);
instance->Set(context, String::NewFromUtf8(isolate, "target").ToLocalChecked(), Null(isolate));
instance->Set(context, String::NewFromUtf8(isolate, "currentTarget").ToLocalChecked(), Null(isolate));
instance->Set(context, String::NewFromUtf8(isolate, "timeStamp").ToLocalChecked(), Number::New(isolate, static_cast<double>(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count())));
args.GetReturnValue().Set(instance);
}
static void MessageEventConstructor(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
Local<Context> context = isolate->GetCurrentContext();
Local<Object> instance = args.This();
Local<String> type = args.Length() > 0 && args[0]->IsString()
? args[0].As<String>()
: String::NewFromUtf8(isolate, "message").ToLocalChecked();
instance->Set(context, String::NewFromUtf8(isolate, "type").ToLocalChecked(), type);
instance->Set(context, String::NewFromUtf8(isolate, "target").ToLocalChecked(), Null(isolate));
instance->Set(context, String::NewFromUtf8(isolate, "currentTarget").ToLocalChecked(), Null(isolate));
instance->Set(context, String::NewFromUtf8(isolate, "timeStamp").ToLocalChecked(), Number::New(isolate, static_cast<double>(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count())));
if (args.Length() > 1 && args[1]->IsObject()) {
Local<Object> init = args[1].As<Object>();
Local<Value> data = init->Get(context, String::NewFromUtf8(isolate, "data").ToLocalChecked()).ToLocalChecked();
instance->Set(context, String::NewFromUtf8(isolate, "data").ToLocalChecked(), data);
Local<Value> origin = init->Get(context, String::NewFromUtf8(isolate, "origin").ToLocalChecked()).ToLocalChecked();
instance->Set(context, String::NewFromUtf8(isolate, "origin").ToLocalChecked(), origin->IsUndefined() ? String::NewFromUtf8(isolate, "").ToLocalChecked().As<Value>() : origin);
} else {
instance->Set(context, String::NewFromUtf8(isolate, "data").ToLocalChecked(), Null(isolate));
instance->Set(context, String::NewFromUtf8(isolate, "origin").ToLocalChecked(), String::NewFromUtf8(isolate, "").ToLocalChecked());
}
args.GetReturnValue().Set(instance);
}
static void CloseEventConstructor(const FunctionCallbackInfo<Value>& args) {
Isolate* isolate = args.GetIsolate();
HandleScope scope(isolate);
Local<Context> context = isolate->GetCurrentContext();
Local<Object> instance = args.This();
Local<String> type = args.Length() > 0 && args[0]->IsString()
? args[0].As<String>()
: String::NewFromUtf8(isolate, "close").ToLocalChecked();
instance->Set(context, String::NewFromUtf8(isolate, "type").ToLocalChecked(), type);
instance->Set(context, String::NewFromUtf8(isolate, "target").ToLocalChecked(), Null(isolate));
instance->Set(context, String::NewFromUtf8(isolate, "currentTarget").ToLocalChecked(), Null(isolate));
instance->Set(context, String::NewFromUtf8(isolate, "timeStamp").ToLocalChecked(), Number::New(isolate, static_cast<double>(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count())));
int code = 1000;
std::string reason = "";
bool wasClean = false;
if (args.Length() > 1 && args[1]->IsObject()) {
Local<Object> init = args[1].As<Object>();
Local<Value> codeVal = init->Get(context, String::NewFromUtf8(isolate, "code").ToLocalChecked()).ToLocalChecked();
if (codeVal->IsNumber()) code = codeVal->Int32Value(context).FromJust();
Local<Value> reasonVal = init->Get(context, String::NewFromUtf8(isolate, "reason").ToLocalChecked()).ToLocalChecked();
if (reasonVal->IsString()) {
String::Utf8Value r(isolate, reasonVal);
reason = *r;
}
Local<Value> cleanVal = init->Get(context, String::NewFromUtf8(isolate, "wasClean").ToLocalChecked()).ToLocalChecked();
if (cleanVal->IsBoolean()) wasClean = cleanVal->BooleanValue(isolate);
}
instance->Set(context, String::NewFromUtf8(isolate, "code").ToLocalChecked(), Integer::New(isolate, code));
instance->Set(context, String::NewFromUtf8(isolate, "reason").ToLocalChecked(), String::NewFromUtf8(isolate, reason.c_str()).ToLocalChecked());
instance->Set(context, String::NewFromUtf8(isolate, "wasClean").ToLocalChecked(), Boolean::New(isolate, wasClean));
args.GetReturnValue().Set(instance);
}
void createWebSocketEventClasses(Isolate* isolate, Local<ObjectTemplate>& global) {
Local<FunctionTemplate> eventTpl = FunctionTemplate::New(isolate, EventConstructor);
eventTpl->SetClassName(String::NewFromUtf8(isolate, "Event").ToLocalChecked());
global->Set(isolate, "Event", eventTpl);
Local<FunctionTemplate> msgEventTpl = FunctionTemplate::New(isolate, MessageEventConstructor);
msgEventTpl->SetClassName(String::NewFromUtf8(isolate, "MessageEvent").ToLocalChecked());
global->Set(isolate, "MessageEvent", msgEventTpl);
Local<FunctionTemplate> closeEventTpl = FunctionTemplate::New(isolate, CloseEventConstructor);
closeEventTpl->SetClassName(String::NewFromUtf8(isolate, "CloseEvent").ToLocalChecked());
global->Set(isolate, "CloseEvent", closeEventTpl);
}
#endif