module Neet.Network (
modSig
, Network(..)
, Neuron(..)
, mkPhenotype
, stepNeuron
, stepNetwork
, snapshot
, pushThrough
, getOutput
) where
import Data.Set (Set)
import qualified Data.Set as S
import Data.List (sortBy, foldl')
import qualified Data.IntMap as IM
import Data.IntMap (IntMap)
import Neet.Genome
import Data.Function
modSig :: Double -> Double
modSig d = 1 / (1 + exp (4.9 * d))
data Neuron =
Neuron { activation :: Double
, connections :: IntMap Double
, yHeight :: Rational
, neurType :: NodeType
}
deriving (Show)
data Network =
Network { netInputs :: [NodeId]
, netOutputs :: [NodeId]
, netState :: IntMap Neuron
, netDepth :: Int
}
deriving (Show)
stepNeuron :: IntMap Double -> Neuron -> Neuron
stepNeuron acts (Neuron _ conns yh nt) = Neuron (modSig weightedSum) conns yh nt
where oneFactor nId w = (acts IM.! nId) * w
weightedSum = IM.foldlWithKey' (\acc k w -> acc + oneFactor k w) 0 conns
stepNetwork :: Network -> [Double] -> Network
stepNetwork net@Network{..} ins = net { netState = newNeurons }
where pairs = zipWith (\x y -> (getNodeId x, y)) netInputs (ins ++ [1])
acts = IM.map activation netState
modState = foldl' (flip $ uncurry IM.insert) acts pairs
newNeurons = IM.map (stepNeuron modState) netState
snapshot :: Network -> [Double] -> Network
snapshot net = go (netDepth net 1)
where go 0 _ = net
go n ds = stepNetwork (go (n 1) ds) ds
pushThrough :: Network -> [Double] -> [Double]
pushThrough net inputs = output
where nodeOrder = sortBy (compare `on` (yHeight . snd)) $ IM.toList nodeMap
nodeMap = netState net
nonInputs = filter (\p -> neurType (snd p) /= Input) nodeOrder
inPairs = zip (map getNodeId $ netInputs net) (inputs ++ [1])
initState = foldl' (flip $ uncurry IM.insert) IM.empty inPairs
addOne :: IntMap Double -> (Int, Neuron) -> IntMap Double
addOne acc (nId, neur) =
IM.insert nId (activation (stepNeuron acc neur)) acc
final = foldl' addOne initState nonInputs
output = map ((final IM.!) . getNodeId) (netOutputs net)
mkPhenotype :: Genome -> Network
mkPhenotype Genome{..} = (IM.foldl' addConn nodeHusk connGenes) { netInputs = map NodeId ins
, netOutputs = map NodeId outs
, netDepth = dep }
where addNode n@(Network _ _ s _) nId (NodeGene nt yh) =
n { netState = IM.insert nId (Neuron 0 IM.empty yh nt) s
}
ins = IM.keys . IM.filter (\ng -> nodeType ng == Input) $ nodeGenes
outs = IM.keys . IM.filter (\ng -> nodeType ng == Output) $ nodeGenes
nodeHusk = IM.foldlWithKey' addNode (Network [] [] IM.empty 0) nodeGenes
depthSet :: Set Rational
depthSet = IM.foldl' (flip S.insert) S.empty $ IM.map Neet.Genome.yHint nodeGenes
dep = S.size depthSet
addConn2Node nId w (Neuron a cs yh nt) = Neuron a (IM.insert nId w cs) yh nt
addConn net@Network{ netState = s } ConnGene{..}
| not connEnabled = net
| otherwise =
let newS = IM.adjust (addConn2Node (getNodeId connIn) connWeight) (getNodeId connOut) s
in net { netState = newS }
getOutput :: Network -> [Double]
getOutput Network{..} = map (activation . (netState IM.!) . getNodeId) netOutputs