module Neet.Genome (
NodeId(..)
, NodeType(..)
, NodeGene(..)
, ConnGene(..)
, InnoId(..)
, ConnSig
, Genome(..)
, fullConn
, mutate
, crossover
, breed
, distance
, renderGenome
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Random
import Data.Map.Strict (Map)
import qualified Data.Traversable as T
import qualified Data.Map.Strict as M
import Data.Set (Set)
import qualified Data.Set as S
import Data.Maybe
import Control.Monad.Fresh.Class
import Neet.Parameters
import Data.GraphViz
import Data.GraphViz.Attributes.Complete
newtype NodeId = NodeId Int
deriving (Show, Eq, Ord, PrintDot)
data NodeType = Input | Hidden | Output
deriving (Show, Eq)
data NodeGene = NodeGene { nodeType :: NodeType
, yHint :: Rational
}
deriving (Show)
data ConnGene = ConnGene { connIn :: NodeId
, connOut :: NodeId
, connWeight :: Double
, connEnabled :: Bool
, connRec :: Bool
}
deriving (Show)
newtype InnoId = InnoId Int
deriving (Show, Eq, Ord)
data Genome =
Genome { nodeGenes :: Map NodeId NodeGene
, connGenes :: Map InnoId ConnGene
, nextNode :: NodeId
}
deriving (Show)
fullConn :: MonadRandom m => Parameters -> Int -> Int -> m Genome
fullConn Parameters{..} iSize oSize = do
let inCount = iSize + 1
inIDs = map NodeId [1..inCount]
outIDs = map NodeId [inCount + 1..oSize + inCount]
inputGenes = zip inIDs $ repeat (NodeGene Input 0)
outputGenes = zip outIDs $ repeat (NodeGene Output 1)
nodeGenes = M.fromList $ inputGenes ++ outputGenes
nextNode = NodeId $ inCount + oSize + 1
nodePairs = (,) <$> inIDs <*> outIDs
conns <- zipWith (\(inN, outN) w -> ConnGene inN outN w True False) nodePairs `liftM` getRandomRs (weightRange,weightRange)
let connGenes = M.fromList $ zip (map InnoId [1..]) conns
return $ Genome{..}
mutateWeights :: MonadRandom m => Parameters -> Genome -> m Genome
mutateWeights Parameters{..} g@Genome{..} = do
roll <- getRandomR (0,1)
if roll > mutWeightRate
then return g
else setConns g `liftM` T.mapM mutOne connGenes
where setConns g cs = g { connGenes = cs }
mutOne conn = do
roll <- getRandomR (0,1)
let newWeight
| roll <= newWeightRate = getRandomR (weightRange,weightRange)
| otherwise = do
pert <- getRandomR (pertAmount,pertAmount)
return $ connWeight conn + pert
w <- newWeight
return $ conn { connWeight = w }
data ConnSig = ConnSig NodeId NodeId
deriving (Show, Eq, Ord)
toConnSig :: ConnGene -> ConnSig
toConnSig gene = ConnSig (connIn gene) (connOut gene)
addConn :: MonadFresh InnoId m => ConnGene ->
(Map ConnSig InnoId, Map InnoId ConnGene) ->
m (Map ConnSig InnoId, Map InnoId ConnGene)
addConn conn (innos, conns) = case M.lookup siggy innos of
Just inno -> return (innos, M.insert inno conn conns)
Nothing -> do
newInno <- fresh
return (M.insert siggy newInno innos, M.insert newInno conn conns)
where siggy = toConnSig conn
mutateConn :: (MonadFresh InnoId m, MonadRandom m) =>
Parameters -> Map ConnSig InnoId -> Genome -> m (Map ConnSig InnoId, Genome)
mutateConn params innos g = do
roll <- getRandomR (0,1)
if roll > addConnRate params
then return (innos, g)
else case allowed of
[] -> return (innos, g)
xs -> do
(innos', conns') <- addRandConn innos (connGenes g)
return $ (innos', g { connGenes = conns' })
where
taken :: Map ConnSig Bool
taken = M.fromList . map (\c -> (toConnSig c, True)) . M.elems . connGenes $ g
notInput (NodeGene Input _) = False
notInput _ = True
nodes = M.toList $ nodeGenes g
nonInputs = filter (notInput . snd) nodes
makePair (n1,g1) (n2,g2) = (ConnSig n1 n2, yHint g2 <= yHint g1)
candidates = M.fromList $ makePair <$> nodes <*> nonInputs
allowed = M.toList $ M.difference candidates taken
pickOne :: MonadRandom m => m (ConnSig, Bool)
pickOne = uniform allowed
pickWeight :: MonadRandom m => m Double
pickWeight = let r = weightRange params in getRandomR (r,r)
addRandConn :: (MonadRandom m, MonadFresh InnoId m) =>
Map ConnSig InnoId -> Map InnoId ConnGene ->
m (Map ConnSig InnoId, Map InnoId ConnGene)
addRandConn innos conns = do
(ConnSig inNode outNode, recc) <- pickOne
w <- pickWeight
let newConn = ConnGene inNode outNode w True recc
addConn newConn (innos,conns)
mutateNode :: (MonadRandom m, MonadFresh InnoId m) =>
Parameters -> Map ConnSig InnoId ->
Genome -> m (Map ConnSig InnoId, Genome)
mutateNode params innos g = do
roll <- getRandomR (0,1)
if roll <= addNodeRate params then addRandNode else return (innos, g)
where conns = connGenes g
nodes = nodeGenes g
pickConn :: MonadRandom m => m (InnoId, ConnGene)
pickConn = uniform $ M.toList conns
newId = nextNode g
newNextNode = case newId of NodeId x -> NodeId (x + 1)
addNode :: MonadFresh InnoId m =>
InnoId -> ConnGene -> m (Map ConnSig InnoId, Genome)
addNode inno gene = do
let ConnSig inId outId = toConnSig gene
inGene = nodes M.! inId
outGene = nodes M.! outId
newGene = NodeGene Hidden ((yHint inGene + yHint outGene) / 2)
newNodes = M.insert newId newGene nodes
disabledConn = gene { connEnabled = False }
backGene = ConnGene inId newId 1 True (connRec gene)
forwardGene = ConnGene newId outId (connWeight gene) True (connRec gene)
(innos', newConns) <-
addConn backGene >=> addConn forwardGene $ (innos, conns)
return $ (innos', g { nodeGenes = newNodes
, connGenes = M.insert inno disabledConn newConns
, nextNode = newNextNode
})
addRandNode :: (MonadRandom m, MonadFresh InnoId m) => m (Map ConnSig InnoId, Genome)
addRandNode =
pickConn >>= uncurry addNode
mutate :: (MonadRandom m, MonadFresh InnoId m) => Parameters -> Map ConnSig InnoId ->
Genome -> m (Map ConnSig InnoId, Genome)
mutate params innos g = do
g' <- mutateWeights params g
uncurry (mutateNode params) >=> uncurry (mutateConn params) $ (innos, g')
superLeft :: Ord k => (a -> b -> c) -> (a -> c) -> Map k a -> Map k b -> Map k c
superLeft comb mk = M.mergeWithKey (\_ a b -> Just $ comb a b) (M.map mk) (const M.empty)
flipCoin :: MonadRandom m => a -> a -> m a
flipCoin a1 a2 = uniform [a1, a2]
crossConns :: MonadRandom m => Parameters -> Map InnoId ConnGene -> Map InnoId ConnGene ->
m (Map InnoId ConnGene)
crossConns params m1 m2 = T.sequence $ superLeft flipConn return m1 m2
where flipConn c1 c2 = do
if connEnabled c1 && connEnabled c2
then flipCoin c1 c2
else do
c <- flipCoin c1 c2
roll <- getRandomR (0,1)
let enabled
| roll <= disableChance params = False
| otherwise = True
return c { connEnabled = enabled }
crossNodes :: MonadRandom m => Map NodeId NodeGene -> Map NodeId NodeGene ->
m (Map NodeId NodeGene)
crossNodes m1 m2 = T.sequence $ superLeft flipCoin return m1 m2
crossover :: MonadRandom m => Parameters -> Genome -> Genome -> m Genome
crossover params g1 g2 = Genome `liftM` newNodes `ap` newConns `ap` return newNextNode
where newNextNode = max (nextNode g1) (nextNode g2)
newConns = crossConns params (connGenes g1) (connGenes g2)
newNodes = crossNodes (nodeGenes g1) (nodeGenes g2)
breed :: (MonadRandom m, MonadFresh InnoId m) =>
Parameters -> Map ConnSig InnoId -> Genome -> Genome ->
m (Map ConnSig InnoId, Genome)
breed params innos g1 g2 =
crossover params g1 g2 >>= mutate params innos
differences :: Map InnoId ConnGene -> Map InnoId ConnGene -> Map InnoId Double
differences = M.mergeWithKey (\_ c1 c2 -> Just $ oneDiff c1 c2) (const M.empty) (const M.empty)
where oneDiff c1 c2 = abs $ connWeight c1 connWeight c2
distance :: Parameters -> Genome -> Genome -> Double
distance params g1 g2 = c1 * exFactor + c2 * disFactor + c3 * weightFactor
where DistParams c1 c2 c3 dt = distParams params
conns1 = connGenes g1
conns2 = connGenes g2
weightDiffs = differences conns1 conns2
weightFactor = M.foldl (+) 0 weightDiffs / fromIntegral (M.size weightDiffs)
ids1 = M.keysSet conns1
ids2 = M.keysSet conns2
edge = min (S.findMax ids1) (S.findMax ids2)
exJoints = (ids1 `S.difference` ids2) `S.union` (ids2 `S.difference` ids1)
(excess, disjoint) = S.partition (<= edge) exJoints
exFactor = fromIntegral $ S.size excess
disFactor = fromIntegral $ S.size disjoint
graphParams :: GraphvizParams NodeId NodeGene Double Rational Rational
graphParams =
Params { isDirected = True
, globalAttributes = [ GraphAttrs [ RankDir FromLeft
, Splines LineEdges
]
, NodeAttrs [ FixedSize SetNodeSize
]
]
, clusterBy = categorizer
, isDotCluster = const True
, clusterID = iderizer
, fmtCluster = clusterizer
, fmtNode = const []
, fmtEdge = \(_,_,w) -> [ toLabel w ]
}
where categorizer (nId, ng) = C (yHint ng) (N (nId, yHint ng))
iderizer 0 = Str "Input Layer"
iderizer 1 = Str "Output Layer"
iderizer rat = Num (Dbl $ fromRational rat)
whiteAttr = Color [WC (X11Color White) Nothing]
blueAttr = Color [WC (X11Color Blue4) Nothing ]
redAttr = Color [WC (X11Color Red2) Nothing ]
greenAttr = Color [WC (X11Color SeaGreen) Nothing ]
solidAttr = Style [ SItem Solid [] ]
circAttr = Shape Circle
clusterizer 0 = [ GraphAttrs [ whiteAttr, rank MinRank ]
, NodeAttrs [ solidAttr, blueAttr, circAttr ]
]
clusterizer 1 = [ GraphAttrs [ whiteAttr, rank MaxRank ]
, NodeAttrs [ solidAttr, redAttr, circAttr ]
]
clusterizer _ = [ GraphAttrs [ whiteAttr ]
, NodeAttrs [ solidAttr, greenAttr, circAttr ]
]
renderGenome :: Genome -> IO ()
renderGenome g = runGraphvizCanvas Dot graph Xlib
where dg = DotGraph True True Nothing
nodes = M.toList . nodeGenes $ g
edges = mapMaybe mkEdge . M.elems . connGenes $ g
mkEdge ConnGene{..} = if connEnabled then Just (connIn, connOut, connWeight) else Nothing
graph = graphElemsToDot graphParams nodes edges