{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}

{-|
Module     : Control.Monad.Trans.Choice.Covariant
Copyright  : (c) Eamon Olive, 2020
             (c) Louis Hyde,  2020
License    : AGPL-3
Maintainer : ejolive97@gmail.com
Stability  : experimental

Module for 'UniformRandom', a wrapper that provides an instance of 'MonadChoice' when it wraps a member of 'MonadRandom'.
This is done by using the 'uniform' function as 'choose'.

-}
module Control.Monad.Choice.Random
  ( UniformRandom
    ( UniformRandom
    )
  , lift
  , colift
  )
  where

-- Before 4.10 liftA2 was not a class function of Applicative
import Control.Applicative
  ( Alternative
    ( empty
    , (<|>)
    , some
    , many
    )
#if MIN_VERSION_base(4,10,0)
  , Applicative
    ( liftA2
    )
#endif
  )

import Control.Monad
  ( MonadPlus
    ( mzero
    , mplus
    )
  )
import Control.Monad.Class.Choice
  ( MonadChoice
    ( choose
    )
  )
import Control.Monad.Cont.Class
  ( MonadCont
    ( callCC
    )
  )
import Control.Monad.Error.Class
  ( MonadError
    ( throwError
    , catchError
    )
  )
import qualified Control.Monad.Fail as Fail
  ( MonadFail
    ( fail
    )
  )
import Control.Monad.Fix
  ( MonadFix
    ( mfix
    )
  )
import Control.Monad.IO.Class
  ( MonadIO
    ( liftIO
    )
  )
import Control.Monad.Primitive
  ( PrimMonad
    ( PrimState
    , primitive
    )
  )
import Control.Monad.Random.Class
  ( MonadRandom
    ( getRandomR
    , getRandom
    , getRandomRs
    , getRandoms
    )
  , MonadSplit
    ( getSplit
    )
  , MonadInterleave
    ( interleave
    )
  , uniform
  )
import Control.Monad.Reader.Class
  ( MonadReader
    ( ask
    , local
    , reader
    )
  )
import Control.Monad.RWS.Class
  ( MonadRWS
  )
import Control.Monad.State.Class
  ( MonadState
    ( get
    , put
    , state
    )
  )
import Control.Monad.Writer.Class
  ( MonadWriter
    ( writer
    , tell
    , listen
    , pass
    )
  )
import Control.Monad.Zip
  ( MonadZip
    ( mzip
    , mzipWith
    , munzip
    )
  )

import Data.Foldable
  ( Foldable
    ( fold
    , foldr'
    , foldl'
    , toList
    )
  )
import Data.Functor.Classes
  ( Eq1
    ( liftEq
    )
  , Ord1
    ( liftCompare
    )
  )
#if !MIN_VERSION_base(4,13,0)
import Data.Semigroup
  ( Semigroup
    ( (<>)
    )
  )
#endif

newtype UniformRandom r a
  = UniformRandom
    { runUniformRandom :: r a
    }

-- | An alias for 'UniformRandom'
lift :: r a -> UniformRandom r a
lift = UniformRandom
{-# INLINE lift #-}

-- | A function to unwrap a 'UniformRandom'
colift :: UniformRandom r a -> r a
colift = runUniformRandom
{-# INLINE colift #-}

-- | Convenience function for writing instances
lift2 :: (r a -> s b -> t c) -> UniformRandom r a -> UniformRandom s b -> UniformRandom t c
lift2 = (((lift .) . (. colift)) .) . (. colift)
{-# INLINE lift2 #-}

instance Functor f => Functor (UniformRandom f) where
  fmap = (lift .) . (. colift) . fmap
  {-# INLINE fmap #-}

instance Applicative f => Applicative (UniformRandom f) where
  pure = lift . pure
  {-# INLINE pure #-}
  (<*>) = lift2 (<*>)
  {-# INLINE (<*>) #-}
#if MIN_VERSION_base(4,10,0)
  liftA2 = lift2 . liftA2
  {-# INLINE liftA2 #-}
#endif
  (*>) = lift2 (*>)
  {-# INLINE (*>) #-}
  (<*) = lift2 (<*)
  {-# INLINE (<*) #-}

instance Monad m => Monad (UniformRandom m) where
  (>>=) = (lift .) . (. (colift .)) . (>>=) . colift
  {-# INLINE (>>=) #-}

instance (Foldable f, MonadRandom m) => MonadChoice f (UniformRandom m) where
  choose = lift . uniform
  {-# INLINE choose #-}

instance MonadRandom m => MonadRandom (UniformRandom m) where
  getRandomR = lift . getRandomR
  {-# INLINE getRandomR #-}
  getRandom = lift getRandom
  {-# INLINE getRandom #-}
  getRandomRs = lift . getRandomRs
  {-# INLINE getRandomRs #-}
  getRandoms = lift getRandoms
  {-# INLINE getRandoms #-}

instance MonadFix m => MonadFix (UniformRandom m) where
  mfix = lift . mfix . (colift .)
  {-# INLINE mfix #-}

instance Fail.MonadFail m => Fail.MonadFail (UniformRandom m) where
  fail = lift . Fail.fail
  {-# INLINE fail #-}

instance Alternative f => Alternative (UniformRandom f) where
  empty = lift empty
  {-# INLINE empty #-}
  (<|>) = lift2 (<|>)
  {-# INLINE (<|>) #-}
  some = lift . some . colift
  {-# INLINE some #-}
  many = lift . many . colift
  {-# INLINE many #-}

instance MonadPlus m => MonadPlus (UniformRandom m) where
  mzero = lift mzero
  {-# INLINE mzero #-}
  mplus = lift2 mplus
  {-# INLINE mplus #-}

instance MonadIO m => MonadIO (UniformRandom m) where
  liftIO = lift . liftIO
  {-# INLINE liftIO #-}

instance Semigroup (r a) => Semigroup (UniformRandom r a) where
  (<>) = lift2 (<>)
  {-# INLINE (<>) #-}

instance
  ( Monoid (r a)
#if !MIN_VERSION_base(4,11,0)
  , Semigroup (r a)
#endif
  )
    => Monoid (UniformRandom r a)
  where
    mempty = lift mempty
    {-# INLINE mempty #-}
    mappend = (<>)
    {-# INLINE mappend #-}

instance MonadError e m => MonadError e (UniformRandom m) where
  throwError = lift . throwError
  {-# INLINE throwError #-}
  catchError = (. (colift .)) . (lift .) . catchError . colift
  {-# INLINE catchError #-}

instance MonadReader r m => MonadReader r (UniformRandom m) where
  ask = lift ask
  {-# INLINE ask #-}
  local = (lift .) . (. colift) . local
  {-# INLINE local #-}
  reader = lift . reader
  {-# INLINE reader #-}

instance MonadState s m => MonadState s (UniformRandom m) where
  get = lift get
  {-# INLINE get #-}
  put = lift . put
  {-# INLINE put #-}
  state = lift . state
  {-# INLINE state #-}

instance Foldable f => Foldable (UniformRandom f) where
  fold = fold . colift
  {-# INLINE fold #-}
  foldMap = (. colift) . foldMap
  {-# INLINE foldMap #-}
  foldr = ((. colift) .) . foldr
  {-# INLINE foldr #-}
  foldr' = ((. colift) .) . foldr'
  {-# INLINE foldr' #-}
  foldl = ((. colift) .) . foldl
  {-# INLINE foldl #-}
  foldl' = ((. colift) .) . foldl'
  {-# INLINE foldl' #-}
  foldr1 = (. colift) . foldr1
  {-# INLINE foldr1 #-}
  foldl1 = (. colift) . foldl1
  {-# INLINE foldl1 #-}
  toList = toList . colift
  {-# INLINE toList #-}
  null = null . colift
  {-# INLINE null #-}
  length = length . colift
  {-# INLINE length #-}
  elem = (. colift) . elem
  {-# INLINE elem #-}
  maximum = maximum . colift
  {-# INLINE maximum #-}
  minimum = minimum . colift
  {-# INLINE minimum #-}
  sum = sum . colift
  {-# INLINE sum #-}
  product = product . colift
  {-# INLINE product #-}

instance Traversable t => Traversable (UniformRandom t) where
  traverse = (fmap lift .) . (. colift) . traverse
  {-# INLINE traverse #-}
  sequenceA = fmap lift . sequenceA . colift
  {-# INLINE sequenceA #-}
  mapM = (fmap lift .) . (. colift) . mapM
  {-# INLINE mapM #-}
  sequence = fmap lift . sequence . colift
  {-# INLINE sequence #-}

instance Eq1 f => Eq1 (UniformRandom f) where
  liftEq = ((. colift) .) . (. colift) . liftEq
  {-# INLINE liftEq #-}

instance Ord1 f => Ord1 (UniformRandom f) where
  liftCompare = ((. colift) .) . (. colift) . liftCompare
  {-# INLINE liftCompare #-}

instance MonadZip m => MonadZip (UniformRandom m) where
  mzip = lift2 mzip
  {-# INLINE mzip #-}
  mzipWith = lift2 . mzipWith
  {-# INLINE mzipWith #-}
  munzip = (\(m1,m2) -> (lift m1, lift m2)) . munzip . colift
  {-# INLINE munzip #-}

instance MonadCont m => MonadCont (UniformRandom m) where
  callCC = lift . callCC . (colift .) . (. (lift .))
  {-# INLINE callCC #-}

instance Eq (r a) => Eq (UniformRandom r a) where
  (==) = (. colift) . (==) . colift
  {-# INLINE (==) #-}
  (/=) = (. colift) . (/=) . colift
  {-# INLINE (/=) #-}

instance Ord (r a) => Ord (UniformRandom r a) where
  compare = (. colift) . compare . colift
  {-# INLINE compare #-}
  (<) = (. colift) . (<) . colift
  {-# INLINE (<) #-}
  (<=) = (. colift) . (<=) . colift
  {-# INLINE (<=) #-}
  (>) = (. colift) . (>) . colift
  {-# INLINE (>) #-}
  (>=) = (. colift) . (>=) . colift
  {-# INLINE (>=) #-}
  max = lift2 max
  {-# INLINE max #-}
  min = lift2 min
  {-# INLINE min #-}

instance MonadRWS r w s m => MonadRWS r w s (UniformRandom m)

instance MonadWriter w m => MonadWriter w (UniformRandom m) where
  writer = lift . writer
  {-# INLINE writer #-}
  tell = lift . tell
  {-# INLINE tell #-}
  listen = lift . listen . colift
  {-# INLINE listen #-}
  pass = lift . pass . colift
  {-# INLINE pass #-}

instance MonadSplit g m => MonadSplit g (UniformRandom m) where
  getSplit = lift getSplit
  {-# INLINE getSplit #-}

instance PrimMonad m => PrimMonad (UniformRandom m) where
  type PrimState (UniformRandom m) = PrimState m

  primitive = lift . primitive
  {-# INLINE primitive #-}

instance MonadInterleave m => MonadInterleave (UniformRandom m) where
  interleave = lift . interleave . colift
  {-# INLINE interleave #-}