00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #include "SSLSocket.h"
00020 #include "CODEX_Quorum/RemoteServer.h"
00021 #include <sys/errno.h>
00022 #include <openssl/err.h>
00023 #include <iostream>
00024 #include <sstream>
00025
00026 #include "timing.h"
00027
00028 using namespace CODEX_SSL;
00029 using namespace CODEX_Quorum;
00030
00031 SSLSocket::SSLSocket( SSL_CTX* ctx ,
00032 int domain ,
00033 int type ,
00034 int protocol ,
00035 bool blocking ) :
00036 SocketBase(domain, type, protocol, blocking),
00037 m_ctx( ctx ),
00038 m_ssl_con( 0 ),
00039 m_needRead( false ),
00040 m_needWrite( false )
00041 {
00042 }
00043
00044 SSLSocket::SSLSocket( const SSLSocket& aOther ) :
00045 SocketBase( aOther ),
00046 m_ctx( aOther.m_ctx ),
00047 m_ssl_con( 0 ),
00048 m_needRead( false ),
00049 m_needWrite( false )
00050 {
00051 }
00052
00053 SSLSocket::~SSLSocket()
00054 {
00055 if ( 0 != m_ssl_con )
00056 {
00057 SSL_set_shutdown( m_ssl_con, SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN );
00058 SSL_shutdown( m_ssl_con );
00059
00060 SSL_free( m_ssl_con );
00061 }
00062 }
00063
00064 int
00065 SSLSocket::set_fd( fd_set* fd_bitmap, StateType s ) const
00066 {
00067 bool ok = true;
00068 switch(s)
00069 {
00070 case SocketBase::kRead :
00071 if ( m_needWrite )
00072 {
00073 ok = false;
00074 }
00075 break;
00076 case SocketBase::kWrite :
00077 if ( m_needRead )
00078 {
00079 ok = false;
00080 }
00081 break;
00082 }
00083 if ( ok )
00084 {
00085 return SocketBase::set_fd( fd_bitmap, s );
00086 }
00087 return socket();
00088 }
00089
00090 bool
00091 SSLSocket::isset_fd( const fd_set* fd_bitmap, StateType s ) const
00092 {
00093 if ( m_needWrite && ( SocketBase::kRead == s ) )
00094 {
00095 return false;
00096 }
00097 if ( m_needRead && ( SocketBase::kWrite == s ) )
00098 {
00099 return false;
00100 }
00101
00102
00103
00104 if ( ( SocketBase::kRead == s ) &&
00105 ( 0 != m_ssl_con ) &&
00106 ( SSL_pending( m_ssl_con ) > 0 ) )
00107 {
00108 return true;
00109 }
00110 return SocketBase::isset_fd( fd_bitmap, s );
00111 }
00112
00113 size_t
00114 SSLSocket::readFrom( void* output, size_t maxSize ) const
00115 {
00116 if ( m_needWrite )
00117 {
00118
00119 return 0;
00120 }
00121 if ( 0 == m_ssl_con )
00122 {
00123 return 0;
00124 }
00125 #ifdef TIMING
00126 SSLTimer.start();
00127 #endif
00128 m_needRead = false;
00129 int bytesRead = SSL_read( m_ssl_con, output, maxSize );
00130 switch ( SSL_get_error( m_ssl_con, bytesRead ) )
00131 {
00132 case SSL_ERROR_NONE :
00133 break;
00134 case SSL_ERROR_ZERO_RETURN :
00135 #ifdef TIMING
00136 SSLTimer.stop();
00137 #endif
00138 throw CODEX_Quorum::QSESocketBaseSocketClosed( __FILE__ , __LINE__ ,
00139 socket(), 0 );
00140 break;
00141 case SSL_ERROR_WANT_READ :
00142 case SSL_ERROR_WANT_WRITE :
00143 case SSL_ERROR_WANT_CONNECT :
00144 case SSL_ERROR_WANT_X509_LOOKUP :
00145
00146 m_needRead = true;
00147 bytesRead = 0;
00148 break;
00149 case SSL_ERROR_SYSCALL :
00150 case SSL_ERROR_SSL :
00151 unsigned long err = ERR_get_error();
00152 #ifdef TIMING
00153 SSLTimer.stop();
00154 #endif
00155 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00156 break;
00157 }
00158 #ifdef TIMING
00159 SSLTimer.stop();
00160 #endif
00161 return bytesRead;
00162 }
00163
00164 int
00165 SSLSocket::internal_write( const unsigned char* output, size_t maxSize ) const
00166 {
00167 if ( m_needRead )
00168 {
00169
00170 return 0;
00171 }
00172 if ( 0 == m_ssl_con )
00173 {
00174 return 0;
00175 }
00176 #ifdef TIMING
00177 SSLTimer.start();
00178 #endif
00179 m_needWrite = false;
00180 int bytesWritten = SSL_write( m_ssl_con, output, maxSize );
00181 switch ( SSL_get_error( m_ssl_con, bytesWritten ) )
00182 {
00183 case SSL_ERROR_NONE :
00184 break;
00185 case SSL_ERROR_ZERO_RETURN :
00186 #ifdef TIMING
00187 SSLTimer.stop();
00188 #endif
00189 throw CODEX_Quorum::QSESocketBaseSocketClosed( __FILE__ , __LINE__ ,
00190 socket(), 0 );
00191 break;
00192 case SSL_ERROR_WANT_READ :
00193 case SSL_ERROR_WANT_WRITE :
00194 case SSL_ERROR_WANT_CONNECT :
00195 case SSL_ERROR_WANT_X509_LOOKUP :
00196
00197 m_needWrite = true;
00198 bytesWritten = 0;
00199 break;
00200 case SSL_ERROR_SYSCALL :
00201 case SSL_ERROR_SSL :
00202 unsigned long err = ERR_get_error();
00203 #ifdef TIMING
00204 SSLTimer.stop();
00205 #endif
00206 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00207 break;
00208 }
00209 #ifdef TIMING
00210 SSLTimer.stop();
00211 #endif
00212 return bytesWritten;
00213 }
00214
00215 SocketBase*
00216 SSLSocket::clone()
00217 {
00218 SocketBase* copy = new SSLSocket( *this );
00219 return copy;
00220 }
00221
00222 void
00223 SSLSocket::connect( const RemoteServer& server )
00224 {
00225 if ( 0 != m_ssl_con )
00226 {
00227
00228
00229 SSL_free( m_ssl_con );
00230 m_ssl_con = 0;
00231 }
00232 try
00233 {
00234
00235 SocketBase::connect( server );
00236
00237 m_ssl_con = SSL_new( m_ctx );
00238 if ( 0 == m_ssl_con )
00239 {
00240 unsigned long err = ERR_get_error();
00241 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00242 }
00243 SSL_set_fd( m_ssl_con, socket() );
00244
00245 int c;
00246 do
00247 {
00248 c = SSL_connect(m_ssl_con);
00249 if ( c <= 0 )
00250 {
00251 switch ( SSL_get_error( m_ssl_con, c ) )
00252 {
00253 case SSL_ERROR_WANT_READ :
00254 case SSL_ERROR_WANT_WRITE :
00255 break;
00256 default :
00257 unsigned long err = ERR_get_error();
00258 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00259 }
00260 }
00261 } while ( c <= 0 );
00262 }
00263 catch ( QSESSLSocket& e )
00264 {
00265 e.report();
00266 if ( 0 != m_ssl_con ) SSL_free( m_ssl_con );
00267 m_ssl_con = 0;
00268 throw;
00269 }
00270 catch ( ... )
00271 {
00272 if ( 0 != m_ssl_con ) SSL_free( m_ssl_con );
00273 m_ssl_con = 0;
00274 throw;
00275 }
00276 }
00277
00278 void
00279 SSLSocket::finish_accept()
00280 {
00281 if ( 0 != m_ssl_con )
00282 {
00283
00284
00285 SSL_free( m_ssl_con );
00286 m_ssl_con = 0;
00287 }
00288 try
00289 {
00290 m_ssl_con = (SSL*) SSL_new(m_ctx);
00291 if ( 0 == m_ssl_con )
00292 {
00293 unsigned long err = ERR_get_error();
00294 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00295 }
00296 SSL_clear(m_ssl_con);
00297 if ( ! SSL_set_fd( m_ssl_con, socket() ) )
00298 {
00299 unsigned long err = ERR_get_error();
00300 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00301 }
00302 SSL_set_accept_state(m_ssl_con);
00303
00304
00305 int a;
00306 do
00307 {
00308 a = SSL_accept(m_ssl_con);
00309 if ( a <= 0 )
00310 {
00311 switch ( SSL_get_error( m_ssl_con, a ) )
00312 {
00313 case SSL_ERROR_WANT_READ :
00314 case SSL_ERROR_WANT_WRITE :
00315 break;
00316 default :
00317 unsigned long err = ERR_get_error();
00318 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00319 }
00320 }
00321 } while ( a <= 0 );
00322 }
00323 catch ( QSESSLSocket& e )
00324 {
00325 e.report();
00326 if ( 0 != m_ssl_con ) SSL_free( m_ssl_con );
00327 m_ssl_con = 0;
00328 throw;
00329 }
00330 catch ( ... )
00331 {
00332 if ( 0 != m_ssl_con ) SSL_free( m_ssl_con );
00333 m_ssl_con = 0;
00334 throw;
00335 }
00336 }
00337
00340 SSLSocketBuilder::SSLSocketBuilder( SSL_METHOD* meth,
00341 const X509* cert,
00342 const RSA* privKey,
00343 const char* ciphers,
00344 const char* caCertFile,
00345 const char* hostCertFile,
00346 int verify,
00347 int domain,
00348 int type,
00349 int protocol,
00350 bool blocking ) :
00351 CODEX_Quorum::SocketBuilder( domain, type, protocol, blocking )
00352 {
00353 try
00354 {
00355 m_ctx = SSL_CTX_new(meth);
00356 if ( 0 == m_ctx )
00357 {
00358 throw SSLNullContextException( __FILE__ , __LINE__ );
00359 }
00360
00361 if ( ! SSL_CTX_use_certificate( m_ctx, (X509*)cert ) )
00362 {
00363 unsigned long err = ERR_get_error();
00364 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00365 }
00366 if ( ! SSL_CTX_use_RSAPrivateKey( m_ctx, (RSA*)privKey ) )
00367
00368 {
00369 unsigned long err = ERR_get_error();
00370 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00371 }
00372 if ( ! SSL_CTX_check_private_key( m_ctx ) )
00373 {
00374 unsigned long err = ERR_get_error();
00375 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00376 }
00377 if ( 0 == ciphers )
00378 {
00379 throw SSLNullCiphersException( __FILE__ , __LINE__ );
00380 }
00381 SSL_CTX_set_cipher_list( m_ctx, ciphers );
00382 if ( ! SSL_CTX_load_verify_locations( m_ctx, caCertFile, 0 ) )
00383 {
00384 unsigned long err = ERR_get_error();
00385 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00386 }
00387 if ( ! SSL_CTX_set_default_verify_paths( m_ctx ) )
00388 {
00389 unsigned long err = ERR_get_error();
00390 throw QSESSLSocket( __FILE__ , __LINE__ , err );
00391 }
00392 if ( 0 == verify )
00393 {
00394 throw SSLVerificationFlagsException( __FILE__ , __LINE__ );
00395 }
00396 SSL_CTX_set_verify( m_ctx, verify, 0 );
00397 SSL_CTX_set_client_CA_list( m_ctx,
00398 SSL_load_client_CA_file(hostCertFile) );
00399 }
00400 catch ( ... )
00401 {
00402 if ( 0 != m_ctx ) SSL_CTX_free( m_ctx );
00403 throw;
00404 }
00405 }
00406
00407 SSLSocketBuilder::~SSLSocketBuilder()
00408 {
00409 if ( 0 != m_ctx ) SSL_CTX_free( m_ctx );
00410 }
00411
00412 SocketBase*
00413 SSLSocketBuilder::operator()() const
00414 {
00415 return new SSLSocket( m_ctx, m_domain, m_type, m_protocol, m_blocking );
00416 }
00417
00418
00419
00420
00421 void
00422 QSESSLSocket::errMsg() const
00423 {
00424 cerr << "SSL error:\n "
00425 << ERR_error_string(error(),0);
00426 }
00427
00428 void
00429 SSLExceptionBase::report() const
00430 {
00431 cerr << "At line " << m_line
00432 << " of file " << m_fname
00433 << "\nSSL exception: ";
00434 derivedMsg();
00435 cerr << endl;
00436 }
00437
00438 void
00439 SSLNullContextException::derivedMsg() const
00440 {
00441 cerr << "Invalid context";
00442 }
00443
00444 void
00445 SSLNullCiphersException::derivedMsg() const
00446 {
00447 cerr << "Invalid cipher list";
00448 }
00449
00450 void
00451 SSLVerificationFlagsException::derivedMsg() const
00452 {
00453 cerr << "Invalid SSL verification flags";
00454 }