{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | A list diff.
module Data.TreeDiff.List (
    diffBy,
    Edit (..),
) where

import Control.DeepSeq (NFData (..))
import Control.Monad.ST (ST, runST)

import qualified Data.Primitive as P

-- import Debug.Trace

-- | List edit operations
--
-- The 'Swp' constructor is redundant, but it let us spot
-- a recursion point when performing tree diffs.
data Edit a
    = Ins a    -- ^ insert
    | Del a    -- ^ delete
    | Cpy a    -- ^ copy unchanged
    | Swp a a  -- ^ swap, i.e. delete + insert
  deriving (Edit a -> Edit a -> Bool
forall a. Eq a => Edit a -> Edit a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Edit a -> Edit a -> Bool
$c/= :: forall a. Eq a => Edit a -> Edit a -> Bool
== :: Edit a -> Edit a -> Bool
$c== :: forall a. Eq a => Edit a -> Edit a -> Bool
Eq, Int -> Edit a -> ShowS
forall a. Show a => Int -> Edit a -> ShowS
forall a. Show a => [Edit a] -> ShowS
forall a. Show a => Edit a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Edit a] -> ShowS
$cshowList :: forall a. Show a => [Edit a] -> ShowS
show :: Edit a -> String
$cshow :: forall a. Show a => Edit a -> String
showsPrec :: Int -> Edit a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Edit a -> ShowS
Show)

instance NFData a => NFData (Edit a) where
    rnf :: Edit a -> ()
rnf (Ins a
x)   = forall a. NFData a => a -> ()
rnf a
x
    rnf (Del a
x)   = forall a. NFData a => a -> ()
rnf a
x
    rnf (Cpy a
x)   = forall a. NFData a => a -> ()
rnf a
x
    rnf (Swp a
x a
y) = forall a. NFData a => a -> ()
rnf a
x seq :: forall a b. a -> b -> b
`seq` forall a. NFData a => a -> ()
rnf a
y

-- | List difference.
--
-- >>> diffBy (==) "hello" "world"
-- [Swp 'h' 'w',Swp 'e' 'o',Swp 'l' 'r',Cpy 'l',Swp 'o' 'd']
--
-- >>> diffBy (==) "kitten" "sitting"
-- [Swp 'k' 's',Cpy 'i',Cpy 't',Cpy 't',Swp 'e' 'i',Cpy 'n',Ins 'g']
--
-- prop> \xs ys -> length (diffBy (==) xs ys) >= max (length xs) (length (ys :: String))
-- prop> \xs ys -> length (diffBy (==) xs ys) <= length xs + length (ys :: String)
--
diffBy :: forall a. Show a => (a -> a -> Bool) -> [a] -> [a] -> [Edit a]
diffBy :: forall a. Show a => (a -> a -> Bool) -> [a] -> [a] -> [Edit a]
diffBy a -> a -> Bool
_  [] []   = []
diffBy a -> a -> Bool
_  []  [a]
ys' = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Edit a
Ins [a]
ys'
diffBy a -> a -> Bool
_  [a]
xs' []  = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Edit a
Del [a]
xs'
diffBy a -> a -> Bool
eq [a]
xs' [a]
ys'
    | Bool
otherwise = forall a. [a] -> [a]
reverse (forall a. Cell a -> a
getCell Cell [Edit a]
lcs)
  where
    xn :: Int
xn = forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs'
    yn :: Int
yn = forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ys'

    xs :: Array a
xs = forall a. Int -> [a] -> Array a
P.arrayFromListN Int
xn [a]
xs'
    ys :: Array a
ys = forall a. Int -> [a] -> Array a
P.arrayFromListN Int
yn [a]
ys'

    lcs :: Cell [Edit a]
    lcs :: Cell [Edit a]
lcs = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
        -- traceShowM ("sizes", xn, yn)

        -- create two buffers.
        MutableArray s (Cell [Edit a])
buf1 <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MutableArray (PrimState m) a)
P.newArray Int
yn (forall a. Int -> a -> Cell a
Cell Int
0 [])
        MutableArray s (Cell [Edit a])
buf2 <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MutableArray (PrimState m) a)
P.newArray Int
yn (forall a. Int -> a -> Cell a
Cell Int
0 [])

        -- fill the first row
        -- 0,0 case is filled already
        forall acc s. acc -> (Int -> acc -> ST s acc) -> ST s ()
yLoop (forall a. Int -> a -> Cell a
Cell Int
0 []) forall a b. (a -> b) -> a -> b
$ \Int
m (Cell Int
w [Edit a]
edit) -> do
            let cell :: Cell [Edit a]
cell = forall a. Int -> a -> Cell a
Cell (Int
w forall a. Num a => a -> a -> a
+ Int
1) (forall a. a -> Edit a
Ins (forall a. Array a -> Int -> a
P.indexArray Array a
ys Int
m) forall a. a -> [a] -> [a]
: [Edit a]
edit)
            forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> m ()
P.writeArray MutableArray s (Cell [Edit a])
buf1 Int
m Cell [Edit a]
cell
            forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> m ()
P.writeArray MutableArray s (Cell [Edit a])
buf2 Int
m Cell [Edit a]
cell
            -- traceShowM ("init", m, cell)
            forall (m :: * -> *) a. Monad m => a -> m a
return Cell [Edit a]
cell

        -- following rows
        --
        -- cellC cellT
        -- cellL cellX
        (MutableArray s (Cell [Edit a])
buf1final, MutableArray s (Cell [Edit a])
_, Cell [Edit a]
_) <- forall acc s. acc -> (Int -> acc -> ST s acc) -> ST s acc
xLoop (MutableArray s (Cell [Edit a])
buf1, MutableArray s (Cell [Edit a])
buf2, forall a. Int -> a -> Cell a
Cell Int
0 []) forall a b. (a -> b) -> a -> b
$ \Int
n (MutableArray s (Cell [Edit a])
prev, MutableArray s (Cell [Edit a])
curr, Cell [Edit a]
cellC) -> do
            -- prevZ <- P.unsafeFreezeArray prev
            -- currZ <- P.unsafeFreezeArray prev
            -- traceShowM ("prev", n, prevZ)
            -- traceShowM ("curr", n, currZ)

            let cellL :: Cell [Edit a]
                cellL :: Cell [Edit a]
cellL = case Cell [Edit a]
cellC of (Cell Int
w [Edit a]
edit) -> forall a. Int -> a -> Cell a
Cell (Int
w forall a. Num a => a -> a -> a
+ Int
1) (forall a. a -> Edit a
Del (forall a. Array a -> Int -> a
P.indexArray Array a
xs Int
n) forall a. a -> [a] -> [a]
: [Edit a]
edit)

            -- traceShowM ("cellC, cellL", n, cellC, cellL)

            forall acc s. acc -> (Int -> acc -> ST s acc) -> ST s ()
yLoop (Cell [Edit a]
cellC, Cell [Edit a]
cellL) forall a b. (a -> b) -> a -> b
$ \Int
m (Cell [Edit a]
cellC', Cell [Edit a]
cellL') -> do
                -- traceShowM ("inner loop", n, m)
                Cell [Edit a]
cellT <- forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> m a
P.readArray MutableArray s (Cell [Edit a])
prev Int
m

                -- traceShowM ("cellT", n, m, cellT)

                let x, y :: a
                    x :: a
x = forall a. Array a -> Int -> a
P.indexArray Array a
xs Int
n
                    y :: a
y = forall a. Array a -> Int -> a
P.indexArray Array a
ys Int
m

                -- from diagonal
                let cellX1 :: Cell [Edit a]
                    cellX1 :: Cell [Edit a]
cellX1
                        | a -> a -> Bool
eq a
x a
y    = forall a b. (Int -> Int) -> (a -> b) -> Cell a -> Cell b
bimap forall a. a -> a
id   (forall a. a -> Edit a
Cpy a
x forall a. a -> [a] -> [a]
:)   Cell [Edit a]
cellC'
                        | Bool
otherwise = forall a b. (Int -> Int) -> (a -> b) -> Cell a -> Cell b
bimap (forall a. Num a => a -> a -> a
+Int
1) (forall a. a -> a -> Edit a
Swp a
x a
y forall a. a -> [a] -> [a]
:) Cell [Edit a]
cellC'

                -- from left
                let cellX2 :: Cell [Edit a]
                    cellX2 :: Cell [Edit a]
cellX2 = forall a b. (Int -> Int) -> (a -> b) -> Cell a -> Cell b
bimap (forall a. Num a => a -> a -> a
+Int
1) (forall a. a -> Edit a
Ins a
y forall a. a -> [a] -> [a]
:) Cell [Edit a]
cellL'

                -- from top
                let cellX3 :: Cell [Edit a]
                    cellX3 :: Cell [Edit a]
cellX3 = forall a b. (Int -> Int) -> (a -> b) -> Cell a -> Cell b
bimap (forall a. Num a => a -> a -> a
+Int
1) (forall a. a -> Edit a
Del a
x forall a. a -> [a] -> [a]
:) Cell [Edit a]
cellT

                -- the actual cell is best of three
                let cellX :: Cell [Edit a]
                    cellX :: Cell [Edit a]
cellX = forall a. Cell a -> Cell a -> Cell a -> Cell a
bestOfThree Cell [Edit a]
cellX1 Cell [Edit a]
cellX2 Cell [Edit a]
cellX3

                -- traceShowM ("cellX", n, m, cellX)

                -- memoize
                forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> m ()
P.writeArray MutableArray s (Cell [Edit a])
curr Int
m Cell [Edit a]
cellX

                forall (m :: * -> *) a. Monad m => a -> m a
return (Cell [Edit a]
cellT, Cell [Edit a]
cellX)

            forall (m :: * -> *) a. Monad m => a -> m a
return (MutableArray s (Cell [Edit a])
curr, MutableArray s (Cell [Edit a])
prev, Cell [Edit a]
cellL)

        forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> m a
P.readArray MutableArray s (Cell [Edit a])
buf1final (Int
yn forall a. Num a => a -> a -> a
- Int
1)

    xLoop :: acc -> (Int -> acc -> ST s acc) -> ST s acc
    xLoop :: forall acc s. acc -> (Int -> acc -> ST s acc) -> ST s acc
xLoop !acc
acc0 Int -> acc -> ST s acc
f = acc -> Int -> ST s acc
go acc
acc0 Int
0 where
        go :: acc -> Int -> ST s acc
go !acc
acc !Int
n | Int
n forall a. Ord a => a -> a -> Bool
< Int
xn = do
            acc
acc' <- Int -> acc -> ST s acc
f Int
n acc
acc
            acc -> Int -> ST s acc
go acc
acc' (Int
n forall a. Num a => a -> a -> a
+ Int
1)
        go !acc
acc Int
_ = forall (m :: * -> *) a. Monad m => a -> m a
return acc
acc

    yLoop :: acc -> (Int -> acc -> ST s acc) -> ST s ()
    yLoop :: forall acc s. acc -> (Int -> acc -> ST s acc) -> ST s ()
yLoop !acc
acc0 Int -> acc -> ST s acc
f = acc -> Int -> ST s ()
go acc
acc0 Int
0 where
        go :: acc -> Int -> ST s ()
go !acc
acc !Int
m | Int
m forall a. Ord a => a -> a -> Bool
< Int
yn = do
            acc
acc' <- Int -> acc -> ST s acc
f Int
m acc
acc
            acc -> Int -> ST s ()
go acc
acc' (Int
m forall a. Num a => a -> a -> a
+ Int
1)
        go acc
_ Int
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ()

data Cell a = Cell !Int !a deriving Int -> Cell a -> ShowS
forall a. Show a => Int -> Cell a -> ShowS
forall a. Show a => [Cell a] -> ShowS
forall a. Show a => Cell a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Cell a] -> ShowS
$cshowList :: forall a. Show a => [Cell a] -> ShowS
show :: Cell a -> String
$cshow :: forall a. Show a => Cell a -> String
showsPrec :: Int -> Cell a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Cell a -> ShowS
Show

getCell :: Cell a -> a
getCell :: forall a. Cell a -> a
getCell (Cell Int
_ a
x) = a
x

bestOfThree :: Cell a -> Cell a -> Cell a -> Cell a
bestOfThree :: forall a. Cell a -> Cell a -> Cell a -> Cell a
bestOfThree a :: Cell a
a@(Cell Int
i a
_x) b :: Cell a
b@(Cell Int
j a
_y) c :: Cell a
c@(Cell Int
k a
_z)
    | Int
i forall a. Ord a => a -> a -> Bool
<= Int
j
    = if Int
i forall a. Ord a => a -> a -> Bool
<= Int
k then Cell a
a else Cell a
c

    | Bool
otherwise
    = if Int
j forall a. Ord a => a -> a -> Bool
<= Int
k then Cell a
b else Cell a
c

bimap :: (Int -> Int) -> (a -> b) -> Cell a -> Cell b
bimap :: forall a b. (Int -> Int) -> (a -> b) -> Cell a -> Cell b
bimap Int -> Int
f a -> b
g (Cell Int
i a
x) = forall a. Int -> a -> Cell a
Cell (Int -> Int
f Int
i) (a -> b
g a
x)