Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members | Related Pages

SSLSocket.cc

00001 /*
00002  * Copyright 2003 Michael A. Marsh, Cornell University. All rights reserved.
00003  * This software is released under the modified BSD license.
00004  * See the file LICENSE in the top-level directory for details.
00005  */
00006 //
00007 // $Id: SSLSocket.cc,v 1.3 2004/05/19 15:56:57 mmarsh Exp $
00008 //
00009 // $Log: SSLSocket.cc,v $
00010 // Revision 1.3  2004/05/19 15:56:57  mmarsh
00011 // *** empty log message ***
00012 //
00013 // Revision 1.2  2003/11/04 22:20:38  mmarsh
00014 // General code cleanup, including the addition of a few new exception
00015 // classes.
00016 //
00017 //
00018 
00019 #include "SSLSocket.h"
00020 #include "CODEX_Quorum/RemoteServer.h"
00021 #include <sys/errno.h>
00022 #include <openssl/err.h>
00023 #include <iostream>
00024 #include <sstream>
00025 
00026 #include "timing.h"
00027 
00028 using namespace CODEX_SSL;
00029 using namespace CODEX_Quorum;
00030 
00031 SSLSocket::SSLSocket( SSL_CTX*  ctx      ,
00032                       int       domain   ,
00033                       int       type     ,
00034                       int       protocol ,
00035                       bool      blocking ) :
00036    SocketBase(domain, type, protocol, blocking),
00037    m_ctx( ctx ),
00038    m_ssl_con( 0 ),
00039    m_needRead( false ),
00040    m_needWrite( false )
00041 {
00042 }
00043 
00044 SSLSocket::SSLSocket( const SSLSocket& aOther ) :
00045    SocketBase( aOther ),
00046    m_ctx( aOther.m_ctx ),
00047    m_ssl_con( 0 ),
00048    m_needRead( false ),
00049    m_needWrite( false )
00050 {
00051 }
00052 
00053 SSLSocket::~SSLSocket()
00054 {
00055    if ( 0 != m_ssl_con )
00056    {
00057       SSL_set_shutdown( m_ssl_con, SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN );
00058       SSL_shutdown( m_ssl_con );
00059       //close( SSL_get_fd( m_ssl_con ) );
00060       SSL_free( m_ssl_con );
00061    }
00062 }
00063 
00064 int
00065 SSLSocket::set_fd( fd_set* fd_bitmap, StateType s ) const
00066 {
00067    bool ok = true;
00068    switch(s)
00069    {
00070       case SocketBase::kRead :
00071          if ( m_needWrite )
00072          {
00073             ok = false;
00074          }
00075          break;
00076       case SocketBase::kWrite :
00077          if ( m_needRead )
00078          {
00079             ok = false;
00080          }
00081          break;
00082    }
00083    if ( ok )
00084    {
00085       return SocketBase::set_fd( fd_bitmap, s );
00086    }
00087    return socket();
00088 }
00089 
00090 bool
00091 SSLSocket::isset_fd( const fd_set* fd_bitmap, StateType s ) const
00092 {
00093    if ( m_needWrite && ( SocketBase::kRead == s ) )
00094    {
00095       return false;
00096    }
00097    if ( m_needRead && ( SocketBase::kWrite == s ) )
00098    {
00099       return false;
00100    }
00101    // Because SSL abstracts a layer on top of the raw sockets, it is
00102    // possible that there's data to read from the SSL structure, but
00103    // the socket itself has nothing left to read.
00104    if ( ( SocketBase::kRead == s ) &&
00105         ( 0 != m_ssl_con ) &&
00106         ( SSL_pending( m_ssl_con ) > 0 ) )
00107    {
00108       return true;
00109    }
00110    return SocketBase::isset_fd( fd_bitmap, s );
00111 }
00112 
00113 size_t
00114 SSLSocket::readFrom( void* output, size_t maxSize ) const
00115 {
00116    if ( m_needWrite )
00117    {
00118       // We should not have been called.
00119       return 0;
00120    }
00121    if ( 0 == m_ssl_con )
00122    {
00123       return 0;
00124    }
00125 #ifdef TIMING
00126    SSLTimer.start();
00127 #endif
00128    m_needRead = false;
00129    int bytesRead = SSL_read( m_ssl_con, output, maxSize );
00130    switch ( SSL_get_error( m_ssl_con, bytesRead ) )
00131    {
00132       case SSL_ERROR_NONE :
00133          break;
00134       case SSL_ERROR_ZERO_RETURN :
00135 #ifdef TIMING
00136          SSLTimer.stop();
00137 #endif
00138          throw CODEX_Quorum::QSESocketBaseSocketClosed( __FILE__ , __LINE__ ,
00139                                                         socket(), 0 );
00140          break;
00141       case SSL_ERROR_WANT_READ :
00142       case SSL_ERROR_WANT_WRITE :
00143       case SSL_ERROR_WANT_CONNECT :
00144       case SSL_ERROR_WANT_X509_LOOKUP :
00145          // The action will need to be repeated.
00146          m_needRead = true;
00147          bytesRead = 0;
00148          break;
00149       case SSL_ERROR_SYSCALL :
00150       case SSL_ERROR_SSL :
00151          unsigned long err = ERR_get_error();
00152 #ifdef TIMING
00153          SSLTimer.stop();
00154 #endif
00155          throw QSESSLSocket( __FILE__ , __LINE__ , err );
00156          break;
00157    }
00158 #ifdef TIMING
00159    SSLTimer.stop();
00160 #endif
00161    return bytesRead;
00162 }
00163 
00164 int
00165 SSLSocket::internal_write( const unsigned char* output, size_t maxSize ) const
00166 {
00167    if ( m_needRead )
00168    {
00169       // We should not have been called.
00170       return 0;
00171    }
00172    if ( 0 == m_ssl_con )
00173    {
00174       return 0;
00175    }
00176 #ifdef TIMING
00177    SSLTimer.start();
00178 #endif
00179    m_needWrite = false;
00180    int bytesWritten = SSL_write( m_ssl_con, output, maxSize );
00181    switch ( SSL_get_error( m_ssl_con, bytesWritten ) )
00182    {
00183       case SSL_ERROR_NONE :
00184          break;
00185       case SSL_ERROR_ZERO_RETURN :
00186 #ifdef TIMING
00187          SSLTimer.stop();
00188 #endif
00189          throw CODEX_Quorum::QSESocketBaseSocketClosed( __FILE__ , __LINE__ ,
00190                                                         socket(), 0 );
00191          break;
00192       case SSL_ERROR_WANT_READ :
00193       case SSL_ERROR_WANT_WRITE :
00194       case SSL_ERROR_WANT_CONNECT :
00195       case SSL_ERROR_WANT_X509_LOOKUP :
00196          // The action will need to be repeated.
00197          m_needWrite = true;
00198          bytesWritten = 0;
00199          break;
00200       case SSL_ERROR_SYSCALL :
00201       case SSL_ERROR_SSL :
00202          unsigned long err = ERR_get_error();
00203 #ifdef TIMING
00204          SSLTimer.stop();
00205 #endif
00206          throw QSESSLSocket( __FILE__ , __LINE__ , err );
00207          break;
00208    }
00209 #ifdef TIMING
00210    SSLTimer.stop();
00211 #endif
00212    return bytesWritten;
00213 }
00214 
00215 SocketBase*
00216 SSLSocket::clone()
00217 {
00218    SocketBase* copy = new SSLSocket( *this );
00219    return copy;
00220 }
00221 
00222 void
00223 SSLSocket::connect( const RemoteServer& server )
00224 {
00225    if ( 0 != m_ssl_con )
00226    {
00227       // The lazy thing to do is just get rid of the existing context,
00228       // since it should always be 0 before calling connect().
00229       SSL_free( m_ssl_con );
00230       m_ssl_con = 0;
00231    }
00232    try
00233    {
00234       // Set up the socket through the base class.
00235       SocketBase::connect( server );
00236 
00237       m_ssl_con = SSL_new( m_ctx );
00238       if ( 0 == m_ssl_con )
00239       {
00240          unsigned long err = ERR_get_error();
00241          throw QSESSLSocket( __FILE__ , __LINE__ , err );
00242       }
00243       SSL_set_fd( m_ssl_con, socket() );
00244 
00245       int c;
00246       do
00247       {
00248          c = SSL_connect(m_ssl_con);
00249          if ( c <= 0 )
00250          {
00251             switch ( SSL_get_error( m_ssl_con, c ) )
00252             {
00253                case SSL_ERROR_WANT_READ :
00254                case SSL_ERROR_WANT_WRITE :
00255                   break;
00256                default :
00257                   unsigned long err = ERR_get_error();
00258                   throw QSESSLSocket( __FILE__ , __LINE__ , err );
00259             }
00260          }
00261       } while ( c <= 0 );
00262    }
00263    catch ( QSESSLSocket& e )
00264    {
00265       e.report();
00266       if ( 0 != m_ssl_con ) SSL_free( m_ssl_con );
00267       m_ssl_con = 0;
00268       throw;
00269    }
00270    catch ( ... )
00271    {
00272       if ( 0 != m_ssl_con ) SSL_free( m_ssl_con );
00273       m_ssl_con = 0;
00274       throw;
00275    }
00276 }
00277 
00278 void
00279 SSLSocket::finish_accept()
00280 {
00281    if ( 0 != m_ssl_con )
00282    {
00283       // The lazy thing to do is just get rid of the existing context,
00284       // since it should always be 0 before calling finish_accept().
00285       SSL_free( m_ssl_con );
00286       m_ssl_con = 0;
00287    }
00288    try
00289    {
00290       m_ssl_con = (SSL*) SSL_new(m_ctx);
00291       if ( 0 == m_ssl_con )
00292       {
00293          unsigned long err = ERR_get_error();
00294          throw QSESSLSocket( __FILE__ , __LINE__ , err );
00295       }
00296       SSL_clear(m_ssl_con);
00297       if ( ! SSL_set_fd( m_ssl_con, socket() ) )
00298       {
00299          unsigned long err = ERR_get_error();
00300          throw QSESSLSocket( __FILE__ , __LINE__ , err );
00301       }
00302       SSL_set_accept_state(m_ssl_con);
00303       // NOTE:  SSL_accept *does not* call an underlying accept(), it just
00304       //        performs the SSL part of the transaction.
00305       int a;
00306       do
00307       {
00308          a = SSL_accept(m_ssl_con);
00309          if ( a <= 0 )
00310          {
00311             switch ( SSL_get_error( m_ssl_con, a ) )
00312             {
00313                case SSL_ERROR_WANT_READ :
00314                case SSL_ERROR_WANT_WRITE :
00315                   break;
00316                default :
00317                   unsigned long err = ERR_get_error();
00318                   throw QSESSLSocket( __FILE__ , __LINE__ , err );
00319             }
00320          }
00321       } while ( a <= 0 );
00322    }
00323    catch ( QSESSLSocket& e )
00324    {
00325       e.report();
00326       if ( 0 != m_ssl_con ) SSL_free( m_ssl_con );
00327       m_ssl_con = 0;
00328       throw;
00329    }
00330    catch ( ... )
00331    {
00332       if ( 0 != m_ssl_con ) SSL_free( m_ssl_con );
00333       m_ssl_con = 0;
00334       throw;
00335    }
00336 }
00337 
00340 SSLSocketBuilder::SSLSocketBuilder( SSL_METHOD*  meth,
00341                                     const X509*  cert,
00342                                     const RSA*   privKey,
00343                                     const char*  ciphers,
00344                                     const char*  caCertFile,
00345                                     const char*  hostCertFile,
00346                                     int          verify,
00347                                     int          domain,
00348                                     int          type,
00349                                     int          protocol,
00350                                     bool         blocking ) :
00351    CODEX_Quorum::SocketBuilder( domain, type, protocol, blocking )
00352 {
00353    try
00354    {
00355       m_ctx = SSL_CTX_new(meth);
00356       if ( 0 == m_ctx )
00357       {
00358          throw SSLNullContextException( __FILE__ , __LINE__ );
00359       }
00360 //      if ( ! SSL_CTX_use_certificate( m_ctx, X509_dup((X509*)cert) ) )
00361       if ( ! SSL_CTX_use_certificate( m_ctx, (X509*)cert ) )
00362       {
00363          unsigned long err = ERR_get_error();
00364          throw QSESSLSocket( __FILE__ , __LINE__ , err );
00365       }
00366       if ( ! SSL_CTX_use_RSAPrivateKey( m_ctx, (RSA*)privKey ) )
00367 //                                        RSAPrivateKey_dup((RSA*)privKey) ) )
00368       {
00369          unsigned long err = ERR_get_error();
00370          throw QSESSLSocket( __FILE__ , __LINE__ , err );
00371       }
00372       if ( ! SSL_CTX_check_private_key( m_ctx ) )
00373       {
00374          unsigned long err = ERR_get_error();
00375          throw QSESSLSocket( __FILE__ , __LINE__ , err );
00376       }
00377       if ( 0 == ciphers )
00378       {
00379          throw SSLNullCiphersException( __FILE__ , __LINE__ );
00380       }
00381       SSL_CTX_set_cipher_list( m_ctx, ciphers );
00382       if ( ! SSL_CTX_load_verify_locations( m_ctx, caCertFile, 0 ) )
00383       {
00384          unsigned long err = ERR_get_error();
00385          throw QSESSLSocket( __FILE__ , __LINE__ , err );
00386       }
00387       if ( ! SSL_CTX_set_default_verify_paths( m_ctx ) )
00388       {
00389          unsigned long err = ERR_get_error();
00390          throw QSESSLSocket( __FILE__ , __LINE__ , err );
00391       }
00392       if ( 0 == verify )
00393       {
00394          throw SSLVerificationFlagsException( __FILE__ , __LINE__ );
00395       }
00396       SSL_CTX_set_verify( m_ctx, verify, 0 );
00397       SSL_CTX_set_client_CA_list( m_ctx,
00398                                   SSL_load_client_CA_file(hostCertFile) );
00399    }
00400    catch ( ... )
00401    {
00402       if ( 0 != m_ctx ) SSL_CTX_free( m_ctx );
00403       throw;
00404    }
00405 }
00406 
00407 SSLSocketBuilder::~SSLSocketBuilder()
00408 {
00409    if ( 0 != m_ctx ) SSL_CTX_free( m_ctx );
00410 }
00411 
00412 SocketBase*
00413 SSLSocketBuilder::operator()() const
00414 {
00415    return new SSLSocket( m_ctx, m_domain, m_type, m_protocol, m_blocking );
00416 }
00417 
00418 
00419 //------ Exceptions ------//
00420 
00421 void
00422 QSESSLSocket::errMsg() const
00423 {
00424    cerr << "SSL error:\n   "
00425         << ERR_error_string(error(),0);
00426 }
00427 
00428 void
00429 SSLExceptionBase::report() const
00430 {
00431    cerr << "At line " << m_line
00432         << " of file " << m_fname
00433         << "\nSSL exception: ";
00434    derivedMsg();
00435    cerr << endl;
00436 }
00437 
00438 void
00439 SSLNullContextException::derivedMsg() const
00440 {
00441    cerr << "Invalid context";
00442 }
00443 
00444 void
00445 SSLNullCiphersException::derivedMsg() const
00446 {
00447    cerr << "Invalid cipher list";
00448 }
00449 
00450 void
00451 SSLVerificationFlagsException::derivedMsg() const
00452 {
00453    cerr << "Invalid SSL verification flags";
00454 }

Generated on Fri May 6 17:41:16 2005 for COrnell Data EXchange (CODEX) by  doxygen 1.4.1