module Network.TLS.Sending (
encodePacket
, encodePacket13
, updateHandshake
, updateHandshake13
) where
import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Handshake.Random
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Packet13
import Network.TLS.Parameters
import Network.TLS.Record
import Network.TLS.Record.Layer
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types (Role(..))
import Network.TLS.Util
import Control.Concurrent.MVar
import Control.Monad.State.Strict
import qualified Data.ByteString as B
import Data.IORef
encodePacket :: Monoid bytes
=> Context -> RecordLayer bytes -> Packet -> IO (Either TLSError bytes)
encodePacket :: Context
-> RecordLayer bytes -> Packet -> IO (Either TLSError bytes)
encodePacket ctx :: Context
ctx recordLayer :: RecordLayer bytes
recordLayer pkt :: Packet
pkt = do
(ver :: Version
ver, _) <- Context -> IO (Version, Bool)
decideRecordVersion Context
ctx
let pt :: ProtocolType
pt = Packet -> ProtocolType
packetType Packet
pkt
mkRecord :: ByteString -> Record Plaintext
mkRecord bs :: ByteString
bs = ProtocolType -> Version -> Fragment Plaintext -> Record Plaintext
forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
pt Version
ver (ByteString -> Fragment Plaintext
fragmentPlaintext ByteString
bs)
len :: Maybe Int
len = Context -> Maybe Int
ctxFragmentSize Context
ctx
[Record Plaintext]
records <- (ByteString -> Record Plaintext)
-> [ByteString] -> [Record Plaintext]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Record Plaintext
mkRecord ([ByteString] -> [Record Plaintext])
-> IO [ByteString] -> IO [Record Plaintext]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Maybe Int -> Packet -> IO [ByteString]
packetToFragments Context
ctx Maybe Int
len Packet
pkt
Either TLSError bytes
bs <- ([bytes] -> bytes)
-> Either TLSError [bytes] -> Either TLSError bytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [bytes] -> bytes
forall a. Monoid a => [a] -> a
mconcat (Either TLSError [bytes] -> Either TLSError bytes)
-> IO (Either TLSError [bytes]) -> IO (Either TLSError bytes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Record Plaintext]
-> (Record Plaintext -> IO (Either TLSError bytes))
-> IO (Either TLSError [bytes])
forall (m :: * -> *) a l b.
Monad m =>
[a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM [Record Plaintext]
records (RecordLayer bytes -> Record Plaintext -> IO (Either TLSError bytes)
forall bytes.
RecordLayer bytes -> Record Plaintext -> IO (Either TLSError bytes)
recordEncode RecordLayer bytes
recordLayer)
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Packet
pkt Packet -> Packet -> Bool
forall a. Eq a => a -> a -> Bool
== Packet
ChangeCipherSpec) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> IO ()
switchTxEncryption Context
ctx
Either TLSError bytes -> IO (Either TLSError bytes)
forall (m :: * -> *) a. Monad m => a -> m a
return Either TLSError bytes
bs
packetToFragments :: Context -> Maybe Int -> Packet -> IO [ByteString]
packetToFragments :: Context -> Maybe Int -> Packet -> IO [ByteString]
packetToFragments ctx :: Context
ctx len :: Maybe Int
len (Handshake hss :: [Handshake]
hss) =
Maybe Int -> ByteString -> [ByteString]
getChunks Maybe Int
len (ByteString -> [ByteString])
-> ([ByteString] -> ByteString) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
B.concat ([ByteString] -> [ByteString])
-> IO [ByteString] -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Handshake -> IO ByteString) -> [Handshake] -> IO [ByteString]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Context -> Role -> Handshake -> IO ByteString
updateHandshake Context
ctx Role
ClientRole) [Handshake]
hss
packetToFragments _ _ (Alert a :: [(AlertLevel, AlertDescription)]
a) = [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [[(AlertLevel, AlertDescription)] -> ByteString
encodeAlerts [(AlertLevel, AlertDescription)]
a]
packetToFragments _ _ ChangeCipherSpec = [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
encodeChangeCipherSpec]
packetToFragments _ _ (AppData x :: ByteString
x) = [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
x]
switchTxEncryption :: Context -> IO ()
switchTxEncryption :: Context -> IO ()
switchTxEncryption ctx :: Context
ctx = do
RecordState
tx <- Context -> HandshakeM RecordState -> IO RecordState
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (String -> Maybe RecordState -> RecordState
forall a. String -> Maybe a -> a
fromJust "tx-state" (Maybe RecordState -> RecordState)
-> HandshakeM (Maybe RecordState) -> HandshakeM RecordState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe RecordState)
-> HandshakeM (Maybe RecordState)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe RecordState
hstPendingTxState)
(ver :: Version
ver, cc :: Role
cc) <- Context -> TLSSt (Version, Role) -> IO (Version, Role)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt (Version, Role) -> IO (Version, Role))
-> TLSSt (Version, Role) -> IO (Version, Role)
forall a b. (a -> b) -> a -> b
$ do Version
v <- TLSSt Version
getVersion
Role
c <- TLSSt Role
isClientContext
(Version, Role) -> TLSSt (Version, Role)
forall (m :: * -> *) a. Monad m => a -> m a
return (Version
v, Role
c)
IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
ctxTxState Context
ctx) (\_ -> RecordState -> IO RecordState
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
tx)
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
<= Version
TLS10 Bool -> Bool -> Bool
&& Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole Bool -> Bool -> Bool
&& RecordState -> Bool
isCBC RecordState
tx Bool -> Bool -> Bool
&& Supported -> Bool
supportedEmptyPacket (Context -> Supported
ctxSupported Context
ctx)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IORef Bool -> Bool -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef Bool
ctxNeedEmptyPacket Context
ctx) Bool
True
where isCBC :: RecordState -> Bool
isCBC tx :: RecordState
tx = Bool -> (Cipher -> Bool) -> Maybe Cipher -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (\c :: Cipher
c -> Bulk -> Int
bulkBlockSize (Cipher -> Bulk
cipherBulk Cipher
c) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 0) (RecordState -> Maybe Cipher
stCipher RecordState
tx)
updateHandshake :: Context -> Role -> Handshake -> IO ByteString
updateHandshake :: Context -> Role -> Handshake -> IO ByteString
updateHandshake ctx :: Context
ctx role :: Role
role hs :: Handshake
hs = do
case Handshake
hs of
Finished fdata :: ByteString
fdata -> Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ Role -> ByteString -> TLSSt ()
updateVerifiedData Role
role ByteString
fdata
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Bool -> HandshakeM () -> HandshakeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Handshake -> Bool
certVerifyHandshakeMaterial Handshake
hs) (HandshakeM () -> HandshakeM ()) -> HandshakeM () -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ ByteString -> HandshakeM ()
addHandshakeMessage ByteString
encoded
Bool -> HandshakeM () -> HandshakeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (HandshakeType -> Bool
finishHandshakeTypeMaterial (HandshakeType -> Bool) -> HandshakeType -> Bool
forall a b. (a -> b) -> a -> b
$ Handshake -> HandshakeType
typeOfHandshake Handshake
hs) (HandshakeM () -> HandshakeM ()) -> HandshakeM () -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ ByteString -> HandshakeM ()
updateHandshakeDigest ByteString
encoded
ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
encoded
where
encoded :: ByteString
encoded = Handshake -> ByteString
encodeHandshake Handshake
hs
encodePacket13 :: Monoid bytes
=> Context -> RecordLayer bytes -> Packet13 -> IO (Either TLSError bytes)
encodePacket13 :: Context
-> RecordLayer bytes -> Packet13 -> IO (Either TLSError bytes)
encodePacket13 ctx :: Context
ctx recordLayer :: RecordLayer bytes
recordLayer pkt :: Packet13
pkt = do
let pt :: ProtocolType
pt = Packet13 -> ProtocolType
contentType Packet13
pkt
mkRecord :: ByteString -> Record Plaintext
mkRecord bs :: ByteString
bs = ProtocolType -> Version -> Fragment Plaintext -> Record Plaintext
forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
pt Version
TLS12 (ByteString -> Fragment Plaintext
fragmentPlaintext ByteString
bs)
len :: Maybe Int
len = Context -> Maybe Int
ctxFragmentSize Context
ctx
[Record Plaintext]
records <- (ByteString -> Record Plaintext)
-> [ByteString] -> [Record Plaintext]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Record Plaintext
mkRecord ([ByteString] -> [Record Plaintext])
-> IO [ByteString] -> IO [Record Plaintext]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Maybe Int -> Packet13 -> IO [ByteString]
packetToFragments13 Context
ctx Maybe Int
len Packet13
pkt
([bytes] -> bytes)
-> Either TLSError [bytes] -> Either TLSError bytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [bytes] -> bytes
forall a. Monoid a => [a] -> a
mconcat (Either TLSError [bytes] -> Either TLSError bytes)
-> IO (Either TLSError [bytes]) -> IO (Either TLSError bytes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Record Plaintext]
-> (Record Plaintext -> IO (Either TLSError bytes))
-> IO (Either TLSError [bytes])
forall (m :: * -> *) a l b.
Monad m =>
[a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM [Record Plaintext]
records (RecordLayer bytes -> Record Plaintext -> IO (Either TLSError bytes)
forall bytes.
RecordLayer bytes -> Record Plaintext -> IO (Either TLSError bytes)
recordEncode13 RecordLayer bytes
recordLayer)
packetToFragments13 :: Context -> Maybe Int -> Packet13 -> IO [ByteString]
packetToFragments13 :: Context -> Maybe Int -> Packet13 -> IO [ByteString]
packetToFragments13 ctx :: Context
ctx len :: Maybe Int
len (Handshake13 hss :: [Handshake13]
hss) =
Maybe Int -> ByteString -> [ByteString]
getChunks Maybe Int
len (ByteString -> [ByteString])
-> ([ByteString] -> ByteString) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
B.concat ([ByteString] -> [ByteString])
-> IO [ByteString] -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Handshake13 -> IO ByteString) -> [Handshake13] -> IO [ByteString]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Context -> Handshake13 -> IO ByteString
updateHandshake13 Context
ctx) [Handshake13]
hss
packetToFragments13 _ _ (Alert13 a :: [(AlertLevel, AlertDescription)]
a) = [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [[(AlertLevel, AlertDescription)] -> ByteString
encodeAlerts [(AlertLevel, AlertDescription)]
a]
packetToFragments13 _ _ (AppData13 x :: ByteString
x) = [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
x]
packetToFragments13 _ _ ChangeCipherSpec13 = [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
encodeChangeCipherSpec]
updateHandshake13 :: Context -> Handshake13 -> IO ByteString
updateHandshake13 :: Context -> Handshake13 -> IO ByteString
updateHandshake13 ctx :: Context
ctx hs :: Handshake13
hs
| Handshake13 -> Bool
isIgnored Handshake13
hs = ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
encoded
| Bool
otherwise = Context -> HandshakeM ByteString -> IO ByteString
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM ByteString -> IO ByteString)
-> HandshakeM ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
Bool -> HandshakeM () -> HandshakeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Handshake13 -> Bool
isHRR Handshake13
hs) HandshakeM ()
wrapAsMessageHash13
ByteString -> HandshakeM ()
updateHandshakeDigest ByteString
encoded
ByteString -> HandshakeM ()
addHandshakeMessage ByteString
encoded
ByteString -> HandshakeM ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
encoded
where
encoded :: ByteString
encoded = Handshake13 -> ByteString
encodeHandshake13 Handshake13
hs
isHRR :: Handshake13 -> Bool
isHRR (ServerHello13 srand :: ServerRandom
srand _ _ _) = ServerRandom -> Bool
isHelloRetryRequest ServerRandom
srand
isHRR _ = Bool
False
isIgnored :: Handshake13 -> Bool
isIgnored NewSessionTicket13{} = Bool
True
isIgnored KeyUpdate13{} = Bool
True
isIgnored _ = Bool
False