{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE PatternGuards #-}
module Plugin.Pl.Transform (
    transform,
  ) where

import Plugin.Pl.Common
import Plugin.Pl.PrettyPrinter ()

import qualified Data.Map as M

import Data.Graph (stronglyConnComp, flattenSCC, flattenSCCs)
import Control.Monad.Trans.State

{-
nub :: Ord a => [a] -> [a]
nub = nub' S.empty where
  nub' _ [] = []
  nub' set (x:xs)
    | x `S.member` set = nub' set xs
    | otherwise = x: nub' (x `S.insert` set) xs
-}

occursP :: String -> Pattern -> Bool
occursP :: String -> Pattern -> Bool
occursP String
v (PVar String
v') = String
v String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
v'
occursP String
v (PTuple Pattern
p1 Pattern
p2) = String
v String -> Pattern -> Bool
`occursP` Pattern
p1 Bool -> Bool -> Bool
|| String
v String -> Pattern -> Bool
`occursP` Pattern
p2
occursP String
v (PCons  Pattern
p1 Pattern
p2) = String
v String -> Pattern -> Bool
`occursP` Pattern
p1 Bool -> Bool -> Bool
|| String
v String -> Pattern -> Bool
`occursP` Pattern
p2

freeIn :: String -> Expr -> Int
freeIn :: String -> Expr -> Int
freeIn String
v (Var Fixity
_ String
v') = Bool -> Int
forall a. Enum a => a -> Int
fromEnum (Bool -> Int) -> Bool -> Int
forall a b. (a -> b) -> a -> b
$ String
v String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
v'
freeIn String
v (Lambda Pattern
pat Expr
e) = if String
v String -> Pattern -> Bool
`occursP` Pattern
pat then Int
0 else String -> Expr -> Int
freeIn String
v Expr
e
freeIn String
v (App Expr
e1 Expr
e2) = String -> Expr -> Int
freeIn String
v Expr
e1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ String -> Expr -> Int
freeIn String
v Expr
e2
freeIn String
v (Let [Decl]
ds Expr
e') = if String
v String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Decl -> String) -> [Decl] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Decl -> String
declName [Decl]
ds then Int
0 
  else String -> Expr -> Int
freeIn String
v Expr
e' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [String -> Expr -> Int
freeIn String
v Expr
e | Define String
_ Expr
e <- [Decl]
ds]

isFreeIn :: String -> Expr -> Bool
isFreeIn :: String -> Expr -> Bool
isFreeIn String
v Expr
e = String -> Expr -> Int
freeIn String
v Expr
e Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0

tuple :: [Expr] -> Expr
tuple :: [Expr] -> Expr
tuple [Expr]
es  = (Expr -> Expr -> Expr) -> [Expr] -> Expr
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 (\Expr
x Expr
y -> Fixity -> String -> Expr
Var Fixity
Inf String
"," Expr -> Expr -> Expr
`App` Expr
x Expr -> Expr -> Expr
`App` Expr
y) [Expr]
es

tupleP :: [String] -> Pattern
tupleP :: [String] -> Pattern
tupleP [String]
vs = (Pattern -> Pattern -> Pattern) -> [Pattern] -> Pattern
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 Pattern -> Pattern -> Pattern
PTuple ([Pattern] -> Pattern) -> [Pattern] -> Pattern
forall a b. (a -> b) -> a -> b
$ String -> Pattern
PVar (String -> Pattern) -> [String] -> [Pattern]
forall a b. (a -> b) -> [a] -> [b]
`map` [String]
vs

dependsOn :: [Decl] -> Decl -> [Decl]
dependsOn :: [Decl] -> Decl -> [Decl]
dependsOn [Decl]
ds Decl
d = [Decl
d' | Decl
d' <- [Decl]
ds, Decl -> String
declName Decl
d' String -> Expr -> Bool
`isFreeIn` Decl -> Expr
declExpr Decl
d]
  
unLet :: Expr -> Expr
unLet :: Expr -> Expr
unLet (App Expr
e1 Expr
e2) = Expr -> Expr -> Expr
App (Expr -> Expr
unLet Expr
e1) (Expr -> Expr
unLet Expr
e2)
unLet (Let [] Expr
e) = Expr -> Expr
unLet Expr
e
unLet (Let [Decl]
ds Expr
e) = Expr -> Expr
unLet (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
  (Pattern -> Expr -> Expr
Lambda ([String] -> Pattern
tupleP ([String] -> Pattern) -> [String] -> Pattern
forall a b. (a -> b) -> a -> b
$ Decl -> String
declName (Decl -> String) -> [Decl] -> [String]
forall a b. (a -> b) -> [a] -> [b]
`map` [Decl]
dsYes) ([Decl] -> Expr -> Expr
Let [Decl]
dsNo Expr
e)) Expr -> Expr -> Expr
`App`
    (Expr
fix' Expr -> Expr -> Expr
`App` (Pattern -> Expr -> Expr
Lambda ([String] -> Pattern
tupleP ([String] -> Pattern) -> [String] -> Pattern
forall a b. (a -> b) -> a -> b
$ Decl -> String
declName (Decl -> String) -> [Decl] -> [String]
forall a b. (a -> b) -> [a] -> [b]
`map` [Decl]
dsYes)
                        ([Expr] -> Expr
tuple  ([Expr] -> Expr) -> [Expr] -> Expr
forall a b. (a -> b) -> a -> b
$ Decl -> Expr
declExpr (Decl -> Expr) -> [Decl] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
`map` [Decl]
dsYes)))
    where
  comps :: [SCC Decl]
comps = [(Decl, Decl, [Decl])] -> [SCC Decl]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
stronglyConnComp [(Decl
d',Decl
d',[Decl] -> Decl -> [Decl]
dependsOn [Decl]
ds Decl
d') | Decl
d' <- [Decl]
ds]
  dsYes :: [Decl]
dsYes = SCC Decl -> [Decl]
forall vertex. SCC vertex -> [vertex]
flattenSCC (SCC Decl -> [Decl]) -> SCC Decl -> [Decl]
forall a b. (a -> b) -> a -> b
$ [SCC Decl] -> SCC Decl
forall a. [a] -> a
head [SCC Decl]
comps
  dsNo :: [Decl]
dsNo = [SCC Decl] -> [Decl]
forall a. [SCC a] -> [a]
flattenSCCs ([SCC Decl] -> [Decl]) -> [SCC Decl] -> [Decl]
forall a b. (a -> b) -> a -> b
$ [SCC Decl] -> [SCC Decl]
forall a. [a] -> [a]
tail [SCC Decl]
comps
  
unLet (Lambda Pattern
v Expr
e) = Pattern -> Expr -> Expr
Lambda Pattern
v (Expr -> Expr
unLet Expr
e)
unLet (Var Fixity
f String
x) = Fixity -> String -> Expr
Var Fixity
f String
x

type Env = (M.Map String String, Int)
-- note: The second component is the environment size, counting duplicate
-- variables.

-- It's a pity we still need that for the pointless transformation.
-- Otherwise a newly created id/const/... could be bound by a lambda
-- e.g. transform' (\id x -> x) ==> transform' (\id -> id) ==> id
alphaRename :: Expr -> Expr
alphaRename :: Expr -> Expr
alphaRename Expr
e = Expr -> State Env Expr
alpha Expr
e State Env Expr -> Env -> Expr
forall s a. State s a -> s -> a
`evalState` (Map String String
forall k a. Map k a
M.empty, Int
0) where
  alpha :: Expr -> State Env Expr
  alpha :: Expr -> State Env Expr
alpha (Var Fixity
f String
v) = do (Map String String
fm, Int
_) <- StateT Env Identity Env
forall (m :: * -> *) s. Monad m => StateT s m s
get; Expr -> State Env Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> State Env Expr) -> Expr -> State Env Expr
forall a b. (a -> b) -> a -> b
$ Fixity -> String -> Expr
Var Fixity
f (String -> Expr) -> String -> Expr
forall a b. (a -> b) -> a -> b
$ String -> (String -> String) -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
v String -> String
forall a. a -> a
id (String -> Map String String -> Maybe String
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
v Map String String
fm)
  alpha (App Expr
e1 Expr
e2) = (Expr -> Expr -> Expr)
-> State Env Expr -> State Env Expr -> State Env Expr
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Expr -> Expr -> Expr
App (Expr -> State Env Expr
alpha Expr
e1) (Expr -> State Env Expr
alpha Expr
e2)
  alpha (Let [Decl]
_ Expr
_) = Bool -> State Env Expr -> State Env Expr
forall a. (?callStack::CallStack) => Bool -> a -> a
assert Bool
False State Env Expr
forall a. a
bt
  alpha (Lambda Pattern
v Expr
e') = State Env Expr -> State Env Expr
forall s a. State s a -> State s a
inEnv (State Env Expr -> State Env Expr)
-> State Env Expr -> State Env Expr
forall a b. (a -> b) -> a -> b
$ (Pattern -> Expr -> Expr)
-> StateT Env Identity Pattern -> State Env Expr -> State Env Expr
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Pattern -> Expr -> Expr
Lambda (Pattern -> StateT Env Identity Pattern
forall (m :: * -> *) b.
(Monad m, Num b, Show b) =>
Pattern -> StateT (Map String String, b) m Pattern
alphaPat Pattern
v) (Expr -> State Env Expr
alpha Expr
e')

  -- act like a reader monad
  inEnv :: State s a -> State s a
  inEnv :: State s a -> State s a
inEnv State s a
f = (s -> a) -> State s a
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets ((s -> a) -> State s a) -> (s -> a) -> State s a
forall a b. (a -> b) -> a -> b
$ State s a -> s -> a
forall s a. State s a -> s -> a
evalState State s a
f

  alphaPat :: Pattern -> StateT (Map String String, b) m Pattern
alphaPat (PVar String
v) = do
    (Map String String
fm, b
i) <- StateT (Map String String, b) m (Map String String, b)
forall (m :: * -> *) s. Monad m => StateT s m s
get
    let v' :: String
v' = String
"$" String -> String -> String
forall a. [a] -> [a] -> [a]
++ b -> String
forall a. Show a => a -> String
show b
i
    (Map String String, b) -> StateT (Map String String, b) m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (String -> String -> Map String String -> Map String String
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert String
v String
v' Map String String
fm, b
ib -> b -> b
forall a. Num a => a -> a -> a
+b
1)
    Pattern -> StateT (Map String String, b) m Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> StateT (Map String String, b) m Pattern)
-> Pattern -> StateT (Map String String, b) m Pattern
forall a b. (a -> b) -> a -> b
$ String -> Pattern
PVar String
v'
  alphaPat (PTuple Pattern
p1 Pattern
p2) = (Pattern -> Pattern -> Pattern)
-> StateT (Map String String, b) m Pattern
-> StateT (Map String String, b) m Pattern
-> StateT (Map String String, b) m Pattern
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Pattern -> Pattern -> Pattern
PTuple (Pattern -> StateT (Map String String, b) m Pattern
alphaPat Pattern
p1) (Pattern -> StateT (Map String String, b) m Pattern
alphaPat Pattern
p2)
  alphaPat (PCons Pattern
p1 Pattern
p2) = (Pattern -> Pattern -> Pattern)
-> StateT (Map String String, b) m Pattern
-> StateT (Map String String, b) m Pattern
-> StateT (Map String String, b) m Pattern
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Pattern -> Pattern -> Pattern
PCons (Pattern -> StateT (Map String String, b) m Pattern
alphaPat Pattern
p1) (Pattern -> StateT (Map String String, b) m Pattern
alphaPat Pattern
p2)


transform :: Expr -> Expr
transform :: Expr -> Expr
transform = Expr -> Expr
transform' (Expr -> Expr) -> (Expr -> Expr) -> Expr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Expr
alphaRename (Expr -> Expr) -> (Expr -> Expr) -> Expr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Expr
unLet

-- Infinite generator of variable names.
varNames :: [String]
varNames :: [String]
varNames = (Int -> [String]) -> [Int] -> [String]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Int -> String -> [String]) -> String -> Int -> [String]
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> String -> [String]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM String
usableChars) [Int
1..]
  where
    usableChars :: String
usableChars = [Char
'a'..Char
'z']

-- First variable name not already in use
fresh :: [String] -> String
fresh :: [String] -> String
fresh [String]
variables = [String] -> String
forall a. [a] -> a
head ([String] -> String)
-> ([String] -> [String]) -> [String] -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> Bool) -> [String] -> [String]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (String -> Bool) -> String -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> [String] -> Bool) -> [String] -> String -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [String]
variables) ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ [String]
varNames

names :: Expr -> [String]
names :: Expr -> [String]
names (Var Fixity
_ String
str)     = [String
str]
-- Lambda pattern names are rewritten to be meaningless/unwritable, so we don't
-- need to include them here. Variables from lambdas used in expressions are
-- also rewritten, but there's no reason to special-case it unless it's provably
-- poor-performing to scan over the result in `fresh`, which I doubt it is.
names (Lambda Pattern
_ Expr
exp)  = Expr -> [String]
names Expr
exp
names (App Expr
exp1 Expr
exp2) = Expr -> [String]
names Expr
exp1 [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ Expr -> [String]
names Expr
exp2
names (Let [Decl]
dlcs Expr
exp)  = (Decl -> [String]) -> [Decl] -> [String]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Decl -> [String]
dnames [Decl]
dlcs [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ Expr -> [String]
names Expr
exp
  where
    dnames :: Decl -> [String]
dnames (Define String
nm Expr
exp) = String
nm String -> [String] -> [String]
forall a. a -> [a] -> [a]
: Expr -> [String]
names Expr
exp

transform' :: Expr -> Expr
transform' :: Expr -> Expr
transform' Expr
exp = Expr -> Expr
go Expr
exp
  where
    -- Explicit sharing for readability
    vars :: [String]
vars = Expr -> [String]
names Expr
exp

    go :: Expr -> Expr
go (Let {}) =
      Bool -> Expr -> Expr
forall a. (?callStack::CallStack) => Bool -> a -> a
assert Bool
False Expr
forall a. a
bt
    go (Var Fixity
f String
v) =
      Fixity -> String -> Expr
Var Fixity
f String
v
    go (App Expr
e1 Expr
e2) =
      Expr -> Expr -> Expr
App (Expr -> Expr
go Expr
e1) (Expr -> Expr
go Expr
e2)
    go (Lambda (PTuple Pattern
p1 Pattern
p2) Expr
e) =
      Expr -> Expr
go (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
        Pattern -> Expr -> Expr
Lambda (String -> Pattern
PVar String
var) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ (Pattern -> Expr -> Expr
Lambda Pattern
p1 (Expr -> Expr) -> (Expr -> Expr) -> Expr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> Expr -> Expr
Lambda Pattern
p2 (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Expr
e) Expr -> Expr -> Expr
`App` Expr
f Expr -> Expr -> Expr
`App` Expr
s
      where
        var :: String
var   = [String] -> String
fresh [String]
vars
        f :: Expr
f     = Fixity -> String -> Expr
Var Fixity
Pref String
"fst" Expr -> Expr -> Expr
`App` Fixity -> String -> Expr
Var Fixity
Pref String
var
        s :: Expr
s     = Fixity -> String -> Expr
Var Fixity
Pref String
"snd" Expr -> Expr -> Expr
`App` Fixity -> String -> Expr
Var Fixity
Pref String
var
    go (Lambda (PCons Pattern
p1 Pattern
p2) Expr
e) =
      Expr -> Expr
go (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
        Pattern -> Expr -> Expr
Lambda (String -> Pattern
PVar String
var) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ (Pattern -> Expr -> Expr
Lambda Pattern
p1 (Expr -> Expr) -> (Expr -> Expr) -> Expr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> Expr -> Expr
Lambda Pattern
p2 (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Expr
e) Expr -> Expr -> Expr
`App` Expr
f Expr -> Expr -> Expr
`App` Expr
s
      where
        var :: String
var = [String] -> String
fresh [String]
vars
        f :: Expr
f   = Fixity -> String -> Expr
Var Fixity
Pref String
"head" Expr -> Expr -> Expr
`App` Fixity -> String -> Expr
Var Fixity
Pref String
var
        s :: Expr
s   = Fixity -> String -> Expr
Var Fixity
Pref String
"tail" Expr -> Expr -> Expr
`App` Fixity -> String -> Expr
Var Fixity
Pref String
var
    go (Lambda (PVar String
v) Expr
e) =
      Expr -> Expr
go (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
getRidOfV Expr
e
      where
        getRidOfV :: Expr -> Expr
getRidOfV (Var Fixity
f String
v') | String
v String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
v'   = Expr
id'
                             | Bool
otherwise = Expr
const' Expr -> Expr -> Expr
`App` Fixity -> String -> Expr
Var Fixity
f String
v'
        getRidOfV l :: Expr
l@(Lambda Pattern
pat Expr
_) =
          Bool -> Expr -> Expr
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ String
v String -> Pattern -> Bool
`occursP` Pattern
pat) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
getRidOfV (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
go Expr
l
        getRidOfV (Let {}) = Bool -> Expr -> Expr
forall a. (?callStack::CallStack) => Bool -> a -> a
assert Bool
False Expr
forall a. a
bt
        getRidOfV e' :: Expr
e'@(App Expr
e1 Expr
e2)
          | Bool
fr1 Bool -> Bool -> Bool
&& Bool
fr2 = Expr
scomb Expr -> Expr -> Expr
`App` Expr -> Expr
getRidOfV Expr
e1 Expr -> Expr -> Expr
`App` Expr -> Expr
getRidOfV Expr
e2
          | Bool
fr1 = Expr
flip' Expr -> Expr -> Expr
`App` Expr -> Expr
getRidOfV Expr
e1 Expr -> Expr -> Expr
`App` Expr
e2
          | Var Fixity
_ String
v' <- Expr
e2, String
v' String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
v = Expr
e1
          | Bool
fr2 = Expr
comp Expr -> Expr -> Expr
`App` Expr
e1 Expr -> Expr -> Expr
`App` Expr -> Expr
getRidOfV Expr
e2
          | Bool
True = Expr
const' Expr -> Expr -> Expr
`App` Expr
e'
          where
            fr1 :: Bool
fr1 = String
v String -> Expr -> Bool
`isFreeIn` Expr
e1
            fr2 :: Bool
fr2 = String
v String -> Expr -> Bool
`isFreeIn` Expr
e2