{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns    #-}

module Ide.Plugin.CallHierarchy.Query (
  incomingCalls
, outgoingCalls
, getSymbolPosition
) where

import qualified Data.Text                      as T
import           Database.SQLite.Simple
import           Development.IDE.GHC.Compat
import           HieDb                          (HieDb (getConn), Symbol (..))
import           Ide.Plugin.CallHierarchy.Types
import           Prelude                        hiding (mod)

incomingCalls :: HieDb -> Symbol -> IO [Vertex]
incomingCalls :: HieDb -> Symbol -> IO [Vertex]
incomingCalls (HieDb -> Connection
getConn -> Connection
conn) Symbol
symbol = do
    let (OccName
o, ModuleName
m, Unit
u) = Symbol -> (OccName, ModuleName, Unit)
parseSymbol Symbol
symbol
    Connection
-> Query
-> (OccName, ModuleName, Unit, OccName, ModuleName, Unit)
-> IO [Vertex]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
query Connection
conn
        (Text -> Query
Query (Text -> Query) -> Text -> Query
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
            [ String
"SELECT mods.mod, decls.occ, mods.hs_src, decls.sl, decls.sc, "
            , String
"decls.el, decls.ec, refs.sl, refs.sc, refs.el, refs.ec "
            , String
"FROM refs "
            , String
"JOIN decls ON decls.hieFile = refs.hieFile "
            , String
"JOIN mods ON mods.hieFile = decls.hieFile "
            , String
"where "
            , String
"(refs.occ = ? AND refs.mod = ? AND refs.unit = ?) "
            , String
"AND "
            , String
"(decls.occ != ? OR mods.mod != ? OR mods.unit != ?) "
            , String
"AND "
            , String
"((refs.sl = decls.sl AND refs.sc > decls.sc) OR (refs.sl > decls.sl)) "
            , String
"AND "
            ,String
"((refs.el = decls.el AND refs.ec <= decls.ec) OR (refs.el < decls.el))"
            ]
        ) (OccName
o, ModuleName
m, Unit
u, OccName
o, ModuleName
m, Unit
u)

outgoingCalls :: HieDb -> Symbol -> IO [Vertex]
outgoingCalls :: HieDb -> Symbol -> IO [Vertex]
outgoingCalls (HieDb -> Connection
getConn -> Connection
conn) Symbol
symbol = do
    let (OccName
o, ModuleName
m, Unit
u) = Symbol -> (OccName, ModuleName, Unit)
parseSymbol Symbol
symbol
    Connection
-> Query
-> (OccName, ModuleName, Unit, OccName, ModuleName, Unit)
-> IO [Vertex]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
query Connection
conn
        (Text -> Query
Query (Text -> Query) -> Text -> Query
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
            [ String
"SELECT rm.mod, defs.occ, rm.hs_src, defs.sl, defs.sc, defs.el, defs.ec, "
            , String
"refs.sl, refs.sc, refs.el, refs.ec "
            , String
"from refs "
            , String
"JOIN defs ON defs.occ = refs.occ "
            , String
"JOIN decls rd ON rd.hieFile = defs.hieFile AND rd.occ = defs.occ "
            , String
"JOIN mods rm ON rm.mod = refs.mod AND rm.unit = refs.unit AND rm.hieFile = defs.hieFile "
            , String
"JOIN decls ON decls.hieFile = refs.hieFile "
            , String
"JOIN mods ON mods.hieFile = decls.hieFile "
            , String
"where "
            , String
"(decls.occ = ? AND mods.mod = ? AND mods.unit = ?) "
            , String
"AND "
            , String
"(defs.occ != ? OR rm.mod != ? OR rm.unit != ?) "
            , String
"AND "
            , String
"((refs.sl = decls.sl AND refs.sc >  decls.sc) OR (refs.sl > decls.sl)) "
            , String
"AND "
            , String
"((refs.el = decls.el AND refs.ec <= decls.ec) OR (refs.el < decls.el))"
            ]
        ) (OccName
o, ModuleName
m, Unit
u, OccName
o, ModuleName
m, Unit
u)

getSymbolPosition :: HieDb -> Vertex -> IO [SymbolPosition]
getSymbolPosition :: HieDb -> Vertex -> IO [SymbolPosition]
getSymbolPosition (HieDb -> Connection
getConn -> Connection
conn) Vertex{Int
String
mod :: String
occ :: String
hieSrc :: String
sl :: Int
sc :: Int
el :: Int
ec :: Int
casl :: Int
casc :: Int
cael :: Int
caec :: Int
$sel:mod:Vertex :: Vertex -> String
$sel:occ:Vertex :: Vertex -> String
$sel:hieSrc:Vertex :: Vertex -> String
$sel:sl:Vertex :: Vertex -> Int
$sel:sc:Vertex :: Vertex -> Int
$sel:el:Vertex :: Vertex -> Int
$sel:ec:Vertex :: Vertex -> Int
$sel:casl:Vertex :: Vertex -> Int
$sel:casc:Vertex :: Vertex -> Int
$sel:cael:Vertex :: Vertex -> Int
$sel:caec:Vertex :: Vertex -> Int
..} = do
    Connection
-> Query
-> (String, Int, Int, Int, Int, Int, Int)
-> IO [SymbolPosition]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
query Connection
conn
        (Text -> Query
Query (Text -> Query) -> Text -> Query
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
            [ String
"SELECT refs.sl, refs.sc from refs where "
            , String
"(occ = ?) "
            , String
"AND "
            , String
"((refs.sl = ? AND refs.sc > ?) OR (refs.sl > ?)) "
            , String
"AND "
            , String
"((refs.el = ? AND refs.ec <= ?) OR (refs.el < ?))"
            ]
        ) (String
occ, Int
sl, Int
sc, Int
sl, Int
el, Int
ec, Int
el)

parseSymbol :: Symbol -> (OccName, ModuleName, Unit)
parseSymbol :: Symbol -> (OccName, ModuleName, Unit)
parseSymbol Symbol{OccName
Module
symName :: OccName
symModule :: Module
symName :: Symbol -> OccName
symModule :: Symbol -> Module
..} =
    let o :: OccName
o = OccName
symName
        m :: ModuleName
m = Module -> ModuleName
forall unit. GenModule unit -> ModuleName
moduleName Module
symModule
        u :: Unit
u = Module -> Unit
forall unit. GenModule unit -> unit
moduleUnit Module
symModule
    in  (OccName
o, ModuleName
m, Unit
u)