{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE PolyKinds #-}

module Camfort.Specification.Stencils.InferenceBackend where

import Prelude hiding (sum)
import Data.Generics.Uniplate.Operations
import Data.List hiding (sum)
import Data.Data
import Control.Arrow ((***))
import Data.Function

import Camfort.Specification.Stencils.Model
import Camfort.Helpers
import Camfort.Helpers.Vec

import Debug.Trace
import Unsafe.Coerce

import Camfort.Specification.Stencils.Syntax

{- Spans are a pair of a lower and upper bound -}
type Span a = (a, a)
mkTrivialSpan a = (a, a)

inferFromIndices :: VecList Int -> Specification
inferFromIndices (VL ixs) = Specification $
    case fromBool mult of
      Linear -> Single $ inferCore ixs'
      NonLinear -> Multiple $ inferCore ixs'
      (ixs', mult) = hasDuplicates ixs

-- Same as inferFromIndices but don't do any linearity checking
-- (defaults to NonLinear). This is used when the front-end does
-- the linearity check first as an optimimsation.
inferFromIndicesWithoutLinearity :: VecList Int -> Specification
inferFromIndicesWithoutLinearity (VL ixs) =
    Specification . Multiple . inferCore $ ixs

inferCore :: (IsNatural n, Permutable n) => [Vec n Int] -> Approximation Spatial
inferCore = simplify . fromRegionsToSpec . inferMinimalVectorRegions

simplify :: Approximation Spatial -> Approximation Spatial
simplify = fmap simplifySpatial

simplifySpatial :: Spatial -> Spatial
simplifySpatial (Spatial (Sum ps)) = Spatial (Sum ps')
   where ps' = order (reducor ps normaliseNoSort size)
         order = sort . (map (Product . sort . unProd))
         size :: [RegionProd] -> Int
         size = foldr (+) 0 . map (length . unProd)

-- Given a list, a list->list transofmer, a size function
-- find the minimal transformed list by applying the transformer
-- to every permutation of the list and when a smaller list is found
-- iteratively apply to permutations on the smaller list
reducor :: [a] -> ([a] -> [a]) -> ([a] -> Int) -> [a]
reducor xs f size = reducor' (permutations xs)
      reducor' [y] = f y
      reducor' (y:ys) =
          if (size y' < size y)
            then reducor' (permutations y')
            else reducor' ys
        where y' = f y

fromRegionsToSpec :: IsNatural n => [Span (Vec n Int)] -> Approximation Spatial
fromRegionsToSpec = foldr (\x y -> sum (toSpecND x) y) zero

-- toSpecND converts an n-dimensional region into an exact
-- spatial specification or a bound of spatial specifications
toSpecND :: Span (Vec n Int) -> Approximation Spatial
toSpecND = toSpecPerDim 1
   -- convert the region one dimension at a time.
   toSpecPerDim :: Int -> Span (Vec n Int) -> Approximation Spatial
   toSpecPerDim d (Nil, Nil)             = one
   toSpecPerDim d (Cons l ls, Cons u us) =
     prod (toSpec1D d l u) (toSpecPerDim (d + 1) (ls, us))

-- toSpec1D takes a dimension identifier, a lower and upper bound of a region in
-- that dimension, and builds the simple directional spec.
toSpec1D :: Dimension -> Int -> Int -> Approximation Spatial
toSpec1D dim l u
    | l == absoluteRep || u == absoluteRep =
        Exact $ Spatial (Sum [Product []])

    | l == 0 && u == 0 =
        Exact $ Spatial (Sum [Product [Centered 0 dim True]])

    | l < 0 && u == 0 =
        Exact $ Spatial (Sum [Product [Backward (abs l) dim True]])

    | l < 0 && u == (-1) =
        Exact $ Spatial (Sum [Product [Backward (abs l) dim False]])

    | l == 0 && u > 0 =
        Exact $ Spatial (Sum [Product [Forward u dim True]])

    | l == 1 && u > 0 =
        Exact $ Spatial (Sum [Product [Forward u dim False]])

    | l < 0 && u > 0 && (abs l == u) =
        Exact $ Spatial (Sum [Product [Centered u dim True]])

    | l < 0 && u > 0 && (abs l /= u) =
        Exact $ Spatial (Sum [Product [Backward (abs l) dim True],
                              Product [Forward  u       dim True]])
    -- Represents a non-contiguous region
    | otherwise =
        upperBound $ Spatial (Sum [Product
                        [if l > 0 then Forward u dim True else Backward (abs l) dim True]])

{- Normalise a span into the form (lower, upper) based on the first index -}
normaliseSpan :: Span (Vec n Int) -> Span (Vec n Int)
normaliseSpan (Nil, Nil)
    = (Nil, Nil)
normaliseSpan (a@(Cons l1 ls1), b@(Cons u1 us1))
    | l1 <= u1  = (a, b)
    | otherwise = (b, a)

{- `spanBoundingBox` creates a span which is a bounding box over two spans -}
spanBoundingBox :: Span (Vec n Int) -> Span (Vec n Int) -> Span (Vec n Int)
spanBoundingBox a b = boundingBox' (normaliseSpan a) (normaliseSpan b)
    boundingBox' :: Span (Vec n Int) -> Span (Vec n Int) -> Span (Vec n Int)
    boundingBox' (Nil, Nil) (Nil, Nil)
        = (Nil, Nil)
    boundingBox' (Cons l1 ls1, Cons u1 us1) (Cons l2 ls2, Cons u2 us2)
        = let (ls', us') = boundingBox' (ls1, us1) (ls2, us2)
           in (Cons (min l1 l2) ls', Cons (max u1 u2) us')

{-| Given two spans, if they are consecutive
    (i.e., (lower1, upper1) (lower2, upper2) where lower2 = upper1 + 1)
    then compose together returning Just of the new span. Otherwise Nothing -}
composeConsecutiveSpans :: Span (Vec n Int)
                        -> Span (Vec n Int) -> [Span (Vec n Int)]
composeConsecutiveSpans (Nil, Nil) (Nil, Nil) = [(Nil, Nil)]
composeConsecutiveSpans (Cons l1 ls1, Cons u1 us1) (Cons l2 ls2, Cons u2 us2)
    | (ls1 == ls2) && (us1 == us2) && (u1 + 1 == l2)
      = [(Cons l1 ls1, Cons u2 us2)]
    | otherwise
      = []

{-| |inferMinimalVectorRegions| a key part of the algorithm, from a list of
    n-dimensional relative indices it infers a list of (possibly overlapping)
    1-dimensional spans (vectors) within the n-dimensional space.
    Built from |minimalise| and |allRegionPermutations| -}
inferMinimalVectorRegions :: (Permutable n) => [Vec n Int] -> [Span (Vec n Int)]
inferMinimalVectorRegions = fixCoalesce . map mkTrivialSpan
  where fixCoalesce spans =
          let spans' = minimaliseRegions . allRegionPermutations $ spans
          in if spans' == spans then spans' else fixCoalesce spans'

{-| Map from a lists of n-dimensional spans of relative indices into all
    possible contiguous spans within the n-dimensional space (individual pass)-}
allRegionPermutations :: (Permutable n)
                      => [Span (Vec n Int)] -> [Span (Vec n Int)]
allRegionPermutations =
  nub . concat . unpermuteIndices . map (coalesceRegions >< id) . groupByPerm . map permutationss
      {- Permutations of a indices in a span
         (independently permutes the lower and upper bounds in the same way) -}
      permutationss :: Permutable n
                   => Span (Vec n Int)
                   -> [(Span (Vec n Int), Vec n Int -> Vec n Int)]
      -- Since the permutation ordering is identical for lower & upper bound,
      -- reuse the same unpermutation
      permutationss (l, u) = map (\((l', un1), (u', un2)) -> ((l', u'), un1))
                           $ zip (permutationsV l) (permutationsV u)

      sortByFst        = sortBy (\(l1, u1) (l2, u2) -> compare l1 l2)

      groupByPerm  :: [[(Span (Vec n Int), Vec n Int -> Vec n Int)]]
                   -> [( [Span (Vec n Int)] , Vec n Int -> Vec n Int)]
      groupByPerm      = map (\ixP -> let unPerm = snd $ head ixP
                                      in (map fst ixP, unPerm)) . transpose

      coalesceRegions :: [Span (Vec n Int)] -> [Span (Vec n Int)]
      coalesceRegions  = nub . foldL composeConsecutiveSpans . sortByFst

      unpermuteIndices :: [([Span (Vec n Int)], Vec n Int -> Vec n Int)]
                       -> [[Span (Vec n Int)]]
      unpermuteIndices = nub . map (\(rs, unPerm) -> map (unPerm *** unPerm) rs)

-- Helper function, reduces a list two elements at a time with a non-determistic operation
foldL :: (a -> a -> [a]) -> [a] -> [a]
foldL f [] = []
foldL f [a] = [a]
foldL f (a:(b:xs)) = case f a b of
                       [] -> a : foldL f (b : xs)
                       cs -> foldL f (cs ++ xs)

{-| Collapses the regions into a small set by looking for potential overlaps
    and eliminating those that overlap -}
minimaliseRegions :: [Span (Vec n Int)] -> [Span (Vec n Int)]
minimaliseRegions [] = []
minimaliseRegions xss = nub . minimalise $ xss
  where localMin x ys = (filter' x (\y -> containedWithin x y && (x /= y)) xss) ++ ys
        minimalise = foldr localMin []
        -- If nothing is caught by the filter, i.e. no overlaps then return
        -- the original regions r
        filter' r f xs = case filter f xs of
                           [] -> [r]
                           ys -> ys

{-| Binary predicate on whether the first region containedWithin the second -}
containedWithin :: Span (Vec n Int) -> Span (Vec n Int) -> Bool
containedWithin (Nil, Nil) (Nil, Nil)
  = True
containedWithin (Cons l1 ls1, Cons u1 us1) (Cons l2 ls2, Cons u2 us2)
  = (l2 <= l1 && u1 <= u2) && containedWithin (ls1, us1) (ls2, us2)

{-| Defines the (total) class of vector sizes which are permutable, along with
    the permutation function which pairs permutations with the 'unpermute'
    operation -}
class Permutable (n :: Nat) where
  -- From a Vector of length n to a list of 'selections'
  --   (triples of a selected element, the rest of the vector,
  --   a function to 'unselect')
  selectionsV :: Vec n a -> [Selection n a]
  -- From a Vector of length n to a list of its permutations paired with the
  -- 'unpermute' function
  permutationsV :: Vec n a -> [(Vec n a, Vec n a -> Vec n a)]

-- 'Split' is a size-indexed family which gives the type of selections
-- for each size:
--    Z is trivial
--    (S n) provides a triple of the select element, the remaining vector,
--           and the 'unselect' function for returning the original value
type family Selection n a where
            Selection Z a = a
            Selection (S n) a = (a, Vec n a, a -> Vec n a -> Vec (S n) a)

instance Permutable Z where
  selectionsV Nil   = []
  permutationsV Nil = [(Nil, id)]

instance Permutable (S Z) where
  selectionsV (Cons x xs)
    = [(x, Nil, Cons)]
  permutationsV (Cons x Nil)
    = [(Cons x Nil, id)]

instance Permutable (S n) => Permutable (S (S n)) where
  selectionsV (Cons x xs) =
    (x, xs, Cons) : [ (y, Cons x ys, unselect unSel)
                    | (y, ys, unSel) <- selectionsV xs ]
     unselect :: (a -> Vec n a -> Vec (S n) a)
              -> (a -> Vec (S n) a -> Vec (S (S n)) a)
     unselect f y' (Cons x' ys') = Cons x' (f y' ys')

  permutationsV xs =
      [ (Cons y zs, \(Cons y' zs') -> unSel y' (unPerm zs'))
        | (y, ys, unSel) <- selectionsV xs,
          (zs,  unPerm)  <- permutationsV ys ]

{- Vector list repreentation where the size 'n' is existential quantified -}
data VecList a where VL :: (IsNatural n, Permutable n) => [Vec n a] -> VecList a

-- Lists existentially quanitify over a vector's size : Exists n . Vec n a
data List a where
     List :: (IsNatural n, Permutable n) => Vec n a -> List a

lnil :: List a
lnil = List Nil
lcons :: a -> List a -> List a
lcons x (List Nil) = List (Cons x Nil)
lcons x (List (Cons y Nil)) = List (Cons x (Cons y Nil))
lcons x (List (Cons y (Cons z xs))) = List (Cons x (Cons y (Cons z xs)))

fromList :: [a] -> List a
fromList = foldr lcons lnil

-- pre-condition: the input is a 'rectangular' list of lists (i.e. all internal
-- lists have the same size)
fromLists :: [[Int]] -> VecList Int
fromLists [] = VL ([] :: [Vec Z Int])
fromLists (xs:xss) = consList (fromList xs) (fromLists xss)
    consList :: List Int -> VecList Int -> VecList Int
    consList (List vec) (VL [])     = VL [vec]
    consList (List vec) (VL (x:xs))
      = let (vec', x') = zipVec vec x
        in  -- Force the pre-condition equality
          case (preCondition x' xs, preCondition vec' xs) of
            (ReflEq, ReflEq) -> VL (vec' : (x' : xs))

            where -- At the moment the pre-condition is 'assumed', and therefore
              -- force used unsafeCoerce: TODO, rewrite
              preCondition :: Vec n a -> [Vec n1 a] -> EqT n n1
              preCondition xs x = unsafeCoerce ReflEq

-- Equality type
data EqT (a :: k) (b :: k) where
    ReflEq :: EqT a a

