-- Copyright (c) 2017 Uber Technologies, Inc. -- -- Permission is hereby granted, free of charge, to any person obtaining a copy -- of this software and associated documentation files (the "Software"), to deal -- in the Software without restriction, including without limitation the rights -- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -- copies of the Software, and to permit persons to whom the Software is -- furnished to do so, subject to the following conditions: -- -- The above copyright notice and this permission notice shall be included in -- all copies or substantial portions of the Software. -- -- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -- THE SOFTWARE. {-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} module Database.Sql.Util.Joins (HasJoins(..), JoinsResult) where import Database.Sql.Type import qualified Data.Map as M import Data.Map (Map) import qualified Data.Set as S import Data.Set (Set) import Data.Semigroup import Data.Functor.Identity import Data.Foldable import Control.Monad (void, when) import Control.Monad.Writer (Writer, execWriter, tell) data Result = Result { resultBindings :: Map ColumnAliasId (Map (RColumnRef ()) FieldChain) , resultColumns :: Set (Map (RColumnRef ()) FieldChain) } instance Monoid Result where mempty = Result mempty mempty mappend (Result bindings columns) (Result bindings' columns') = Result (bindings <> bindings') (columns <> columns') -- Relationship observed between two columns type Join = ((FullyQualifiedColumnName, [StructFieldName ()]), (FullyQualifiedColumnName, [StructFieldName ()])) type Scoped a = Writer Result a type JoinsResult = Set Join class HasJoins q where getJoins :: q -> Set Join instance HasJoins (Statement d ResolvedNames a) where getJoins stmt = let Result{..} = execWriter $ getJoinsStatement stmt unalias :: Map (RColumnRef ()) FieldChain -> Map (FQColumnName ()) FieldChain unalias m = M.fromList $ M.toList m >>= \case (RColumnRef fqcn, chain) -> [(fqcn, chain)] (RColumnAlias (ColumnAlias _ _ aliasId), _) -> maybe [] (M.toList . unalias) $ M.lookup aliasId resultBindings sets = S.map unalias resultColumns toPairs m | M.null m = [] | otherwise = do let ((c@(QColumnName _ (Identity table) _), chain), m') = M.deleteFindMin m pairs = do (c'@(QColumnName _ (Identity table') _), chain') <- M.toList m' fields <- expandChain chain fields' <- expandChain chain' if table /= table' then [((fqcnToFQCN c, fields), (fqcnToFQCN c', fields'))] else [] pairs ++ toPairs m' in S.fromList $ toPairs =<< S.toList sets where expandChain (FieldChain m) | M.null m = [[]] | otherwise = do (k, v) <- M.toList m (k:) <$> expandChain v getJoinsStatement :: Statement d ResolvedNames a -> Scoped () getJoinsStatement (QueryStmt query) = void $ getJoinsQuery query getJoinsStatement (InsertStmt insert) = getJoinsInsert insert getJoinsStatement (UpdateStmt update) = getJoinsUpdate update getJoinsStatement (DeleteStmt delete) = getJoinsDelete delete getJoinsStatement (TruncateStmt _) = pure () getJoinsStatement (CreateTableStmt create) = getJoinsCreateTable create getJoinsStatement (AlterTableStmt _) = pure () getJoinsStatement (DropTableStmt _) = pure () getJoinsStatement (CreateViewStmt create) = void $ getJoinsQuery $ createViewQuery create getJoinsStatement (DropViewStmt _) = pure () getJoinsStatement (CreateSchemaStmt _) = pure () getJoinsStatement (GrantStmt _) = pure () getJoinsStatement (RevokeStmt _) = pure () getJoinsStatement (BeginStmt _) = pure () getJoinsStatement (CommitStmt _) = pure () getJoinsStatement (RollbackStmt _) = pure () getJoinsStatement (ExplainStmt _ _) = pure () getJoinsStatement (EmptyStmt _) = pure () queryColumns :: Query ResolvedNames a -> [RColumnRef a] queryColumns (QueryExcept _ _ query _) = queryColumns query queryColumns (QueryUnion _ _ _ query _) = queryColumns query queryColumns (QueryIntersect _ _ query _) = queryColumns query queryColumns (QueryWith _ _ query) = queryColumns query queryColumns (QueryOrder _ _ query) = queryColumns query queryColumns (QueryLimit _ _ query) = queryColumns query queryColumns (QueryOffset _ _ query) = queryColumns query queryColumns (QuerySelect _ Select{selectCols = SelectColumns _ selections}) = selections >>= \case SelectExpr _ aliases _ -> map RColumnAlias aliases SelectStar _ _ (StarColumnNames cols) -> cols getJoinsCreateTable :: CreateTable d ResolvedNames a -> Scoped () getJoinsCreateTable CreateTable{..} = getJoinsTableDefinition createTableDefinition -- TODO - "join" for columns producing columns? Assuming no... -- note that defaults cannot reference other tables in Vertica, possibly in other dialects getJoinsTableDefinition :: TableDefinition d ResolvedNames a -> Scoped () getJoinsTableDefinition (TableColumns _ _) = pure () getJoinsTableDefinition (TableLike _ _) = pure () getJoinsTableDefinition (TableAs _ _ query) = void $ getJoinsQuery query getJoinsTableDefinition (TableNoColumnInfo _) = pure () getJoinsInsert :: Insert ResolvedNames a -> Scoped () getJoinsInsert Insert{..} = case insertValues of InsertDefaultValues _ -> pure () InsertExprValues _ values -> mapM_ (mapM_ getJoinsDefaultExpr) values InsertSelectValues query -> void $ getJoinsQuery query InsertDataFromFile _ _ -> pure () getJoinsDefaultExpr :: DefaultExpr ResolvedNames a -> Scoped () getJoinsDefaultExpr (DefaultValue _) = pure () getJoinsDefaultExpr (ExprValue expr) = void $ getJoinsExpr expr getJoinsUpdate :: Update ResolvedNames a -> Scoped () getJoinsUpdate Update{..} = do mapM_ (getJoinsDefaultExpr . snd) updateSetExprs mapM_ getJoinsTablish updateFrom mapM_ getJoinsExpr updateWhere getJoinsDelete :: Delete ResolvedNames a -> Scoped () getJoinsDelete (Delete _ _ (Just expr)) = void $ getJoinsExpr expr getJoinsDelete (Delete _ _ Nothing) = pure () zipColumns :: Query ResolvedNames a -> Query ResolvedNames a -> Scoped () zipColumns lhs rhs = do let lcolumns = queryColumns lhs rcolumns = queryColumns rhs forM_ (zip lcolumns rcolumns) $ \ (lcol, rcol) -> emit $ M.fromSet (const $ FieldChain M.empty) $ S.fromList [void lcol, void rcol] getJoinsQuery :: Query ResolvedNames a -> Scoped () getJoinsQuery (QuerySelect _ select) = getJoinsSelect select getJoinsQuery (QueryExcept _ _ lhs rhs) = do getJoinsQuery lhs getJoinsQuery rhs zipColumns lhs rhs getJoinsQuery (QueryUnion _ _ _ lhs rhs) = do getJoinsQuery lhs getJoinsQuery rhs zipColumns lhs rhs getJoinsQuery (QueryIntersect _ _ lhs rhs) = do getJoinsQuery lhs getJoinsQuery rhs zipColumns lhs rhs getJoinsQuery (QueryWith _ ctes query) = do mapM_ getJoinsCTE ctes getJoinsQuery query getJoinsQuery (QueryOrder _ orders query) = do mapM_ getJoinsOrder orders getJoinsQuery query getJoinsQuery (QueryLimit _ _ query) = getJoinsQuery query getJoinsQuery (QueryOffset _ _ query) = getJoinsQuery query getJoinsSelect :: Select ResolvedNames a -> Scoped () getJoinsSelect (Select{..}) = do getJoinsSelectCols selectCols maybe (pure ()) getJoinsSelectFrom selectFrom maybe (pure ()) getJoinsSelectWhere selectWhere maybe (pure ()) getJoinsSelectTimeseries selectTimeseries maybe (pure ()) getJoinsSelectGroup selectGroup maybe (pure ()) getJoinsSelectHaving selectHaving maybe (pure ()) getJoinsSelectNamedWindow selectNamedWindow getJoinsSelectFrom :: SelectFrom ResolvedNames a -> Scoped () getJoinsSelectFrom (SelectFrom _ tablishes) = mapM_ getJoinsTablish tablishes getJoinsSelectCols :: SelectColumns ResolvedNames a -> Scoped () getJoinsSelectCols (SelectColumns _ selections) = mapM_ getJoinsSelection selections getJoinsSelectWhere :: SelectWhere ResolvedNames a -> Scoped () getJoinsSelectWhere (SelectWhere _ expr) = void $ getJoinsExpr expr getJoinsSelectTimeseries :: SelectTimeseries ResolvedNames a -> Scoped () getJoinsSelectTimeseries (SelectTimeseries _ _ _ partition expr) = do maybe (pure ()) getJoinsPartition partition void $ getJoinsExpr expr getJoinsPositionOrExpr :: PositionOrExpr ResolvedNames a -> Scoped () getJoinsPositionOrExpr (PositionOrExprPosition _ _ _) = pure () getJoinsPositionOrExpr (PositionOrExprExpr expr) = void $ getJoinsExpr expr getJoinsGroupingElement :: GroupingElement ResolvedNames a -> Scoped () getJoinsGroupingElement (GroupingElementExpr _ posOrExpr) = getJoinsPositionOrExpr posOrExpr getJoinsGroupingElement (GroupingElementSet _ exprs) = mapM_ getJoinsExpr exprs getJoinsSelectGroup :: SelectGroup ResolvedNames a -> Scoped () getJoinsSelectGroup (SelectGroup _ groupingElements) = mapM_ getJoinsGroupingElement groupingElements getJoinsSelectHaving :: SelectHaving ResolvedNames a -> Scoped () getJoinsSelectHaving (SelectHaving _ exprs) = mapM_ getJoinsExpr exprs getJoinsSelectNamedWindow :: SelectNamedWindow ResolvedNames a -> Scoped () getJoinsSelectNamedWindow (SelectNamedWindow _ windows) = mapM_ joins windows where joins (NamedWindowExpr _ _ windowExpr) = getJoinsWindowExpr windowExpr joins (NamedPartialWindowExpr _ _ partialWindowExpr) = getJoinsPartialWindowExpr partialWindowExpr emit :: Map (RColumnRef ()) FieldChain -> Scoped () emit cols = tell $ mempty { resultColumns = S.singleton cols } bind :: ColumnAliasId -> Map (RColumnRef ()) FieldChain -> Scoped () bind alias cols = tell $ mempty { resultBindings = M.singleton alias cols } getJoinsExpr :: Expr ResolvedNames a -> Scoped (Map (RColumnRef ()) FieldChain) getJoinsExpr (BinOpExpr _ op lhs rhs) = do lcols <- getJoinsExpr lhs rcols <- getJoinsExpr rhs let allcols = M.unionWith (<>) lcols rcols when (op `elem` ["=", "!=", "<>", "<=>", "==", "<", ">", "<=", ">="]) $ do emit allcols return allcols getJoinsExpr (CaseExpr _ cases else_) = do cols <- mapM (\ (when_, then_) -> getJoinsExpr when_ *> getJoinsExpr then_) cases col <- maybe (pure M.empty) getJoinsExpr else_ return $ M.unionsWith (<>) $ col : cols getJoinsExpr (LikeExpr _ _ escape pattern expr) = do void $ maybe (pure mempty) (getJoinsExpr . escapeExpr) escape lcols <- getJoinsExpr $ patternExpr pattern rcols <- getJoinsExpr expr let allcols = M.unionWith (<>) lcols rcols emit allcols return allcols getJoinsExpr (UnOpExpr _ _ expr) = getJoinsExpr expr getJoinsExpr (ConstantExpr _ _) = return M.empty getJoinsExpr (ColumnExpr _ column) = return $ M.singleton (void column) $ FieldChain M.empty getJoinsExpr (InListExpr _ exprs expr) = do cols <- M.unionsWith (<>) <$> mapM getJoinsExpr (expr:exprs) emit cols return cols getJoinsExpr (InSubqueryExpr _ query expr) = do getJoinsQuery query let [column] = queryColumns query columns <- getJoinsExpr expr let columns' = M.insert (void column) (FieldChain M.empty) columns emit columns' return columns' getJoinsExpr (BetweenExpr _ expr start end) = M.unionsWith (<>) <$> mapM getJoinsExpr [expr, start, end] getJoinsExpr (OverlapsExpr _ (r1start, r1end) (r2start, r2end)) = M.unionsWith (<>) <$> mapM getJoinsExpr [r1start, r1end, r2start, r2end] getJoinsExpr (FunctionExpr _ _ _ args params mFilter mOver) = do cols <- M.unionsWith (<>) <$> mapM getJoinsExpr (args ++ map snd params) maybe (pure mempty) getJoinsFilter mFilter maybe (pure mempty) getJoinsOverSubExpr mOver return cols getJoinsExpr (AtTimeZoneExpr _ ts tz) = M.unionWith (<>) <$> getJoinsExpr ts <*> getJoinsExpr tz getJoinsExpr (SubqueryExpr _ query) = do getJoinsQuery query let [column] = queryColumns query pure $ M.singleton (void column) $ FieldChain M.empty getJoinsExpr (ExistsExpr _ query) = do _ <- getJoinsQuery query return M.empty getJoinsExpr (ArrayExpr _ values) = M.unionsWith (<>) <$> mapM getJoinsExpr values getJoinsExpr (FieldAccessExpr _ expr field) = go expr $ FieldChain $ M.singleton (void field) $ FieldChain M.empty where go (ColumnExpr _ ref@(RColumnRef _)) chain = return $ M.singleton (void ref) chain go (FieldAccessExpr _ expr' field') chain = go expr' $ FieldChain $ M.singleton (void field') chain go expr' _ = getJoinsExpr expr' getJoinsExpr (ArrayAccessExpr _ expr index) = M.unionsWith (<>) <$> mapM getJoinsExpr [expr, index] getJoinsExpr (TypeCastExpr _ _ expr _) = getJoinsExpr expr getJoinsExpr (VariableSubstitutionExpr _) = return M.empty getJoinsFilter :: Filter ResolvedNames a -> Scoped () getJoinsFilter (Filter _ expr) = void $ getJoinsExpr expr getJoinsOverSubExpr :: OverSubExpr ResolvedNames a -> Scoped () getJoinsOverSubExpr (OverWindowExpr _ windowExpr) = getJoinsWindowExpr windowExpr getJoinsOverSubExpr (OverWindowName _ _) = pure () getJoinsOverSubExpr (OverPartialWindowExpr _ partial) = getJoinsPartialWindowExpr partial getJoinsWindowExpr :: WindowExpr ResolvedNames a -> Scoped () getJoinsWindowExpr (WindowExpr _ p os _) = do maybe (pure ()) getJoinsPartition p mapM_ getJoinsOrder os getJoinsPartialWindowExpr :: PartialWindowExpr ResolvedNames a -> Scoped () getJoinsPartialWindowExpr (PartialWindowExpr _ _ p os _) = do maybe (pure ()) getJoinsPartition p mapM_ getJoinsOrder os getJoinsPartition :: Partition ResolvedNames a -> Scoped () getJoinsPartition (PartitionBy _ es) = mapM_ getJoinsExpr es getJoinsPartition (PartitionBest _) = return () getJoinsPartition (PartitionNodes _) = return () getJoinsOrder :: Order ResolvedNames a -> Scoped () getJoinsOrder (Order _ posOrExpr _ _) = void $ getJoinsPositionOrExpr posOrExpr getJoinsTablish :: Tablish ResolvedNames a -> Scoped () getJoinsTablish (TablishTable _ _ _) = pure () getJoinsTablish (TablishLateralView _ LateralView{..} lhs) = do maybe (pure ()) getJoinsTablish lhs mapM_ getJoinsExpr lateralViewExprs getJoinsTablish (TablishSubQuery _ _ query) = getJoinsQuery query getJoinsTablish (TablishJoin _ _ (JoinNatural _ (RNaturalColumns columns)) lhs rhs) = do getJoinsTablish lhs getJoinsTablish rhs forM_ columns $ \ (RUsingColumn lcol rcol) -> do emit $ M.fromSet (const $ FieldChain M.empty) $ S.fromList [void lcol, void rcol] getJoinsTablish (TablishJoin _ _ (JoinOn expr) lhs rhs) = do getJoinsTablish lhs getJoinsTablish rhs void $ getJoinsExpr expr getJoinsTablish (TablishJoin _ _ (JoinUsing _ columns) lhs rhs) = do getJoinsTablish lhs getJoinsTablish rhs forM_ columns $ \ (RUsingColumn lcol rcol) -> do emit $ M.fromSet (const $ FieldChain M.empty) $ S.fromList [void lcol, void rcol] getJoinsCTE :: CTE ResolvedNames a -> Scoped () getJoinsCTE (CTE _ _ _ query) = getJoinsQuery query getJoinsSelection :: Selection ResolvedNames a -> Scoped () getJoinsSelection (SelectStar _ _ _) = pure () getJoinsSelection (SelectExpr _ aliases expr) = do cols <- getJoinsExpr expr forM_ aliases $ \ (ColumnAlias _ _ aliasId) -> bind aliasId cols