module Numeric.Neural.Model
( ParamFun(..)
, Component(..)
, _weights
, activate
, _component
, Pair(..)
, FEither(..)
, Convolve(..)
, cArr
, cFirst
, cLeft
, cConvolve
, Model(..)
, model
, modelR
, modelError
, descent
, StdModel
, mkStdModel
) where
import Control.Applicative
import Control.Arrow
import Control.Category
import Control.Monad.Par (runPar)
import Control.Monad.Par.Combinator (parMapReduceRange, InclusiveRange(..))
import Data.Profunctor
import Data.MyPrelude
import Prelude hiding (id, (.))
import Data.Utils.Analytic
import Data.Utils.Arrow
import Data.Utils.Statistics (mean)
import Data.Utils.Traversable
newtype ParamFun s t a b = ParamFun { runPF :: a -> t s -> b }
instance Category (ParamFun s t) where
id = arr id
ParamFun f . ParamFun g = ParamFun $ \x ts -> f (g x ts) ts
instance Arrow (ParamFun s t) where
arr f = ParamFun (\x _ -> f x)
first (ParamFun f) = ParamFun $ \(x, y) ts -> (f x ts, y)
instance ArrowChoice (ParamFun s t) where
left (ParamFun f) = ParamFun $ \ex ts -> case ex of
Left x -> Left (f x ts)
Right y -> Right y
instance ArrowConvolve (ParamFun s t) where
convolve (ParamFun f) = ParamFun $ \xs ts -> flip f ts <$> xs
instance Functor (ParamFun s t a) where fmap = fmapArr
instance Applicative (ParamFun s t a) where pure = pureArr; (<*>) = apArr
instance Profunctor (ParamFun s t) where dimap = dimapArr
data Component f g = forall t. (Traversable t, Applicative t, NFData (t Double)) => Component
{ weights :: t Double
, compute :: forall s. Analytic s => ParamFun s t (f s) (g s)
, initR :: forall m. MonadRandom m => m (t Double)
}
_weights:: Lens' (Component f g) [Double]
_weights = lens (\(Component ws _ _) -> toList ws)
(\(Component _ c i) ws -> let Just ws' = fromList ws in Component ws' c i)
activate :: Component f g -> f Double -> g Double
activate (Component ws f _) xs = runPF f xs ws
data Empty a = Empty deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)
instance Applicative Empty where
pure = const Empty
Empty <*> Empty = Empty
instance NFData (Empty a) where
rnf Empty = ()
data Pair s t a = Pair (s a) (t a) deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)
instance (NFData (s a), NFData (t a)) => NFData (Pair s t a) where
rnf (Pair xs ys) = rnf xs `seq` rnf ys `seq` ()
instance (Applicative s, Applicative t) => Applicative (Pair s t) where
pure x = Pair (pure x) (pure x)
Pair f g <*> Pair x y = Pair (f <*> x) (g <*> y)
instance Category Component where
id = cArr id
Component ws c i . Component ws' c' i' = Component
{ weights = Pair ws ws'
, compute = ParamFun $ \x (Pair zs zs') -> runPF c (runPF c' x zs') zs
, initR = Pair <$> i <*> i'
}
cArr :: Diff f g -> Component f g
cArr (Diff f) = Component
{ weights = Empty
, compute = arr f
, initR = return Empty
}
cFirst :: Component f g -> Component (Pair f h) (Pair g h)
cFirst (Component ws c i) = Component
{ weights = ws
, compute = ParamFun $ \(Pair xs ys) ws' -> Pair (runPF c xs ws') ys
, initR = i
}
data FEither f g a = FLeft (f a) | FRight (g a)
deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)
cLeft :: Component f g -> Component (FEither f h) (FEither g h)
cLeft (Component ws c i) = Component
{ weights = ws
, compute = ParamFun $ \es ws' -> case es of
FLeft xs -> FLeft $ runPF c xs ws'
FRight ys -> FRight ys
, initR = i
}
data Convolve f g a = Convolve (f (g a))
deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)
cConvolve :: Functor h => Component f g -> Component (Convolve h f) (Convolve h g)
cConvolve (Component ws c i) = Component
{ weights = ws
, compute = ParamFun $ \(Convolve xss) ws' -> Convolve $ flip (runPF c) ws' <$> xss
, initR = i
}
instance NFData (Component f g) where
rnf (Component ws _ _) = rnf ws
data Model :: (* -> *) -> (* -> *) -> * -> * -> * -> * where
Model :: (Functor f, Functor g)
=> Component f g
-> (a -> (f Double, Diff g Identity))
-> (b -> f Double)
-> (g Double -> c)
-> Model f g a b c
instance Profunctor (Model f g a) where
dimap m n (Model c e i o) = Model c e (i . m) (n . o)
instance NFData (Model f g a b c) where
rnf (Model c _ _ _) = rnf c
_component :: Lens' (Model f g a b c) (Component f g)
_component = lens (\(Model c _ _ _) -> c)
(\(Model _ e i o) c -> Model c e i o)
model :: Model f g a b c -> b -> c
model (Model c _ i o) = o . activate c . i
modelR :: MonadRandom m => Model f g a b c -> m (Model f g a b c)
modelR (Model c e i o) = case c of
Component _ f r -> do
ws <- r
return $ Model (Component ws f r) e i o
errFun :: forall f t a g. Functor f
=> (a -> (f Double, Diff g Identity))
-> a
-> (forall s. Analytic s => ParamFun s t (f s) (g s))
-> Diff t Identity
errFun e x f = Diff $ runPF f' x where
f' :: forall s. Analytic s => ParamFun s t a (Identity s)
f' = proc z -> do
let (x', Diff h) = e z
x'' = fromDouble <$> x'
y <- f -< x''
returnA -< h y
modelError' :: Model f g a b c -> a -> Double
modelError' (Model c e _ _) x = case c of
Component ws f _ -> let f' = errFun e x f
in runIdentity $ runDiff f' ws
modelError :: Foldable h => Model f g a b c -> h a -> Double
modelError m xs = mean $ modelError' m <$> toList xs
descent :: (Foldable h)
=> Model f g a b c
-> Double
-> h a
-> (Double, Model f g a b c)
descent (Model c e i o) eta xs = case c of
Component ws f r ->
let xs' = toList xs
l = length xs'
l' = fromIntegral l
scale = eta / l'
q j = do
let x = xs' !! j
(err', g') = gradWith' (\_ dw -> scale * dw) (errFun e x f) ws
return (err' / l', g')
s (err', g') (err'', g'') = return (err' + err'', (+) <$> g' <*> g'')
(err, ws') = runPar $ parMapReduceRange (InclusiveRange 0 $ pred l) q s (0, pure 0)
ws'' = () <$> ws <*> ws'
c' = Component ws'' f r
m = Model c' e i o
in (err, m)
type StdModel f g b c = Model f g (b, c) b c
mkStdModel :: (Functor f, Functor g)
=> Component f g
-> (c -> Diff g Identity)
-> (b -> f Double)
-> (g Double -> c)
-> StdModel f g b c
mkStdModel c e i o = Model c e' i o where
e' (x, y) = (i x, e y)