{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
module StableMarriage.GaleShapley
( Men(..)
, Women(..)
, World
, meets
, PO.Ordering(..)
) where
import Prelude hiding (Ordering(..), compare)
import Control.Arrow ((&&&))
import Data.List (sortOn, groupBy, splitAt)
import Data.Poset as PO (Ordering(..), sortBy')
import Data.Function (on)
class Men m where
type W m :: *
loves :: m -> [W m]
forget :: m -> m
class (Ord w, Men m, w ~ W m) => Women m w where
acceptable :: w -> m -> Bool
compare :: w -> m -> m -> PO.Ordering
limit :: w -> [m] -> Int
limit _ _ = 1
type World w m = (Men m, Women m w, w ~ W m) => ([(w, [m])], [m])
marriage :: World w m -> World w m
marriage x = let x' = counter $ attack x
in if stable x'
then x'
else marriage x'
stable :: (Men m, Women m w, w ~ W m) => World w m -> Bool
stable (cs, ms) = all resigned ms
where
resigned :: Men m => m -> Bool
resigned = null . loves
satisfy :: (Men m, Women m w, w ~ W m) => (w, [m]) -> Bool
satisfy (w, ms) = limit w ms <= length ms
attack :: World w m -> World w m
attack (cs, ms) = (cs', ms')
where
cs' = join cs (propose ms)
ms' = despair ms
propose :: (Ord w, Men m, Women m w, w ~ W m) => [m] -> [(w, [m])]
propose = gather . competes
where
competes :: (Men m, Women m w, w ~ W m, Ord w) => [m] -> [[(w, m)]]
competes = groupBy ((==) `on` fst) . sortOn fst . concatMap next
where
next :: (Men m, Women m w, w ~ W m) => m -> [(w, m)]
next m = let xs = loves m
in if null xs
then []
else [(head xs, m)]
gather :: (Men m, Women m w, w ~ W m) => [[(w, m)]] -> [(w, [m])]
gather = map sub
where
sub :: (Men m, Women m w, w ~ W m) => [(w, m)] -> (w, [m])
sub cs@((w, m):_) = (w, map snd cs)
join :: (Men m, Women m w, w ~ W m) => [(w, [m])] -> [(w, [m])] -> [(w, [m])]
join cs xs = gather $ groupBy ((==) `on` fst) $ sortOn fst $ cs ++ xs
where
gather :: (Men m, Women m w, w ~ W m) => [[(w, [m])]] -> [(w, [m])]
gather = map sub
where
sub :: (Men m, Women m w, w ~ W m) => [(w, [m])] -> (w, [m])
sub cs@((w, m):_) = (w, concatMap snd cs)
despair :: Men m => [m] -> [m]
despair = filter (null . loves)
counter :: World w m -> World w m
counter (cs, ms) = (cs', ms'')
where
(cs', ms') = choice cs
ms'' = ms ++ heartbreak ms'
heartbreak :: Men m => [m] -> [m]
heartbreak = map forget
choice :: (Men m, Women m w, w ~ W m) => [(w, [m])] -> World w m
choice = gather . map judge
where
judge :: (Men m, Women m w, w ~ W m) => (w, [m]) -> ((w, [m]), [m])
judge (w, ms) = let (n, p, cmp) = (limit w ms, acceptable w, compare w)
(cs, rs) = splitAt n $ sortBy' (p, cmp) ms
out = filter (not . p) ms
in ((w, cs), rs ++ out)
gather :: (Men m, Women m w, w ~ W m) => [((w, [m]), [m])] -> World w m
gather = map fst &&& concatMap snd
meets :: (Men m, Women m w, w ~ W m) => [m] -> [w] -> World w m
meets ms ws = marriage (zip ws (repeat []), ms)