You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

546 lines
13 KiB
C++

/***************************************************************************
* Copyright (C) 2013 by gempa GmbH *
* *
* All Rights Reserved. *
* *
* NOTICE: All information contained herein is, and remains *
* the property of gempa GmbH and its suppliers, if any. The intellectual *
* and technical concepts contained herein are proprietary to gempa GmbH *
* and its suppliers. *
* Dissemination of this information or reproduction of this material *
* is strictly forbidden unless prior written permission is obtained *
* from gempa GmbH. *
***************************************************************************/
#include <gempa/caps/socket.h>
#include <gempa/caps/log.h>
#include <fcntl.h>
#include <sys/types.h>
#ifndef WIN32
#include <sys/socket.h>
#include <netdb.h>
#include <unistd.h>
#else
#ifndef SHUT_RDWR
#define SHUT_RDWR SD_BOTH
#endif
#define _WIN32_WINNT 0x0501 // Older versions does not support getaddrinfo
#include <io.h>
#include <winsock2.h>
#include <ws2tcpip.h>
#endif
#include <cerrno>
#include <cstring>
#include <sstream>
using namespace std;
namespace {
inline string toString(unsigned short value) {
stringstream ss;
ss << value;
return ss.str();
}
}
namespace Gempa {
namespace CAPS {
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Socket::Socket() {
_fd = -1;
_bytesSent = _bytesReceived = 0;
_timeOutSecs = _timeOutUsecs = 0;
#ifdef WIN32
WSADATA wsa;
WSAStartup(MAKEWORD(2,0),&wsa);
#endif
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Socket::~Socket() {
close();
#ifdef WIN32
WSACleanup();
#endif
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
const char *Socket::toString(Status stat) {
switch ( stat ) {
case Success:
return "success";
default:
case Error:
return "error";
case AllocationError:
return "allocation error";
case ReuseAdressError:
return "reusing address failed";
case BindError:
return "bind error";
case ListenError:
return "listen error";
case AcceptError:
return "accept error";
case ConnectError:
return "connect error";
case AddrInfoError:
return "address info error";
case Timeout:
return "timeout";
case InvalidSocket:
return "invalid socket";
case InvalidPort:
return "invalid port";
case InvalidAddressFamily:
return "invalid address family";
case InvalidAddress:
return "invalid address";
case InvalidHostname:
return "invalid hostname";
}
return "";
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
void Socket::shutdown() {
if ( _fd == -1 ) return;
//CAPS_DEBUG("Socket::shutdown");
::shutdown(_fd, SHUT_RDWR);
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
void Socket::close() {
if ( _fd != -1 ) {
//CAPS_DEBUG("[socket] close %lX with fd = %d", (long int)this, _fd);
int fd = _fd;
_fd = -1;
::close(fd);
}
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Socket::Status Socket::setSocketTimeout(int secs, int usecs) {
_timeOutSecs = secs;
_timeOutUsecs = usecs;
if ( _fd != -1 )
return applySocketTimeout(_timeOutSecs, _timeOutUsecs);
return Success;
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Socket::Status Socket::applySocketTimeout(int secs, int usecs) {
if ( _fd != -1 ) {
struct timeval timeout;
void *opt;
int optlen;
if ( secs >= 0 ) {
timeout.tv_sec = secs;
timeout.tv_usec = usecs;
opt = &timeout;
optlen = sizeof(timeout);
}
else {
opt = NULL;
optlen = 0;
}
CAPS_DEBUG("set socket timeout to %d.%ds", secs, usecs);
if ( setsockopt(_fd, SOL_SOCKET, SO_RCVTIMEO, opt, optlen) )
return Error;
if ( setsockopt(_fd, SOL_SOCKET, SO_SNDTIMEO, opt, optlen) )
return Error;
}
else
return InvalidSocket;
return Success;
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Socket::Device::Status Socket::setNonBlocking(bool nb) {
if ( !isValid() )
return Device::InvalidDevice;
#ifndef WIN32
int flags = fcntl(_fd, F_GETFL, 0);
if ( nb )
flags |= O_NONBLOCK;
else
flags &= ~O_NONBLOCK;
if ( fcntl(_fd, F_SETFL, flags) == -1 )
return Device::Error;
#else
u_long arg = nb?1:0;
if ( ioctlsocket(_fd, FIONBIO, &arg) != 0 )
return Device::Error;
#endif
return Device::Success;
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Socket::Status Socket::connect(const std::string &hostname, uint16_t port) {
if ( _fd != -1 ) {
//CAPS_WARNING("closing stale socket");
close();
}
struct sockaddr addr;
size_t addrlen;
struct addrinfo *res;
struct addrinfo hints;
memset (&hints, 0, sizeof(hints));
hints.ai_family = PF_INET;
hints.ai_socktype = SOCK_STREAM;
string strPort = ::toString(port);
int ret = getaddrinfo(hostname.c_str(), strPort.c_str(), &hints, &res);
if ( ret ) {
CAPS_DEBUG("Test3 Socket::connect(%s:%d): %s",
hostname.c_str(), port,
#ifndef WIN32
strerror(errno));
#else
gai_strerror(ret));
#endif
return AddrInfoError;
}
addr = *(res->ai_addr);
addrlen = res->ai_addrlen;
freeaddrinfo(res);
if ( (_fd = socket(PF_INET, SOCK_STREAM, 0)) < 0 ) {
/*CAPS_DEBUG("Socket::connect(%s:%d): %s",
hostname.c_str(), port, strerror(errno));*/
return AllocationError;
}
#ifndef WIN32
if ( ::connect(_fd, (struct sockaddr *)&addr, addrlen) == -1 ) {
if ( errno != EINPROGRESS ) {
/*CAPS_DEBUG("Socket::connect(%s:%d): %s",
hostname.c_str(), port, strerror(errno));*/
close();
return errno == ETIMEDOUT?Timeout:ConnectError;
}
}
#else
if ( ::connect(_fd, (struct sockaddr *)&addr, addrlen) == SOCKET_ERROR ) {
int err = WSAGetLastError();
if (err != WSAEINPROGRESS && err != WSAEWOULDBLOCK) {
CAPS_DEBUG("Socket::connect(%s:%d): %s",
hostname.c_str(), port, gai_strerror(err));
close();
return ConnectError;
}
}
#endif
return Success;
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
int Socket::send(const char *data) {
return write(data, strlen(data));
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
int Socket::write(const char *data, int len) {
#if !defined(MACOSX) && !defined(WIN32)
int sent = (int)::send(_fd, data, len, MSG_NOSIGNAL);
#else
int sent = (int)::send(_fd, data, len, 0);
#endif
if ( sent > 0 ) {
_bytesSent += sent;
}
return sent;
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
int Socket::read(char *data, int len) {
int recvd = (int)::recv(_fd, data, len, 0);
if ( recvd > 0 ) _bytesReceived += recvd;
return recvd;
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
int Socket::flush() { return 1; }
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
bool Socket::isValid() {
return _fd != -1;
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
#if !defined(CAPS_FEATURES_SSL) || CAPS_FEATURES_SSL
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
SSLSocket::SSLSocket() : _ssl(NULL), _ctx(NULL) {}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
SSLSocket::SSLSocket(SSL_CTX *ctx) : _ssl(NULL), _ctx(ctx) {}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
SSLSocket::~SSLSocket() {
close();
cleanUp();
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
int SSLSocket::write(const char *data, int len) {
int ret = SSL_write(_ssl, data, len);
if ( ret > 0 ) {
_bytesSent += ret;
return ret;
}
int err = SSL_get_error(_ssl, ret);
switch ( err ) {
case SSL_ERROR_WANT_X509_LOOKUP:
errno = EAGAIN;
return -1;
case SSL_ERROR_WANT_READ:
errno = EAGAIN;
return -1;
case SSL_ERROR_WANT_WRITE:
errno = EAGAIN;
return -1;
case SSL_ERROR_ZERO_RETURN:
errno = EINVAL;
return 0;
default:
break;
}
return ret;
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
int SSLSocket::read(char *data, int len) {
int ret = SSL_read(_ssl, data, len);
if ( ret > 0 ) {
_bytesReceived += ret;
return ret;
}
int err = SSL_get_error(_ssl, ret);
switch ( err ) {
case SSL_ERROR_WANT_X509_LOOKUP:
errno = EAGAIN;
return -1;
case SSL_ERROR_WANT_READ:
errno = EAGAIN;
return -1;
case SSL_ERROR_WANT_WRITE:
errno = EAGAIN;
return -1;
case SSL_ERROR_ZERO_RETURN:
errno = EINVAL;
return 0;
default:
break;
}
return ret;
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Socket::Status SSLSocket::connect(const std::string &hostname, uint16_t port) {
cleanUp();
_ctx = SSL_CTX_new(SSLv23_client_method());
if ( _ctx == NULL ) {
CAPS_DEBUG("Invalid SSL context");
return ConnectError;
}
SSL_CTX_set_mode(_ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
Status s = Socket::connect(hostname, port);
if ( s != Success )
return s;
_ssl = SSL_new(_ctx);
if ( _ssl == NULL ) {
CAPS_DEBUG("Failed to create SSL context");
return ConnectError;
}
SSL_set_fd(_ssl, _fd);
SSL_set_shutdown(_ssl, 0);
SSL_set_connect_state(_ssl);
int err = SSL_connect(_ssl);
if ( err < 0 ) {
CAPS_ERROR("Failed to connect with SSL, error %d",
SSL_get_error(_ssl, err));
close();
return ConnectError;
}
return Success;
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
const unsigned char *SSLSocket::sessionID() const {
#if OPENSSL_VERSION_NUMBER < 0x10100000L
return _ssl?_ssl->session->session_id:NULL;
#else
return _ssl?SSL_SESSION_get0_id_context(SSL_get0_session(_ssl), NULL):NULL;
#endif
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
unsigned int SSLSocket::sessionIDLength() const {
#if OPENSSL_VERSION_NUMBER < 0x10100000L
return _ssl?_ssl->session->session_id_length:0;
#else
unsigned int len;
if ( !_ssl ) return 0;
SSL_SESSION_get0_id_context(SSL_get0_session(_ssl), &len);
return len;
#endif
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
X509 *SSLSocket::peerCertificate() {
if ( _ssl == NULL ) return NULL;
return SSL_get_peer_certificate(_ssl);
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
void SSLSocket::cleanUp() {
if ( _ssl ) {
SSL_free(_ssl);
_ssl = NULL;
}
if ( _ctx ) {
SSL_CTX_free(_ctx);
_ctx = NULL;
}
}
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
#endif
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
}
}