module Data.Derive.Fold(makeFold) where
import Language.Haskell
import Data.Derive.Internal.Derivation
import Data.List
import Data.Generics.Uniplate.DataOnly
makeFold :: Derivation
makeFold = derivationCustom "Fold" $ \(_,d) -> Right $ simplify $ mkFold d
mkFold :: DataDecl -> [Decl ()]
mkFold d | isIdent $ dataDeclName d = [TypeSig () [name n] (foldType d), FunBind () $ zipWith f [0..] $ dataDeclCtors d]
| otherwise = []
where
n = "fold" ++ title (dataDeclName d)
f i c = Match () (name n) pat (UnGuardedRhs () bod) Nothing
where pat = replicate i (PWildCard ()) ++ [pVar "f"] ++ replicate (length (dataDeclCtors d) - i - 1) (PWildCard ()) ++
[PParen () $ PApp () (qname $ ctorDeclName c) (map pVar vars)]
bod = apps (var "f") (map var vars)
vars = ['x' : show i | i <- [1..length (ctorDeclFields c)]]
foldType :: DataDecl -> Type ()
foldType d = tyFun $ map f (dataDeclCtors d) ++ [dt, v]
where
dt = dataDeclType d
v = head $ map (tyVar . return) ['a'..] \\ universe dt
f c = TyParen () $ tyFun $ map snd (ctorDeclFields c) ++ [v]