module Haskell_ML.FCN
( FCNet(), TrainEvo(..)
, randNet, runNet, netTest, hiddenStruct
, getWeights, getBiases
, trainNTimes
) where
import Control.Monad.Random
import Data.Binary
import Data.List
import Data.Singletons.Prelude
import Data.Singletons.TypeLits
import Data.Vector.Storable (toList)
import GHC.Generics (Generic)
import Numeric.LinearAlgebra.Static
import Haskell_ML.Util
data FCNet :: Nat -> Nat -> * where
FCNet :: Network i hs o -> FCNet i o
randNet :: (KnownNat i, KnownNat o, MonadRandom m)
=> [Integer]
-> m (FCNet i o)
randNet hs = withSomeSing hs (fmap FCNet . randNetwork')
data TrainEvo = TrainEvo
{ accs :: [Double]
, diffs :: [([[Double]],[[Double]])]
}
trainNTimes :: (KnownNat i, KnownNat o)
=> Int
-> Double
-> FCNet i o
-> [(R i, R o)]
-> (FCNet i o, TrainEvo)
trainNTimes = trainNTimes' [] []
trainNTimes' :: (KnownNat i, KnownNat o)
=> [Double]
-> [([[Double]], [[Double]])]
-> Int -> Double -> FCNet i o -> [(R i, R o)] -> (FCNet i o, TrainEvo)
trainNTimes' accs diffs 0 _ net _ = (net, TrainEvo accs diffs)
trainNTimes' accs diffs n rate net prs = trainNTimes' (accs ++ [acc]) (diffs ++ [diff]) (n1) rate net' prs
where net' = trainNet rate net prs
acc = classificationAccuracy res ref
res = runNet net' $ map fst prs
ref = map snd prs
diff = ( zipWith (zipWith ()) (getWeights net') (getWeights net)
, zipWith (zipWith ()) (getBiases net') (getBiases net) )
runNet :: (KnownNat i, KnownNat o)
=> FCNet i o
-> [R i]
-> [R o]
runNet (FCNet n) = map (runNetwork n)
instance (KnownNat i, KnownNat o) => Binary (FCNet i o) where
put = putFCNet
get = getFCNet
netTest :: MonadRandom m => Double -> Int -> m String
netTest rate n = do
inps <- replicateM n $ do
s <- getRandom
return $ randomVector s Uniform * 2 1
let outs = flip map inps $ \v ->
if v `inCircle` (fromRational 0.33, 0.33)
|| v `inCircle` (fromRational (0.33), 0.33)
then fromRational 1
else fromRational 0
net0 :: Network 2 '[16, 8] 1 <- randNetwork
let trained = sgd rate (zip inps outs) net0
outMat = [ [ render (norm_2 (runNetwork trained (vector [x / 25 1,y / 10 1])))
| x <- [0..50] ]
| y <- [0..20] ]
render r | r <= 0.2 = ' '
| r <= 0.4 = '.'
| r <= 0.6 = '-'
| r <= 0.8 = '='
| otherwise = '#'
return $ unlines outMat
where
inCircle :: KnownNat n => R n -> (R n, Double) -> Bool
v `inCircle` (o, r) = norm_2 (v o) <= r
hiddenStruct :: FCNet i o -> [Integer]
hiddenStruct (FCNet net) = hiddenStruct' net
hiddenStruct' :: Network i hs o -> [Integer]
hiddenStruct' = \case
W _ -> []
_ :&~ (n' :: Network h hs' o)
-> natVal (Proxy @h)
: hiddenStruct' n'
getWeights :: (KnownNat i, KnownNat o) => FCNet i o -> [[Double]]
getWeights (FCNet net) = getWeights' net
getWeights' :: (KnownNat i, KnownNat o) => Network i hs o -> [[Double]]
getWeights' (W Layer{..}) = [concatMap (toList . extract) (toRows nodes)]
getWeights' (Layer{..} :&~ net) = concatMap (toList . extract) (toRows nodes) : getWeights' net
getBiases :: (KnownNat i, KnownNat o) => FCNet i o -> [[Double]]
getBiases (FCNet net) = getBiases' net
getBiases' :: (KnownNat i, KnownNat o) => Network i hs o -> [[Double]]
getBiases' (W Layer{..}) = [toList $ extract biases]
getBiases' (Layer{..} :&~ net) = toList (extract biases) : getBiases' net
data Layer i o = Layer { biases :: !(R o)
, nodes :: !(L o i)
}
deriving (Show, Generic)
instance (KnownNat i, KnownNat o) => Binary (Layer i o)
randLayer :: forall m i o. (MonadRandom m, KnownNat i, KnownNat o)
=> m (Layer i o)
randLayer = do
s1 :: Int <- getRandom
s2 :: Int <- getRandom
let m = eye
b = randomVector s2 Gaussian
n = gaussianSample s1 (takeDiag m) (sym m)
return $ Layer b n
data Network :: Nat -> [Nat] -> Nat -> * where
W :: !(Layer i o)
-> Network i '[] o
(:&~) :: KnownNat h
=> !(Layer i h)
-> !(Network h hs o)
-> Network i (h ': hs) o
infixr 5 :&~
randNetwork :: forall m i hs o. (MonadRandom m, KnownNat i, SingI hs, KnownNat o)
=> m (Network i hs o)
randNetwork = randNetwork' sing
randNetwork' :: forall m i hs o. (MonadRandom m, KnownNat i, KnownNat o)
=> Sing hs -> m (Network i hs o)
randNetwork' = \case
SNil -> W <$> randLayer
SNat `SCons` ss -> (:&~) <$> randLayer <*> randNetwork' ss
putNet :: (KnownNat i, KnownNat o)
=> Network i hs o
-> Put
putNet = \case
W w -> put w
w :&~ n -> put w *> putNet n
getNet :: forall i hs o. (KnownNat i, KnownNat o)
=> Sing hs
-> Get (Network i hs o)
getNet = \case
SNil -> W <$> get
SNat `SCons` ss -> (:&~) <$> get <*> getNet ss
instance (KnownNat i, SingI hs, KnownNat o) => Binary (Network i hs o) where
put = putNet
get = getNet sing
putFCNet :: (KnownNat i, KnownNat o)
=> FCNet i o
-> Put
putFCNet (FCNet net) = do
put (hiddenStruct' net)
putNet net
getFCNet :: (KnownNat i, KnownNat o)
=> Get (FCNet i o)
getFCNet = do
hs <- get
withSomeSing hs (fmap FCNet . getNet)
runLayer :: (KnownNat i, KnownNat o)
=> Layer i o
-> R i
-> R o
runLayer (Layer b n) v = b + n #> v
runNetwork :: (KnownNat i, KnownNat o)
=> Network i hs o
-> R i
-> R o
runNetwork = \case
W w -> \(!v) -> logistic (runLayer w v)
(w :&~ n') -> \(!v) -> let v' = logistic (runLayer w v)
in runNetwork n' v'
trainNet :: (KnownNat i, KnownNat o)
=> Double
-> FCNet i o
-> [(R i, R o)]
-> FCNet i o
trainNet rate (FCNet net) trn_prs = FCNet $ sgd rate trn_prs net
sgd :: forall i hs o. (KnownNat i, KnownNat o)
=> Double
-> [(R i, R o)]
-> Network i hs o
-> Network i hs o
sgd rate trn_prs net = foldl' (sgdStep rate) net trn_prs
sgdStep :: forall i hs o. (KnownNat i, KnownNat o)
=> Double
-> Network i hs o
-> (R i, R o)
-> Network i hs o
sgdStep rate net trn_pr = fst $ go x0 net
where
x0 = fst trn_pr
target = snd trn_pr
go :: forall j js. KnownNat j
=> R j
-> Network j js o
-> (Network j js o, R j)
go !x (W w@(Layer wB wN))
= let y = runLayer w x
o = logistic y
dEdy = logistic' y * (o target)
wB' = wB konst rate * dEdy
wN' = wN konst rate * (dEdy `outer` x)
w' = Layer wB' wN'
dWs = tr wN #> dEdy
in (W w', dWs)
go !x (w@(Layer wB wN) :&~ n)
= let y = runLayer w x
o = logistic y
(n', dWs') = go o n
dEdy = logistic' y * dWs'
wB' = wB konst rate * dEdy
wN' = wN konst rate * (dEdy `outer` x)
w' = Layer wB' wN'
dWs = tr wN #> dEdy
in (w' :&~ n', dWs)
logistic :: Floating a => a -> a
logistic x = 1 / (1 + exp (x))
logistic' :: Floating a => a -> a
logistic' x = logix * (1 logix)
where
logix = logistic x