{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
module StableMarriage.GaleShapley
       ( Men(..)
       , Women(..)
       , World
       , meets

       -- re-export
       , PO.Ordering(..)
       ) where

import Prelude hiding (Ordering(..), compare)
import Control.Arrow ((&&&))
import Data.List (sortOn, groupBy, splitAt)
import Data.Poset as PO (Ordering(..), sortBy')
import Data.Function (on)

class Men m where
  type W m :: *
  loves :: m -> [W m]
  forget :: m -> m

class (Ord w, Men m, w ~ W m) => Women m w where
  acceptable :: w -> m -> Bool
  compare :: w -> m -> m -> PO.Ordering
  limit :: w -> [m] -> Int
  limit w
_ [m]
_ = Int
1

type World w m = (Men m, Women m w, w ~ W m) => ([(w, [m])], [m])

marriage :: World w m -> World w m
marriage :: forall w m. World w m -> World w m
marriage World w m
x = let x' :: ([(w, [m])], [m])
x' = forall w m. World w m -> World w m
counter forall a b. (a -> b) -> a -> b
$ forall w m. World w m -> World w m
attack World w m
x
             in if forall m w. (Men m, Women m w, w ~ W m) => World w m -> Bool
stable ([(w, [m])], [m])
x'
                then ([(w, [m])], [m])
x'
                else forall w m. World w m -> World w m
marriage ([(w, [m])], [m])
x'

stable :: (Men m, Women m w, w ~ W m) => World w m -> Bool
stable :: forall m w. (Men m, Women m w, w ~ W m) => World w m -> Bool
stable ([(w, [m])]
cs, [m]
ms) = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall m. Men m => m -> Bool
resigned [m]
ms
    where
      resigned :: Men m => m -> Bool
      resigned :: forall m. Men m => m -> Bool
resigned = forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall m. Men m => m -> [W m]
loves
      satisfy :: (Men m, Women m w, w ~ W m) => (w, [m]) -> Bool
      satisfy :: forall m w. (Men m, Women m w, w ~ W m) => (w, [m]) -> Bool
satisfy (w
w, [m]
ms) = forall m w. Women m w => w -> [m] -> Int
limit w
w [m]
ms forall a. Ord a => a -> a -> Bool
<= forall (t :: * -> *) a. Foldable t => t a -> Int
length [m]
ms

attack :: World w m -> World w m
attack :: forall w m. World w m -> World w m
attack ([(w, [m])]
cs, [m]
ms) = ([(w, [m])]
cs', [m]
ms')
    where
      cs' :: [(w, [m])]
cs' = forall m w.
(Men m, Women m w, w ~ W m) =>
[(w, [m])] -> [(w, [m])] -> [(w, [m])]
join [(w, [m])]
cs (forall w m. (Ord w, Men m, Women m w, w ~ W m) => [m] -> [(w, [m])]
propose [m]
ms)
      ms' :: [m]
ms' = forall m. Men m => [m] -> [m]
despair [m]
ms

propose :: (Ord w, Men m, Women m w, w ~ W m) => [m] -> [(w, [m])]
propose :: forall w m. (Ord w, Men m, Women m w, w ~ W m) => [m] -> [(w, [m])]
propose = forall m w. (Men m, Women m w, w ~ W m) => [[(w, m)]] -> [(w, [m])]
gather forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall m w. (Men m, Women m w, w ~ W m, Ord w) => [m] -> [[(w, m)]]
competes
    where
      competes :: (Men m, Women m w, w ~ W m, Ord w) => [m] -> [[(w, m)]]
      competes :: forall m w. (Men m, Women m w, w ~ W m, Ord w) => [m] -> [[(w, m)]]
competes = forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall m w. (Men m, Women m w, w ~ W m) => m -> [(w, m)]
next
          where
            next :: (Men m, Women m w, w ~ W m) => m -> [(w, m)]
            next :: forall m w. (Men m, Women m w, w ~ W m) => m -> [(w, m)]
next m
m = let xs :: [W m]
xs = forall m. Men m => m -> [W m]
loves m
m
                     in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [W m]
xs
                        then []
                        else [(forall a. [a] -> a
head [W m]
xs, m
m)]
gather :: (Men m, Women m w, w ~ W m) => [[(w, m)]] -> [(w, [m])]
gather :: forall m w. (Men m, Women m w, w ~ W m) => [[(w, m)]] -> [(w, [m])]
gather = forall a b. (a -> b) -> [a] -> [b]
map forall m w. (Men m, Women m w, w ~ W m) => [(w, m)] -> (w, [m])
sub
          where
            sub :: (Men m, Women m w, w ~ W m) => [(w, m)] -> (w, [m])
            sub :: forall m w. (Men m, Women m w, w ~ W m) => [(w, m)] -> (w, [m])
sub cs :: [(w, m)]
cs@((w
w, m
m):[(w, m)]
_) = (w
w, forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(w, m)]
cs)

join :: (Men m, Women m w, w ~ W m) => [(w, [m])] -> [(w, [m])] -> [(w, [m])]
join :: forall m w.
(Men m, Women m w, w ~ W m) =>
[(w, [m])] -> [(w, [m])] -> [(w, [m])]
join [(w, [m])]
cs [(w, [m])]
xs = forall m w.
(Men m, Women m w, w ~ W m) =>
[[(w, [m])]] -> [(w, [m])]
gather forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ [(w, [m])]
cs forall a. [a] -> [a] -> [a]
++ [(w, [m])]
xs
    where
      gather :: (Men m, Women m w, w ~ W m) => [[(w, [m])]] -> [(w, [m])]
      gather :: forall m w.
(Men m, Women m w, w ~ W m) =>
[[(w, [m])]] -> [(w, [m])]
gather = forall a b. (a -> b) -> [a] -> [b]
map forall m w. (Men m, Women m w, w ~ W m) => [(w, [m])] -> (w, [m])
sub
          where
            sub :: (Men m, Women m w, w ~ W m) => [(w, [m])] -> (w, [m])
            sub :: forall m w. (Men m, Women m w, w ~ W m) => [(w, [m])] -> (w, [m])
sub cs :: [(w, [m])]
cs@((w
w, [m]
m):[(w, [m])]
_) = (w
w, forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall a b. (a, b) -> b
snd [(w, [m])]
cs)

despair :: Men m => [m] -> [m]
despair :: forall m. Men m => [m] -> [m]
despair = forall a. (a -> Bool) -> [a] -> [a]
filter (forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall m. Men m => m -> [W m]
loves)

counter :: World w m -> World w m
counter :: forall w m. World w m -> World w m
counter ([(w, [m])]
cs, [m]
ms) = ([(w, [m])]
cs', [m]
ms'')
    where
      ([(w, [m])]
cs', [m]
ms') = forall m w. (Men m, Women m w, w ~ W m) => [(w, [m])] -> World w m
choice [(w, [m])]
cs
      ms'' :: [m]
ms'' = [m]
ms forall a. [a] -> [a] -> [a]
++ forall m. Men m => [m] -> [m]
heartbreak [m]
ms'

      heartbreak :: Men m => [m] -> [m]
      heartbreak :: forall m. Men m => [m] -> [m]
heartbreak = forall a b. (a -> b) -> [a] -> [b]
map forall m. Men m => m -> m
forget


choice :: (Men m, Women m w, w ~ W m) => [(w, [m])] -> World w m
choice :: forall m w. (Men m, Women m w, w ~ W m) => [(w, [m])] -> World w m
choice = forall m w.
(Men m, Women m w, w ~ W m) =>
[((w, [m]), [m])] -> World w m
gather forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall m w.
(Men m, Women m w, w ~ W m) =>
(w, [m]) -> ((w, [m]), [m])
judge
    where
      judge :: (Men m, Women m w, w ~ W m) => (w, [m]) -> ((w, [m]), [m])
      judge :: forall m w.
(Men m, Women m w, w ~ W m) =>
(w, [m]) -> ((w, [m]), [m])
judge (w
w, [m]
ms) = let (Int
n, m -> Bool
p, m -> m -> Ordering
cmp) = (forall m w. Women m w => w -> [m] -> Int
limit w
w [m]
ms, forall m w. Women m w => w -> m -> Bool
acceptable w
w, forall m w. Women m w => w -> m -> m -> Ordering
compare w
w)
                          ([m]
cs, [m]
rs) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
n forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool, a -> a -> Ordering) -> [a] -> [a]
sortBy' (m -> Bool
p, m -> m -> Ordering
cmp) [m]
ms
                          out :: [m]
out = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. m -> Bool
p) [m]
ms
                      in ((w
w, [m]
cs), [m]
rs forall a. [a] -> [a] -> [a]
++ [m]
out)
      gather :: (Men m, Women m w, w ~ W m) => [((w, [m]), [m])] -> World w m
      gather :: forall m w.
(Men m, Women m w, w ~ W m) =>
[((w, [m]), [m])] -> World w m
gather = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall a b. (a, b) -> b
snd

meets :: (Men m, Women m w, w ~ W m) => [m] -> [w] -> World w m
meets :: forall m w. (Men m, Women m w, w ~ W m) => [m] -> [w] -> World w m
meets [m]
ms [w]
ws = forall w m. World w m -> World w m
marriage (forall a b. [a] -> [b] -> [(a, b)]
zip [w]
ws (forall a. a -> [a]
repeat []), [m]
ms)