Files
LNXRNT/Sources/connection_pool.cpp
2026-02-20 23:40:15 -08:00

182 lines
4.9 KiB
C++

#include "connection_pool.h"
#include "socket_optimization.h"
#include <kinc/log.h>
#ifdef _WIN32
#include <ws2tcpip.h>
#pragma comment(lib, "ws2_32.lib")
#else
#include <unistd.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <sys/socket.h>
#endif
// global
ConnectionPool g_connection_pool;
ConnectionPool::~ConnectionPool() {
std::lock_guard<std::mutex> lock(pool_mutex_);
for (auto& host_pair : pool_) {
while (!host_pair.second.empty()) {
auto conn = host_pair.second.front();
host_pair.second.pop();
#ifdef _WIN32
closesocket(conn.socket);
#else
close(conn.socket);
#endif
}
}
}
socket_t ConnectionPool::acquire_connection(const std::string& host, int port, bool is_ssl) {
std::string key = make_key(host, port, is_ssl);
{
std::lock_guard<std::mutex> lock(pool_mutex_);
auto it = pool_.find(key);
if (it != pool_.end() && !it->second.empty()) {
auto conn = it->second.front();
it->second.pop();
if (is_socket_alive(conn.socket)) {
kinc_log(KINC_LOG_LEVEL_INFO, "Reusing pooled connection for %s:%d", host.c_str(), port);
return conn.socket;
} else {
#ifdef _WIN32
closesocket(conn.socket);
#else
close(conn.socket);
#endif
}
}
}
socket_t sock = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (sock == INVALID_SOCKET) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to create socket for %s:%d", host.c_str(), port);
return INVALID_SOCKET;
}
SocketOptimization::optimizeHttp(static_cast<int>(sock));
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port);
struct hostent* host_entry = gethostbyname(host.c_str());
if (!host_entry) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to resolve hostname %s", host.c_str());
#ifdef _WIN32
closesocket(sock);
#else
close(sock);
#endif
return INVALID_SOCKET;
}
memcpy(&server_addr.sin_addr, host_entry->h_addr, host_entry->h_length);
if (::connect(sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) == SOCKET_ERROR) {
kinc_log(KINC_LOG_LEVEL_ERROR, "Failed to connect to %s:%d", host.c_str(), port);
#ifdef _WIN32
closesocket(sock);
#else
close(sock);
#endif
return INVALID_SOCKET;
}
kinc_log(KINC_LOG_LEVEL_INFO, "Created new connection for %s:%d", host.c_str(), port);
return sock;
}
void ConnectionPool::release_connection(const std::string& host, int port, socket_t socket, bool is_ssl) {
if (socket == INVALID_SOCKET) return;
std::string key = make_key(host, port, is_ssl);
std::lock_guard<std::mutex> lock(pool_mutex_);
auto& queue = pool_[key];
if (queue.size() >= MAX_CONNECTIONS_PER_HOST) {
#ifdef _WIN32
closesocket(socket);
#else
close(socket);
#endif
return;
}
PooledConnection conn;
conn.socket = socket;
conn.last_used = std::chrono::steady_clock::now();
conn.is_ssl = is_ssl;
queue.push(conn);
kinc_log(KINC_LOG_LEVEL_INFO, "Pooled connection for %s:%d", host.c_str(), port);
}
void ConnectionPool::cleanup_expired() {
std::lock_guard<std::mutex> lock(pool_mutex_);
auto now = std::chrono::steady_clock::now();
for (auto& host_pair : pool_) {
auto& queue = host_pair.second;
std::queue<PooledConnection> new_queue;
while (!queue.empty()) {
auto conn = queue.front();
queue.pop();
auto age = std::chrono::duration_cast<std::chrono::seconds>(now - conn.last_used);
if (age.count() < CONNECTION_TIMEOUT_SECONDS && is_socket_alive(conn.socket)) {
new_queue.push(conn);
} else {
#ifdef _WIN32
closesocket(conn.socket);
#else
close(conn.socket);
#endif
}
}
queue = std::move(new_queue);
}
}
size_t ConnectionPool::get_pool_size() const {
std::lock_guard<std::mutex> lock(pool_mutex_);
size_t total = 0;
for (const auto& host_pair : pool_) {
total += host_pair.second.size();
}
return total;
}
std::string ConnectionPool::make_key(const std::string& host, int port, bool is_ssl) const {
return host + ":" + std::to_string(port) + (is_ssl ? ":ssl" : "");
}
bool ConnectionPool::is_socket_alive(socket_t socket) const {
char test_byte;
int result = ::recv(socket, &test_byte, 1, MSG_PEEK);
#ifdef _WIN32
if (result == SOCKET_ERROR) {
int error = WSAGetLastError();
return (error == WSAEWOULDBLOCK);
}
#else
if (result < 0) {
return (errno == EWOULDBLOCK || errno == EAGAIN);
}
#endif
return true;
}