module Kempe.IR.Opt ( optimize
                    ) where

import           Kempe.IR.Type

optimize :: [Stmt] -> [Stmt]
optimize :: [Stmt] -> [Stmt]
optimize = [Stmt] -> [Stmt]
sameTarget ([Stmt] -> [Stmt]) -> ([Stmt] -> [Stmt]) -> [Stmt] -> [Stmt]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stmt] -> [Stmt]
successiveBumps ([Stmt] -> [Stmt]) -> ([Stmt] -> [Stmt]) -> [Stmt] -> [Stmt]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stmt] -> [Stmt]
successiveBumps ([Stmt] -> [Stmt]) -> ([Stmt] -> [Stmt]) -> [Stmt] -> [Stmt]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stmt] -> [Stmt]
removeNop

-- | Often IR generation will leave us with something like
--
-- > (movtemp datapointer (+ (reg datapointer) (int 8)))
-- > (movtemp datapointer (- (reg datapointer) (int 8)))
--
-- i.e. push a value and immediately pop it for use.
--
-- This is silly and we remove it in this pass.
--
-- Also take the opportunity to simplify stuff like
--
-- > (movmem (- (reg datapointer) (int 8)) (mem [8] (- (reg datapointer) (int 0))))
-- > (movmem (- (reg datapointer) (int 0)) (mem [8] (- (reg datapointer) (int 8))))
successiveBumps :: [Stmt] -> [Stmt]
successiveBumps :: [Stmt] -> [Stmt]
successiveBumps [] = []
successiveBumps
    ((MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntPlusIR (Reg Temp
DataPointer) (ConstInt Int64
i)))
        :(MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntMinusIR (Reg Temp
DataPointer) (ConstInt Int64
i')))
        :[Stmt]
ss) | Int64
i Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
i' = [Stmt] -> [Stmt]
successiveBumps [Stmt]
ss
successiveBumps
    ((MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntMinusIR (Reg Temp
DataPointer) (ConstInt Int64
i)))
        :(MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntPlusIR (Reg Temp
DataPointer) (ConstInt Int64
i')))
        :[Stmt]
ss) | Int64
i Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
i' = [Stmt] -> [Stmt]
successiveBumps [Stmt]
ss
successiveBumps
    ((MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntPlusIR (Reg Temp
DataPointer) (ConstInt Int64
i)))
        :(MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntPlusIR (Reg Temp
DataPointer) (ConstInt Int64
i')))
        :[Stmt]
ss) =
            Temp -> Exp -> Stmt
MovTemp Temp
DataPointer (IntBinOp -> Exp -> Exp -> Exp
ExprIntBinOp IntBinOp
IntPlusIR (Temp -> Exp
Reg Temp
DataPointer) (Int64 -> Exp
ConstInt (Int64 -> Exp) -> Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Int64
iInt64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+Int64
i')) Stmt -> [Stmt] -> [Stmt]
forall a. a -> [a] -> [a]
: [Stmt] -> [Stmt]
successiveBumps [Stmt]
ss
successiveBumps
    ((MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntMinusIR (Reg Temp
DataPointer) (ConstInt Int64
i)))
        :(MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntMinusIR (Reg Temp
DataPointer) (ConstInt Int64
i')))
        :[Stmt]
ss) =
            Temp -> Exp -> Stmt
MovTemp Temp
DataPointer (IntBinOp -> Exp -> Exp -> Exp
ExprIntBinOp IntBinOp
IntMinusIR (Temp -> Exp
Reg Temp
DataPointer) (Int64 -> Exp
ConstInt (Int64 -> Exp) -> Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Int64
iInt64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+Int64
i')) Stmt -> [Stmt] -> [Stmt]
forall a. a -> [a] -> [a]
: [Stmt] -> [Stmt]
successiveBumps [Stmt]
ss
successiveBumps
    ((MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntPlusIR (Reg Temp
DataPointer) (ConstInt Int64
i)))
        :(MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntMinusIR (Reg Temp
DataPointer) (ConstInt Int64
i')))
        :[Stmt]
ss) =
            Temp -> Exp -> Stmt
MovTemp Temp
DataPointer (IntBinOp -> Exp -> Exp -> Exp
ExprIntBinOp IntBinOp
IntMinusIR (Temp -> Exp
Reg Temp
DataPointer) (Int64 -> Exp
ConstInt (Int64 -> Exp) -> Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Int64
iInt64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
-Int64
i')) Stmt -> [Stmt] -> [Stmt]
forall a. a -> [a] -> [a]
: [Stmt] -> [Stmt]
successiveBumps [Stmt]
ss
successiveBumps
    ((MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntMinusIR (Reg Temp
DataPointer) (ConstInt Int64
i)))
        :(MovTemp Temp
DataPointer (ExprIntBinOp IntBinOp
IntPlusIR (Reg Temp
DataPointer) (ConstInt Int64
i')))
        :[Stmt]
ss) =
            Temp -> Exp -> Stmt
MovTemp Temp
DataPointer (IntBinOp -> Exp -> Exp -> Exp
ExprIntBinOp IntBinOp
IntMinusIR (Temp -> Exp
Reg Temp
DataPointer) (Int64 -> Exp
ConstInt (Int64 -> Exp) -> Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Int64
i'Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
-Int64
i)) Stmt -> [Stmt] -> [Stmt]
forall a. a -> [a] -> [a]
: [Stmt] -> [Stmt]
successiveBumps [Stmt]
ss
successiveBumps
    (st :: Stmt
st@(MovMem Exp
e0 Int64
k (Mem Int64
8 Exp
e1))
        :(MovMem Exp
e0' Int64
k' (Mem Int64
8 Exp
e1'))
        :[Stmt]
ss) | Int64
k Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
k' Bool -> Bool -> Bool
&& Exp
e0 Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== Exp
e1' Bool -> Bool -> Bool
&& Exp
e1 Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== Exp
e0' = Stmt
st Stmt -> [Stmt] -> [Stmt]
forall a. a -> [a] -> [a]
: [Stmt] -> [Stmt]
successiveBumps [Stmt]
ss
successiveBumps (Stmt
s:[Stmt]
ss) = Stmt
s Stmt -> [Stmt] -> [Stmt]
forall a. a -> [a] -> [a]
: [Stmt] -> [Stmt]
successiveBumps [Stmt]
ss

-- | Stuff like
--
-- > (movmem (- (reg datapointer) (int 8)) (mem [8] (- (reg datapointer) (int 0))))
-- > (movmem (- (reg datapointer) (int 8)) (mem [8] (- (reg datapointer) (int 16))))
--
-- Basically if two successive 'Stmt's write to the same location, only bother
-- with the second one.
sameTarget :: [Stmt] -> [Stmt]
sameTarget :: [Stmt] -> [Stmt]
sameTarget [] = []
sameTarget
    ((MovMem Exp
e0 Int64
k Exp
_)
        :st :: Stmt
st@(MovMem Exp
e0' Int64
k' Exp
_)
        :[Stmt]
ss) | Int64
k Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
k' Bool -> Bool -> Bool
&& Exp
e0 Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== Exp
e0' = Stmt
st Stmt -> [Stmt] -> [Stmt]
forall a. a -> [a] -> [a]
: [Stmt] -> [Stmt]
sameTarget [Stmt]
ss
sameTarget (Stmt
s:[Stmt]
ss) = Stmt
s Stmt -> [Stmt] -> [Stmt]
forall a. a -> [a] -> [a]
: [Stmt] -> [Stmt]
sameTarget [Stmt]
ss

removeNop :: [Stmt] -> [Stmt]
removeNop :: [Stmt] -> [Stmt]
removeNop = (Stmt -> Bool) -> [Stmt] -> [Stmt]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Stmt -> Bool) -> Stmt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stmt -> Bool
isNop)
    where
        isNop :: Stmt -> Bool
isNop (MovTemp Temp
e (ExprIntBinOp IntBinOp
IntPlusIR (Reg Temp
e') (ConstInt Int64
0))) | Temp
e Temp -> Temp -> Bool
forall a. Eq a => a -> a -> Bool
== Temp
e' = Bool
True
        isNop (MovTemp Temp
e (ExprIntBinOp IntBinOp
IntMinusIR (Reg Temp
e') (ConstInt Int64
0))) | Temp
e Temp -> Temp -> Bool
forall a. Eq a => a -> a -> Bool
== Temp
e' = Bool
True
        isNop (MovMem Exp
e Int64
_ (Mem Int64
_ Exp
e')) | Exp
e Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== Exp
e' = Bool
True -- the Eq on Exp is kinda weird, but if the syntax trees are the same then they're certainly equivalent semantically
        isNop Stmt
_ = Bool
False