module TransferDB where import Prelude hiding (fail, log) import System.IO (hPutStr, hPutStrLn, stderr, putStrLn, hFlush, stdout) import Control.Concurrent (forkIO, ThreadId) import Control.Concurrent.STM (TVar, newTVar, modifyTVar, readTVar, writeTVar, TQueue, newTQueue, readTQueue, writeTQueue, STM, atomically, check, orElse, retry) import Control.Monad (replicateM_, join) import Control.Monad.Trans.Maybe (MaybeT, runMaybeT) import Control.Monad.Trans.Reader (ReaderT, runReaderT, ask, asks, withReaderT) import Control.Monad.Trans.Class (lift) import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Fail (MonadFail, fail) import Control.Logging (log, withStderrLogging) import Foreign.Marshal.Alloc (alloca, allocaBytes) import Foreign.Ptr (Ptr, castPtr, nullPtr, wordPtrToPtr, ptrToWordPtr) import Foreign.Storable (peek, poke) import Data.String(IsString(fromString)) import Data.Char (toLower, toUpper) import Data.List (find) import SQL.CLI (SQLHDBC, SQLSMALLINT, SQLINTEGER, SQLPOINTER, SQLULEN, SQLLEN, sql_handle_dbc, sql_handle_stmt, sql_null_data, sql_all_types, sql_default, sql_char, sql_datetime, sql_code_date, sql_code_time, sql_code_timestamp, sql_type_date, sql_type_time, sql_type_timestamp, sql_varchar, sql_attr_imp_row_desc, sql_data_at_exec, sql_desc_type, sql_desc_length, sql_desc_octet_length, sql_desc_precision, sql_desc_scale, sql_desc_datetime_interval_code, sql_commit, sql_rollback) import SQL.ODBC (sql_longvarchar, sql_binary, sql_longvarbinary, sql_varbinary, sql_interval, sql_attr_autocommit, sql_autocommit_off) import SQL.CLI.Utils (SQLConfig, ColumnInfo (ci_ColumnName, ci_TableSchem, ci_TableName), allocHandle, freeHandle, tableExists, forAllRecords, collectColumnsInfo, execDirect, prepare, execute, getStorableStmtAttr, numResultCols, paramData, forAllData, putData, getDescField, bindParam, setConnectAttr, endTran, toCLIType) import Data.List (intercalate) import TransferPlan (Batch, BatchItem, ColumnName, WhereCondition, scope_Db, scope_UserName, scope_Password, scope_Schema, plan_Source, plan_Destination, plan_Batches, batch_Name, batch_Items, batch_Table, batch_Where, batch_OrderBy) import Options (Options(option_Plan, option_Threads, option_Count, option_Drop)) import Database.TransferDB.Commons (finally, withConnection) -- | the size of the buffer used to transfer data transferBufferSize :: SQLLEN transferBufferSize = 8192 * 8 data TransferOptions = TransferOptions { to_Options :: Options, to_SQLConfig :: SQLConfig} transferDB :: ReaderT TransferOptions (MaybeT IO) () transferDB = withStderrLogging $ do threads <- asks (option_Threads . to_Options) if threads == 0 then do liftIO $ log $ fromString "Transfering database in single thread mode" schema1 <- asks (scope_Schema . plan_Source . option_Plan . to_Options) schema2 <- asks (scope_Schema . plan_Destination . option_Plan . to_Options) batches <- asks (plan_Batches . option_Plan . to_Options) sqlConfig <- asks to_SQLConfig dropBatches <- asks (option_Drop . to_Options) countBatches <- asks (option_Count . to_Options) let batches' = if dropBatches > 0 then drop dropBatches batches else batches batches'' = if countBatches > 0 then take countBatches batches' else batches' liftIO $ log $ fromString $ "Number of batches: " ++ (show $ length batches'') ++ "/" ++ (show $ length batches) result <- withSourceAndDest (\ srcDBC dstDBC -> liftIO $ allocaBytes (fromIntegral transferBufferSize) (\ p_transfer_buf -> alloca (\ p_transfer_len_or_ind -> runMaybeT $ runReaderT (sequence_ $ map (transferBatch srcDBC schema1 dstDBC schema2 p_transfer_buf p_transfer_len_or_ind) batches'') sqlConfig))) maybe (fail "Transfer tables failed") return result else do liftIO $ log $ fromString $ "Transfering database using " ++ (show threads) ++ " threads" batchChan <- liftIO $ atomically newTQueue allBatchesPublished <- liftIO $ atomically $ newTVar False workerThreads <- liftIO $ atomically $ newTVar 0 withReaderT (ThreadedTransferSetup batchChan allBatchesPublished workerThreads) $ do startWorkerThreads publishBatches waitForWorkToEnd -- | Runs an action with 2 connections, to the source and destination dbs withSourceAndDest :: (MonadIO m) => (SQLHDBC -> SQLHDBC -> ReaderT TransferOptions (MaybeT m) a) -> ReaderT TransferOptions (MaybeT m) a withSourceAndDest f = do s1 <- asks (scope_Db . plan_Source . option_Plan . to_Options) u1 <- asks (scope_UserName . plan_Source . option_Plan . to_Options) p1 <- asks (scope_Password . plan_Source . option_Plan . to_Options) s2 <- asks (scope_Db . plan_Destination . option_Plan . to_Options) u2 <- asks (scope_UserName . plan_Destination . option_Plan . to_Options) p2 <- asks (scope_Password . plan_Destination . option_Plan . to_Options) withConnection s1 u1 p1 (\ _ srcDBC -> do setConnectAttr srcDBC sql_attr_autocommit (wordPtrToPtr sql_autocommit_off) 0 withConnection s2 u2 p2 (\ _ dstDBC -> do setConnectAttr dstDBC sql_attr_autocommit (wordPtrToPtr sql_autocommit_off) 0 f srcDBC dstDBC )) -- | Executes one transfer batch. It gets the source db conection, the source schema, the destination db connection, the -- destination schema, the pointers to allocated -- buffers for receiving the data and length/indicator value, the batch and returns a 'ReaderT' action expecting -- the 'SQLConfig' of the current CLI implementation. transferBatch :: SQLHDBC -> String -> SQLHDBC -> String -> SQLPOINTER -> Ptr SQLLEN -> Batch -> ReaderT SQLConfig (MaybeT IO) () transferBatch srcDBC schema1 dstDBC schema2 p_transfer_buf p_transfer_len_or_ind batch = let transfer_table' :: BatchItem -> ReaderT SQLConfig (MaybeT IO) () transfer_table' batchItem = do let tableName = batch_Table batchItem whereCondition = batch_Where batchItem orderBy = batch_OrderBy batchItem srcExists <- tableExists srcDBC schema1 tableName destTableName <- let lowerTableName = map toLower tableName upperTableName = map toUpper tableName in findM (lift . (tableExists dstDBC schema2)) [tableName, lowerTableName, upperTableName] let (destExists, destTableName') = maybe (False, "") ((,) True) destTableName liftIO $ log $ fromString $ "------> Source table " ++ tableName ++ (if srcExists then " exists" else " does not exist") liftIO $ log $ fromString $ "------> Source table " ++ tableName ++ (if destExists then (" has destination: " ++ destTableName') else " has no destination.") if srcExists && destExists then transferTable srcDBC dstDBC (batch_Name batch) schema1 tableName schema2 destTableName' orderBy whereCondition p_transfer_buf p_transfer_len_or_ind else return () in do liftIO $ log $ fromString "---------------------------------" liftIO $ log $ fromString $ "Batch: " ++ (batch_Name batch) liftIO $ log $ fromString "---------------------------------" sequence_ $ map transfer_table' $ batch_Items batch data ThreadedTransferSetup = ThreadedTransferSetup { threaded_BatchChan :: TQueue Batch, -- ^ the queue for sending the batch information to the worker threads threaded_AllBatchesPublished :: TVar Bool, -- ^ if false, there still are more batches to be published threaded_WorkerThreads :: TVar Int, -- ^ the number of active worker threads threaded_TransferOptions :: TransferOptions -- ^ the transfere options, read from the command line } -- | Create the requested number of worker threads; the requested number of threads is read from the -- 'threaded_TransferOptions' environment variable. -- -- Each worker thread will evaluate the worker function ('transferBatch') for each batch read from the 'threaded_BatchChan' queue. startWorkerThreads :: (MonadIO m) => ReaderT ThreadedTransferSetup (MaybeT m) () startWorkerThreads = do threads <- asks (option_Threads . to_Options . threaded_TransferOptions) replicateM_ threads runWorkerThread -- | Start a worker thread responsible with evaluating the worker action function for each -- batch read from the 'threaded_BatchChan'. It exists when all batches have been published and -- the 'threaded_BatchChan' queue is empty runWorkerThread :: (MonadIO m) => ReaderT ThreadedTransferSetup (MaybeT m) ThreadId runWorkerThread = do workerThreadsVar <- asks threaded_WorkerThreads setup <- ask liftIO $ forkIO $ do atomically $ modifyTVar workerThreadsVar (+ 1) log $ fromString "Thread started" result <- runMaybeT $ runReaderT processBatches setup log $ fromString $ maybe "Thread failed" (\ _ -> "Thread ended") result atomically $ modifyTVar workerThreadsVar (subtract 1) -- | implements the action that runs in a thread, processing batches enqued -- in a queue, one by one processBatches :: ReaderT ThreadedTransferSetup (MaybeT IO) () processBatches = do queue <- asks threaded_BatchChan allBatchesPublishedVar<- asks threaded_AllBatchesPublished sqlConfig <- asks (to_SQLConfig . threaded_TransferOptions) schema1 <- asks (scope_Schema . plan_Source . option_Plan . to_Options . threaded_TransferOptions) schema2 <- asks (scope_Schema . plan_Destination . option_Plan . to_Options . threaded_TransferOptions) withReaderT threaded_TransferOptions $ withSourceAndDest (\ srcDBC dstDBC -> liftIO $ allocaBytes (fromIntegral transferBufferSize) (\ p_transfer_buf -> alloca (\ p_transfer_len_or_ind -> let processBatches' :: IO () processBatches' = join $ atomically $ processNextBatch `orElse` checkForEnd processNextBatch :: STM (IO ()) processNextBatch = do batch <- readTQueue queue return $ do result <- runMaybeT $ runReaderT (transferBatch srcDBC schema1 dstDBC schema2 p_transfer_buf p_transfer_len_or_ind batch) sqlConfig maybe (log $ fromString $ "batch " ++ (batch_Name batch) ++ " failed") (\ _ -> log $ fromString $ "batch " ++ (batch_Name batch) ++ " finished") result processBatches' checkForEnd = do allBatchesPublished <- readTVar allBatchesPublishedVar check allBatchesPublished return $ return () in processBatches'))) -- | publish all batches to 'threaded_BatchChan' queue publishBatches :: (MonadIO m) => ReaderT ThreadedTransferSetup m () publishBatches = do batches <- asks (plan_Batches . option_Plan . to_Options . threaded_TransferOptions) dropBatches <- asks (option_Drop . to_Options . threaded_TransferOptions) countBatches <- asks (option_Count . to_Options . threaded_TransferOptions) queue <- asks threaded_BatchChan allBatchesPublishedVar<- asks threaded_AllBatchesPublished let batches' = if dropBatches > 0 then drop dropBatches batches else batches batches'' = if countBatches > 0 then take countBatches batches' else batches' liftIO $ log $ fromString $ "Number of batches: " ++ (show $ length batches'') ++ "/" ++ (show $ length batches) liftIO $ mapM_ (atomically . (writeTQueue queue)) batches'' liftIO $ atomically $ writeTVar allBatchesPublishedVar True -- | wait for worker threads to complete work waitForWorkToEnd :: (MonadIO m) => ReaderT ThreadedTransferSetup m () waitForWorkToEnd = do allBatchesPublishedVar<- asks threaded_AllBatchesPublished workerThreadsVar <- asks threaded_WorkerThreads liftIO $ atomically $ do allBatchesPublished <- readTVar allBatchesPublishedVar workerThreads <- readTVar workerThreadsVar check (workerThreads <= 0 && allBatchesPublished) -- don't stop if there still are batches to be published -- | Transfer data from one table in one database to another table in -- another database. If all goes well, a line with the source table name -- destination table name, number of records and total transferred size, -- separated by comma is displayed on standard out. transferTable :: SQLHDBC -- ^ source connection handler -> SQLHDBC -- ^ destination connection handler -> String -- ^ batch name -> String -- ^ source schema name -> String -- ^ source table name -> String -- ^ destination schema name -> String -- ^ destination table name -> Maybe [ColumnName] -- ^ optional order by fields; it should either be Nothing or Just a list with at least one column -> Maybe WhereCondition -- ^ optional where condition -> SQLPOINTER -- ^ a pointer to a transfer buffer of size 'transferBufferSize'; it will be used to read data from the source fields -> Ptr SQLLEN -- ^ a pointer to a buffer used to store column size info -> ReaderT SQLConfig (MaybeT IO) () transferTable srcDBC dstDBC batchName srcSName srcTName dstSName dstTName orderBy whereCondition p_transfer_buf p_transfer_len_or_ind = do liftIO $ log $ fromString $ "Transferring table from source " ++ srcTName ++ " to destination " ++ dstTName -- 1. execute the select statemtn from source table liftIO $ log $ fromString "---------------------------- 1. SOURCE TABLE -------------------------------" colss <- collectColumnsInfo srcDBC srcSName srcTName select <- makeSelectSql colss orderBy whereCondition liftIO $ log $ fromString $ batchName ++ " (source-db) executing: " ++ select srcStmt <- allocHandle sql_handle_stmt srcDBC finally (freeHandle sql_handle_stmt srcStmt) $ do execDirect srcStmt select (fail $ batchName ++ " parameter data expected on source select statement") -- 2. prepare the insert statement to destination table liftIO $ log $ fromString "---------------------------- 2. DESTINATION TABLE -------------------------------" colsd <- collectColumnsInfo dstDBC dstSName dstTName insert <- makeInsertSql colss colsd liftIO $ log $ fromString $ batchName ++ " (destination-db) preparing: " ++ insert dstStmt <- allocHandle sql_handle_stmt dstDBC finally (freeHandle sql_handle_stmt dstStmt) $ do prepare dstStmt insert -- 3. set dynamic parameters for the insert statement based on columns -- information from the select statement liftIO $ log $ fromString "---------------------------- 3. DYNAMIC PARAMETERS -------------------------------" numCols <- numResultCols srcStmt srcDesc <- getStorableStmtAttr srcStmt sql_attr_imp_row_desc let whereConditionInfo' = maybe "(-)" (\ w -> "(" ++ w ++ ")") whereCondition whereConditionInfo = " " ++ whereConditionInfo' result <- liftIO $ alloca (\ p_transferred -> -- p_transferred is an Ptr Integer to keep track of -- transferred size alloca (\ p_data_at_exec -> runMaybeT $ do -- p_data_at_exec is the address of a buffer used to set OCTET_LENGTH_PTR field -- of parameters descriptor for the insert into destination table statement. This -- buffer should contain the value 'SQL_DATA_AT_EXEC', so the value of the parameter -- will be asked when executing the insert statement. liftIO $ poke (p_data_at_exec :: Ptr SQLLEN) sql_data_at_exec liftIO $ poke p_transferred 0 result <- liftIO $ allocaBytes 64 (\ p_buf -> let setParamFromStmt recno = do typeField <- liftIO $ runMaybeT $ do getDescField srcDesc recno sql_desc_type p_buf 64 nullPtr liftIO (peek (castPtr p_buf) :: IO SQLSMALLINT) --liftIO $ log $ fromString $ "recno " ++ (show recno) ++ ": type = " ++ (show typeField) subTypeField <- liftIO $ runMaybeT $ do getDescField srcDesc recno sql_desc_datetime_interval_code p_buf 64 nullPtr liftIO (peek (castPtr p_buf) :: IO SQLSMALLINT) --liftIO $ log $ fromString $ "recno " ++ (show recno) ++ ": subtype = " ++ (show subTypeField) lengthField <- liftIO $ runMaybeT $ do liftIO $ poke ((castPtr p_buf) :: Ptr SQLULEN) 0 getDescField srcDesc recno sql_desc_length p_buf 64 nullPtr liftIO (peek (castPtr p_buf) :: IO SQLULEN) --liftIO $ log $ fromString $ "recno " ++ (show recno) ++ ": length = " ++ (show lengthField) octetLenField <- liftIO $ runMaybeT $ do liftIO $ poke ((castPtr p_buf) :: Ptr SQLLEN) 0 getDescField srcDesc recno sql_desc_octet_length p_buf 64 nullPtr liftIO (peek (castPtr p_buf) :: IO SQLLEN) --liftIO $ log $ fromString $ "recno " ++ (show recno) ++ ": octet length = " ++ (show octetLenField) precisionField <- liftIO $ runMaybeT $ do getDescField srcDesc recno sql_desc_precision p_buf 64 nullPtr liftIO (peek p_buf :: IO SQLSMALLINT) --liftIO $ log $ fromString $ "recno " ++ (show recno) ++ ": precision = " ++ (show precisionField) scaleField <- liftIO $ runMaybeT $ do getDescField srcDesc recno sql_desc_scale p_buf 64 nullPtr liftIO (peek p_buf :: IO SQLSMALLINT) --liftIO $ log $ fromString $ "recno " ++ (show recno) ++ ": scale = " ++ (show scaleField) let typeField' = maybe sql_all_types id typeField translatedType = toCLIType typeField' subTypeField' = maybe sql_all_types id subTypeField lengthField' = maybe 0 id lengthField precisionField' = maybe 0 id precisionField scaleField' = maybe 0 id scaleField lenprec = if typeField' `elem` [sql_char, sql_varchar, sql_datetime, sql_longvarchar, sql_binary, sql_longvarbinary, sql_varbinary, sql_interval] then lengthField' else fromIntegral $ precisionField' let translatedType' = case translatedType of x | x == sql_datetime -> case subTypeField' of x' | x' == sql_code_date -> sql_type_date | x' == sql_code_time -> sql_type_time | x' == sql_code_timestamp -> sql_type_timestamp | otherwise -> sql_datetime | otherwise -> x --liftIO $ log $ fromString $ "recno " ++ (show recno) ++ ": lenprec = " ++ (show lenprec) bindParam dstStmt recno sql_default translatedType' lenprec scaleField' (wordPtrToPtr (fromIntegral recno)) p_data_at_exec --liftIO $ log $ fromString "Implementation descriptor fields:" return () in runMaybeT $ sequence_ [setParamFromStmt i | i <- [1..numCols]]) maybe (fail $ batchName ++ ": set dynamic parameters for insert into the destination table failed") return result -- 4. for each row in the result set of select statement, execute the -- insert statement liftIO $ log $ fromString "---------------------------- 4. TRANSFER -------------------------------" let transferRow :: (MonadIO m, MonadFail m) => Int -> m Int transferRow count = seq count $ do if count `mod` 10000 == 0 then liftIO $ hPutStr stderr "." else return () execute dstStmt transferCols return $! (count + 1) transferCols :: (MonadIO m, MonadFail m) => m () transferCols = paramData dstStmt transferCol transferCol :: (MonadIO m, MonadFail m) => SQLPOINTER -> m () transferCol p_data = do let colno = fromIntegral $ ptrToWordPtr p_data transferredSize <- seq colno $ forAllData srcStmt colno sql_default p_transfer_buf transferBufferSize p_transfer_len_or_ind transferChunk 0 crtTransferred <- liftIO $ peek p_transferred seq crtTransferred $ seq transferredSize $ liftIO $ poke p_transferred $! (transferredSize + crtTransferred) transferChunk :: (MonadIO m, MonadFail m) => Int -> m Int transferChunk crtsize = seq crtsize $ do size <- liftIO $ peek p_transfer_len_or_ind let chunksize = if size < transferBufferSize then size else transferBufferSize seq chunksize $ putData dstStmt p_transfer_buf chunksize return $! if size == sql_null_data then 0 else (fromIntegral chunksize) + crtsize liftIO $ log $ fromString $ batchName ++ " - Start transfering table: " ++ srcTName count <- forAllRecords srcStmt transferRow 0 transferred <- liftIO $ peek p_transferred liftIO $ log $ fromString $ "\n" ++ batchName ++ " - Transfer completed from " ++ srcTName ++ " to " ++ dstTName ++ whereConditionInfo liftIO $ log $ fromString $ batchName ++ " - Transferred " ++ (show count) ++ " records" ++ ", " ++ (show transferred) ++ " bytes" liftIO $ putStrLn $ batchName ++ "," ++ srcTName ++ "," ++ dstTName ++ "," ++ (show count) ++ "," ++ (show transferred) ++ "," ++ whereConditionInfo' liftIO $ hFlush stdout commit)) maybe (rollback >> (fail $ batchName ++ " - Transferred failed between " ++ srcTName ++ " and " ++ dstTName ++ whereConditionInfo)) return result where rollback :: (MonadIO m) => m () rollback = liftIO $ do liftIO $ log $ fromString $ batchName ++ " - Rollback transaction" _ <- runMaybeT $ endTran sql_handle_dbc dstDBC sql_rollback _ <- runMaybeT $ endTran sql_handle_dbc srcDBC sql_rollback return () commit :: (MonadIO m, MonadFail m) => m () commit = do result <- runMaybeT $ do liftIO $ log $ fromString $ batchName ++ " - Commit transaction" endTran sql_handle_dbc dstDBC sql_commit endTran sql_handle_dbc srcDBC sql_commit maybe (rollback >> fail (batchName ++ " - commit failed")) return result makeSelectSql :: (MonadIO m, MonadFail m) => [ColumnInfo] -> Maybe [ColumnName] -> Maybe WhereCondition -> m String makeSelectSql cs orderBy whereCondition = do tableName <- extractQualifiedTableName cs let select = "select " ++ (fieldsList cs) ++ " from " ++ tableName select' = maybe select addWhere whereCondition select'' = maybe select' addOrderBy orderBy addWhere w = select ++ " where " ++ w addOrderBy fs = select' ++ " order by " ++ (intercalate ", " fs) return select'' makeInsertSql :: (MonadIO m, MonadFail m) => [ColumnInfo] -- ^ source columns info -> [ColumnInfo] -- ^ destination columns info -> m String makeInsertSql cs' cs = do tableName <- extractQualifiedTableName cs matchedFields <- findDstFields return $ "insert into " ++ tableName ++ " (" ++ (fieldsList' matchedFields) ++ ") values (" ++ values ++ ")" where dstfields = fields cs srcfields = fields cs' values = intercalate ", " $ replicate n "?" n = length cs findDstFields = sequence $ map findDstField srcfields findDstField f = do let dstField = find (`elem` dstfields) [f, lowerF, upperF] maybe failf return dstField where upperF = map toUpper f lowerF = map toLower f failf = do let err = "source field " ++ f ++ " not found in destination table" liftIO $ log $ fromString err fail err fieldsList :: [ColumnInfo] -> String fieldsList cs = fieldsList' fieldNames where fieldNames = [field | field <- fields cs] fieldsList' :: [String] -> String fieldsList' fieldNames = intercalate ", " fieldNames fields :: [ColumnInfo] -> [String] fields cs = map ci_ColumnName cs extractQualifiedTableName :: (MonadIO m, MonadFail m) => [ColumnInfo] -> m String extractQualifiedTableName [] = do liftIO $ log $ fromString "extractQualifiedTableName called with no column info" fail "extractQualifiedTableName failed: columns list is empty" extractQualifiedTableName (c:cs) = let schemaName = ci_TableSchem c tableName = ci_TableName c otherSchema = find ((schemaName /= ) . ci_TableSchem) cs otherTable = find ((tableName /= ) . ci_TableName) cs in case otherSchema of Just s -> do let err = "Columns info contain different schema names: " ++ schemaName ++ ", " ++ (ci_TableSchem s) liftIO $ log $ fromString err fail err Nothing -> case otherTable of Just t -> do let err = "Columns info contain different table names: " ++ tableName ++ ", " ++ (ci_TableName t) liftIO $ log $ fromString err fail err Nothing -> case schemaName of [] -> return tableName s' -> return $ s' ++ "." ++ tableName findM :: (Monad m) => (a -> m Bool) -> [a] -> m (Maybe a) findM _ [] = return Nothing findM f (x:xs) = do found <- f x if found then return $ Just x else findM f xs