module Happstack.Server.Internal.TLS where
import Control.Concurrent (forkIO, killThread, myThreadId)
import Control.Exception (finally)
import Control.Exception.Extensible as E
import Control.Monad (forever, when)
import Data.Time (UTCTime)
import Happstack.Server.Internal.Listen (listenOn)
import Happstack.Server.Internal.Handler (request)
import Happstack.Server.Internal.Socket (acceptLite)
import Happstack.Server.Internal.TimeoutManager (cancel, initialize, register)
import Happstack.Server.Internal.TimeoutSocketTLS as TSS
import Happstack.Server.Internal.Types (Request, Response)
import Network.Socket (HostName, PortNumber, Socket, getSocketName, sClose, socketPort)
import OpenSSL (withOpenSSL)
import OpenSSL.Session (SSL, SSLContext)
import qualified OpenSSL.Session as SSL
import Happstack.Server.Internal.TimeoutIO (TimeoutIO(toHandle, toShutdown))
import Happstack.Server.Types (LogAccess, logMAccess)
import System.IO.Error (isFullError)
import System.Log.Logger (Priority(..), logM)
#ifndef mingw32_HOST_OS
import System.Posix.Signals (Handler(Ignore), installHandler, openEndedPipe)
#endif
log':: Priority -> String -> IO ()
log' = logM "Happstack.Server.Internal.TLS"
data TLSConf = TLSConf {
tlsPort :: Int
, tlsCert :: FilePath
, tlsKey :: FilePath
, tlsTimeout :: Int
, tlsLogAccess :: Maybe (LogAccess UTCTime)
, tlsValidator :: Maybe (Response -> IO Response)
}
nullTLSConf :: TLSConf
nullTLSConf =
TLSConf { tlsPort = 443
, tlsCert = ""
, tlsKey = ""
, tlsTimeout = 30
, tlsLogAccess = Just logMAccess
, tlsValidator = Nothing
}
data HTTPS = HTTPS
{ httpsSocket :: Socket
, sslContext :: SSLContext
}
httpsOnSocket :: FilePath
-> FilePath
-> Socket
-> IO HTTPS
httpsOnSocket cert key socket =
do ctx <- SSL.context
SSL.contextSetPrivateKeyFile ctx key
SSL.contextSetCertificateFile ctx cert
SSL.contextSetDefaultCiphers ctx
b <- SSL.contextCheckPrivateKey ctx
when (not b) $ error $ "OpenTLS certificate and key do not match."
return (HTTPS socket ctx)
acceptTLS :: HTTPS -> IO (SSL, HostName, PortNumber)
acceptTLS (HTTPS sck' ctx) =
do
(sck, peer, port) <- acceptLite sck'
ssl <- SSL.connection ctx sck
SSL.accept ssl
return (ssl, peer, port)
listenTLS :: TLSConf
-> (Request -> IO Response)
-> IO ()
listenTLS tlsConf hand =
do withOpenSSL $ return ()
tlsSocket <- listenOn (tlsPort tlsConf)
https <- httpsOnSocket (tlsCert tlsConf) (tlsKey tlsConf) tlsSocket
listenTLS' (tlsTimeout tlsConf) (tlsLogAccess tlsConf) tlsSocket https hand
listenTLS' :: Int -> Maybe (LogAccess UTCTime) -> Socket -> HTTPS -> (Request -> IO Response) -> IO ()
listenTLS' timeout mlog socket https hand = do
#ifndef mingw32_HOST_OS
installHandler openEndedPipe Ignore Nothing
#endif
tm <- initialize (timeout * (10^(6 :: Int)))
do let work :: (SSL, HostName, PortNumber) -> IO ()
work (ssl, hn, p) =
do tid <- myThreadId
thandle <- register tm $ do SSL.shutdown ssl SSL.Unidirectional `E.catch` ignoreSSLException
killThread tid
let timeoutIO = TSS.timeoutSocketIO thandle ssl
request timeoutIO mlog (hn,fromIntegral p) hand `E.catches` [ Handler ignoreConnectionAbruptlyTerminated
, Handler ehs
]
cancel (toHandle timeoutIO)
toShutdown timeoutIO `E.catch` ignoreSSLException
loop :: IO ()
loop = forever $ do w <- acceptTLS https
forkIO $ work w
return ()
pe e = log' ERROR ("ERROR in https accept thread: " ++ show e)
infi = loop `catchSome` pe >> infi
sockName <- getSocketName socket
sockPort <- socketPort socket
log' NOTICE ("Listening on https://" ++ show sockName ++":" ++ show sockPort)
infi `finally` (sClose socket)
where
ignoreSSLException :: SSL.SomeSSLException -> IO ()
ignoreSSLException _ = return ()
ignoreConnectionAbruptlyTerminated :: SSL.ConnectionAbruptlyTerminated -> IO ()
ignoreConnectionAbruptlyTerminated _ = return ()
ehs :: SomeException -> IO ()
ehs x = when ((fromException x) /= Just ThreadKilled) $ log' ERROR ("HTTPS request failed with: " ++ show x)
catchSome op h =
op `E.catches` [ Handler ignoreConnectionAbruptlyTerminated
, Handler $ \(e :: ArithException) -> h (toException e)
, Handler $ \(e :: ArrayException) -> h (toException e)
, Handler $ \(e :: IOException) ->
if isFullError e
then return ()
else throw e
]