-- | \"Open\" functions, working on functors instead of trees.

{-# LANGUAGE CPP #-}
module Data.Generics.Fixplate.Open
  (
    toList , toRevList
  -- * Accumulating maps
  , mapAccumL  , mapAccumR
  , mapAccumL_ , mapAccumR_
  -- * Open functions
  , holes   , holesList
  , apply   , builder
  -- * Individual elements
  , project , unsafeProject
  , sizeF
  -- * Enumerations
  , enumerate
  , enumerateWith
  , enumerateWith_
  -- * Shapes
  , Hole(..) , Shape , shape
  -- * Zips
  , zipF , unzipF
  , zipWithF  , unsafeZipWithF
  , zipWithFM , unsafeZipWithFM
  )
where

--------------------------------------------------------------------------------

import Control.Monad ( liftM )
import Data.Foldable
import Data.Traversable ( mapAccumL , mapAccumR )
import Prelude hiding ( foldl , foldr , mapM , mapM_ , concat , concatMap )

import Data.Generics.Fixplate.Base
import Data.Generics.Fixplate.Misc

--------------------------------------------------------------------------------

-- | Equivalent to @reverse . toList@.
toRevList :: Foldable f => f a -> [a]
toRevList = Data.Foldable.foldl (flip (:)) []

--------------------------------------------------------------------------------
-- Accumulating maps

mapAccumL_ :: Traversable f => (a -> b -> (a, c)) -> a -> f b -> f c
mapAccumL_ f x t = snd (mapAccumL f x t)

mapAccumR_ :: Traversable f => (a -> b -> (a, c)) -> a -> f b -> f c
mapAccumR_ f x t = snd (mapAccumR f x t)

--------------------------------------------------------------------------------
-- Open functions

-- | The children together with functions replacing that particular child.    
holes :: Traversable f => f a -> f (a, a -> f a)
holes tree = mapAccumL_ ithHole (0::Int) tree where
  ithHole i x = (i+1, (x,h)) where
    h y = mapAccumL_ g 0 tree where
      g j z = (j+1, if i==j then y else z)

holesList :: Traversable f => f a -> [(a, a -> f a)]
holesList = toList . holes

-- | Apply the given function to each child in turn.
apply :: Traversable f => (a -> a) -> f a -> f (f a)
apply f tree = fmap g (holes tree) where
  g (x,replace) = replace (f x)

-- | Builds up a structure from a list of the children.
-- It is unsafe in the sense that it will throw an exception
-- if there are not enough elements in the list.
builder :: Traversable f => f a -> [b] -> f b
builder tree xs = mapAccumL_ g xs tree where
  g (x:xs) _ = (xs,x)
  g _ _ = error "Open/builder: shouldn't happen"

--------------------------------------------------------------------------------

-- | Extracts the ith child.
project :: Foldable f => Int -> f a -> Maybe a
project i tree =
  case foldl f (Left 0) tree of
    Right x -> Just x
    Left  _ -> Nothing
  where
    f (Left j) x = if i==j then Right x else Left (j+1)
    f old      _ = old

unsafeProject :: Foldable f => Int -> f a -> a
unsafeProject i tree =
  case foldl f (Left 0) tree of
    Right x -> x
    Left  _ -> error "unsafePoject: invalid index"
  where
    f (Left j) x = if i==j then Right x else Left (j+1)
    f old      _ = old

-- | Number of children. This is the generalization of 'length' to foldable functors:
--
-- > sizeF x = length (toList x)
--
sizeF :: Foldable f => f a -> Int
sizeF = foldl (\i _ -> i+1) 0

--------------------------------------------------------------------------------
-- Enumerations

-- | Enumerates children from the left to the right, starting with zero.
-- Also returns the number of children. This is just a simple application
-- of 'mapAccumL'.
enumerate :: Traversable f => f a -> (Int, f (Int, a))
enumerate = mapAccumL (\i x -> (i+1,(i,x))) 0

enumerateWith :: Traversable f => (Int -> a -> b) -> f a -> (Int, f b)
enumerateWith h = mapAccumL (\i x -> (i+1, h i x)) 0

enumerateWith_ :: Traversable f => (Int -> a -> b) -> f a -> f b
enumerateWith_ h = snd . enumerateWith h

--------------------------------------------------------------------------------
-- Shapes

-- | A type encoding the \"shape\" of the functor data:
-- We ignore all the fields whose type is the parameter type,
-- but remember the rest:
--
-- > newtype Shape f = Shape { unShape :: f Hole }
--
-- This can be used to decide whether two realizations are compatible.
newtype Shape f = Shape { unShape :: f Hole }

-- | Extracting the \"shape\" of the functor
shape :: Functor f => f a -> Shape f
shape = Shape . fmap (const Hole)

instance EqF   f => Eq   (Shape f) where x == y        = equalF       (unShape x) (unShape y)
instance OrdF  f => Ord  (Shape f) where compare x y   = compareF     (unShape x) (unShape y)

-- we need this dirty trick because we want at the same time have
-- a 'Show' instance for 'Shape f' and allow the user to define
-- his own Show instance for 'Hole'.
data Void = Void ; instance Show Void where show _ = "_"

instance (Functor f, ShowF f) => Show (Shape f) where
  showsPrec d x = showParen (d>app_prec)
    $ showString "Shape "
    . showsPrecF (app_prec+1) (fmap (const Void) $ unShape x)

--------------------------------------------------------------------------------
-- Zips

-- | Zips two structures if they are compatible.
zipF :: (Traversable f, EqF f) => f a -> f b -> Maybe (f (a,b))
zipF = zipWithF (,)

unzipF :: Functor f => f (a,b) -> (f a, f b)
unzipF t = (fmap fst t, fmap snd t)

-- | Zipping two structures using a function.
zipWithF :: (Traversable f, EqF f) => (a -> b -> c) -> f a -> f b -> Maybe (f c)
zipWithF f x y =
  if shape x == shape y
    then Just (unsafeZipWithF f x y)
    else Nothing

-- | Unsafe version of 'zipWithF': does not check if the two structures are compatible.
-- It is left-biased in the sense that the structure of the second argument is retained.
unsafeZipWithF :: Traversable f => (a -> b -> c) -> f a -> f b -> f c
unsafeZipWithF f x y = z where
  z = mapAccumL_ g (toList y) x
  g (b:bs) a = (bs, f a b)
  g _ _ = error "Open/unsafeZipWithF: shouldn't happen"

--------------------------------------------------------------------------------

-- | Monadic version of 'zipWithF'. TODO: better name?
zipWithFM :: (Traversable f, EqF f, Monad m) => (a -> b -> m c) -> f a -> f b -> m (Maybe (f c))
zipWithFM f x y =
  if shape x == shape y
    then liftM Just (unsafeZipWithFM f x y)
    else return Nothing

unsafeZipWithFM :: (Traversable f, Monad m) => (a -> b -> m c) -> f a -> f b -> m (f c)
unsafeZipWithFM f x y = liftM snd $ mapAccumM g (toList y) x where
  g (b:bs) a = f a b >>= \r -> return (bs, r)
  g _ _ = error "Open/unsafeZipWithFM: shouldn't happen"

--------------------------------------------------------------------------------