module Bio.Bam.Trim
( trim_3
, trim_3'
, trim_low_quality
, default_fwd_adapters
, default_rev_adapters
, find_merge
, mergeBam
, find_trim
, trimBam
, mergeTrimBam
, twoMins
, merged_seq
, merged_qual
) where
import Bio.Bam.Header
import Bio.Bam.Rec
import Bio.Bam.Rmdup ( ECig(..), setMD, toECig )
import Bio.Iteratee
import Bio.Prelude
import Foreign.C.Types ( CInt(..) )
import qualified Data.ByteString as B
import qualified Data.Vector.Generic as V
import qualified Data.Vector.Storable as W
trim_3' :: ([Nucleotides] -> [Qual] -> Bool) -> BamRec -> BamRec
trim_3' p b | b_flag b `testBit` 4 = trim_rev
| otherwise = trim_fwd
where
trim_fwd = let l = subtract 1 . fromIntegral . length . takeWhile (uncurry p) $
zip (inits . reverse . V.toList $ b_seq b)
(inits . reverse . V.toList $ b_qual b)
in trim_3 l b
trim_rev = let l = subtract 1 . fromIntegral . length . takeWhile (uncurry p) $
zip (inits . V.toList $ b_seq b)
(inits . V.toList $ b_qual b)
in trim_3 l b
trim_3 :: Int -> BamRec -> BamRec
trim_3 l b | b_flag b `testBit` 4 = trim_rev
| otherwise = trim_fwd
where
trim_fwd = let (_, cigar') = trim_back_cigar (b_cigar b) l
c = modMd (takeECig (V.length (b_seq b) - l)) b
in c { b_seq = V.take (V.length (b_seq c) - l) (b_seq c)
, b_qual = V.take (V.length (b_qual c) - l) (b_qual c)
, b_cigar = cigar'
, b_exts = map (\(k,e) -> case e of
Text t | k `elem` trim_set
-> (k, Text (B.take (B.length t - l) t))
_ -> (k,e)
) (b_exts c) }
trim_rev = let (off, cigar') = trim_fwd_cigar (b_cigar b) l
c = modMd (dropECig l) b
in c { b_seq = V.drop l (b_seq c)
, b_qual = V.drop l (b_qual c)
, b_pos = b_pos c + off
, b_cigar = cigar'
, b_exts = map (\(k,e) -> case e of
Text t | k `elem` trim_set
-> (k, Text (B.drop l t))
_ -> (k,e)
) (b_exts c) }
trim_set = ["BQ","CQ","CS","E2","OQ","U2"]
modMd :: (ECig -> ECig) -> BamRec -> BamRec
modMd f br = maybe br (setMD br . f . toECig (b_cigar br)) (getMd br)
endOf :: ECig -> ECig
endOf WithMD = WithMD
endOf WithoutMD = WithoutMD
endOf (Mat' _ es) = endOf es
endOf (Ins' _ es) = endOf es
endOf (SMa' _ es) = endOf es
endOf (Rep' _ es) = endOf es
endOf (Del' _ es) = endOf es
endOf (Nop' _ es) = endOf es
endOf (HMa' _ es) = endOf es
endOf (Pad' _ es) = endOf es
takeECig :: Int -> ECig -> ECig
takeECig 0 es = endOf es
takeECig _ WithMD = WithMD
takeECig _ WithoutMD = WithoutMD
takeECig n (Mat' m es) = Mat' n $ if n > m then takeECig (n-m) es else WithMD
takeECig n (Ins' m es) = Ins' n $ if n > m then takeECig (n-m) es else WithMD
takeECig n (SMa' m es) = SMa' n $ if n > m then takeECig (n-m) es else WithMD
takeECig n (Rep' ns es) = Rep' ns $ takeECig (n-1) es
takeECig n (Del' ns es) = Del' ns $ takeECig n es
takeECig n (Nop' m es) = Nop' m $ takeECig n es
takeECig n (HMa' m es) = HMa' m $ takeECig n es
takeECig n (Pad' m es) = Pad' m $ takeECig n es
dropECig :: Int -> ECig -> ECig
dropECig 0 es = es
dropECig _ WithMD = WithMD
dropECig _ WithoutMD = WithoutMD
dropECig n (Mat' m es) = if n > m then dropECig (n-m) es else Mat' n WithMD
dropECig n (Ins' m es) = if n > m then dropECig (n-m) es else Ins' n WithMD
dropECig n (SMa' m es) = if n > m then dropECig (n-m) es else SMa' n WithMD
dropECig n (Rep' _ es) = dropECig (n-1) es
dropECig n (Del' _ es) = dropECig n es
dropECig n (Nop' _ es) = dropECig n es
dropECig n (HMa' _ es) = dropECig n es
dropECig n (Pad' _ es) = dropECig n es
trim_back_cigar, trim_fwd_cigar :: V.Vector v Cigar => v Cigar -> Int -> ( Int, v Cigar )
trim_back_cigar c l = (o, V.fromList $ reverse c') where (o,c') = sanitize_cigar . trim_cigar l $ reverse $ V.toList c
trim_fwd_cigar c l = (o, V.fromList c') where (o,c') = sanitize_cigar $ trim_cigar l $ V.toList c
sanitize_cigar :: (Int, [Cigar]) -> (Int, [Cigar])
sanitize_cigar (o, [ ]) = (o, [])
sanitize_cigar (o, (op:*l):xs) | op == Pad = sanitize_cigar (o,xs)
| op == Del || op == Nop = sanitize_cigar (o + l, xs)
| op == Ins = (o, (SMa :* l):xs)
| otherwise = (o, (op :* l):xs)
trim_cigar :: Int -> [Cigar] -> (Int, [Cigar])
trim_cigar 0 cs = (0, cs)
trim_cigar _ [] = (0, [])
trim_cigar l ((op:*ll):cs) | bad_op op = let (o,cs') = trim_cigar l cs in (o + reflen op ll, cs')
| otherwise = case l `compare` ll of
LT -> (reflen op l, (op :* (ll-l)):cs)
EQ -> (reflen op ll, cs)
GT -> let (o,cs') = trim_cigar (l - ll) cs in (o + reflen op ll, cs')
where
reflen op' = if ref_op op' then id else const 0
bad_op o = o /= Mat && o /= Ins && o /= SMa
ref_op o = o == Mat || o == Del
trim_low_quality :: Qual -> a -> [Qual] -> Bool
trim_low_quality q = const $ all (< q)
find_merge :: [W.Vector Nucleotides] -> [W.Vector Nucleotides]
-> W.Vector Nucleotides -> W.Vector Qual
-> W.Vector Nucleotides -> W.Vector Qual
-> (Int, Int, Int)
find_merge ads1 ads2 r1 q1 r2 q2 = (mlen, score2 - score1, plain_score - score1)
where
plain_score = 6 * fromIntegral (V.length r1 + V.length r2)
(score1, mlen, score2) = twoMins plain_score (V.length r1 + V.length r2) $
merge_score ads1 ads2 r1 q1 r2 q2
mergeBam :: Int -> Int -> [W.Vector Nucleotides] -> [W.Vector Nucleotides] -> BamRec -> BamRec -> [BamRec]
mergeBam lowq highq ads1 ads2 r1 r2
| V.null (b_seq r1) && V.null (b_seq r2) = [ ]
| qual1 < lowq || mlen < 0 = [ r1', r2' ]
| qual1 >= highq && mlen == 0 = [ ]
| qual1 >= highq = [ rm ]
| mlen < len_r1-20 || mlen < len_r2-20 = [ rm ]
| otherwise = map flag_alternative [ r1', r2', rm ]
where
len_r1 = V.length $ b_seq r1
len_r2 = V.length $ b_seq r2
b_seq_r1 = V.convert $ b_seq r1
b_seq_r2 = V.convert $ b_seq r2
b_qual_r1 = V.convert $ b_qual r1
b_qual_r2 = V.convert $ b_qual r2
(mlen, qual1, qual2) = find_merge ads1 ads2 b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2
flag_alternative br = br { b_exts = updateE "FF" (Int $ extAsInt 0 "FF" br .|. eflagAlternative) $ b_exts br }
store_quals br = br { b_exts = updateE "YM" (Int qual1) $ updateE "YN" (Int qual2) $ b_exts br }
pair_flags = flagPaired.|.flagProperlyPaired.|.flagMateUnmapped.|.flagMateReversed.|.flagFirstMate.|.flagSecondMate
r1' = store_quals r1
r2' = store_quals r2
rm = store_quals $ merged_read mlen (fromIntegral $ min 63 qual1)
merged_read l qmax = nullBamRec {
b_qname = b_qname r1,
b_flag = flagUnmapped .|. complement pair_flags .&. b_flag r1,
b_seq = V.convert $ merged_seq l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2,
b_qual = V.convert $ merged_qual qmax l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2,
b_exts = let ff = if l < len_r1 then eflagTrimmed else 0
in updateE "FF" (Int $ extAsInt 0 "FF" r1 .|. eflagMerged .|. ff) $ b_exts r1 }
{-# INLINE merged_seq #-}
merged_seq :: (V.Vector v Nucleotides, V.Vector v Qual)
=> Int -> v Nucleotides -> v Qual -> v Nucleotides -> v Qual -> v Nucleotides
merged_seq l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2 = V.concat
[ V.take (l - len_r2) b_seq_r1
, V.zipWith4 zz (V.take l $ V.drop (l - len_r2) b_seq_r1)
(V.take l $ V.drop (l - len_r2) b_qual_r1)
(V.reverse $ V.take l $ V.drop (l - len_r1) b_seq_r2)
(V.reverse $ V.take l $ V.drop (l - len_r1) b_qual_r2)
, V.reverse $ V.take (l - len_r1) b_seq_r2 ]
where
len_r1 = V.length b_qual_r1
len_r2 = V.length b_qual_r2
zz !n1 (Q !q1) !n2 (Q !q2) | n1 == compls n2 = n1
| q1 > q2 = n1
| otherwise = compls n2
{-# INLINE merged_qual #-}
merged_qual :: (V.Vector v Nucleotides, V.Vector v Qual)
=> Word8 -> Int -> v Nucleotides -> v Qual -> v Nucleotides -> v Qual -> v Qual
merged_qual qmax l b_seq_r1 b_qual_r1 b_seq_r2 b_qual_r2 = V.concat
[ V.take (l - len_r2) b_qual_r1
, V.zipWith4 zz (V.take l $ V.drop (l - len_r2) b_seq_r1)
(V.take l $ V.drop (l - len_r2) b_qual_r1)
(V.reverse $ V.take l $ V.drop (l - len_r1) b_seq_r2)
(V.reverse $ V.take l $ V.drop (l - len_r1) b_qual_r2)
, V.reverse $ V.take (l - len_r1) b_qual_r2 ]
where
len_r1 = V.length b_qual_r1
len_r2 = V.length b_qual_r2
zz !n1 (Q !q1) !n2 (Q !q2) | n1 == compls n2 = Q $ min qmax (q1 + q2)
| q1 > q2 = Q $ q1 - q2
| otherwise = Q $ q2 - q1
find_trim :: [W.Vector Nucleotides]
-> W.Vector Nucleotides -> W.Vector Qual
-> (Int, Int, Int)
find_trim ads1 r1 q1 = (mlen, score2 - score1, plain_score - score1)
where
plain_score = 6 * fromIntegral (V.length r1)
(score1, mlen, score2) = twoMins plain_score (V.length r1) $
merge_score ads1 [V.empty] r1 q1 V.empty V.empty
trimBam :: Int -> Int -> [W.Vector Nucleotides] -> BamRec -> [BamRec]
trimBam lowq highq ads1 r1
| V.null (b_seq r1) = [ ]
| mlen == 0 && qual1 >= highq = [ ]
| qual1 < lowq || mlen < 0 = [ r1' ]
| qual1 >= highq = [ r1t ]
| otherwise = map flag_alternative [ r1', r1t ]
where
b_seq_r1 = V.convert $ b_seq r1
b_qual_r1 = V.convert $ b_qual r1
(mlen, qual1, qual2) = find_trim ads1 b_seq_r1 b_qual_r1
flag_alternative br = br { b_exts = updateE "FF" (Int $ extAsInt 0 "FF" br .|. eflagAlternative) $ b_exts br }
store_quals br = br { b_exts = updateE "YM" (Int qual1) $ updateE "YN" (Int qual2) $ b_exts br }
r1' = store_quals r1
r1t = store_quals $ trimmed_read mlen
trimmed_read l = nullBamRec {
b_qname = b_qname r1,
b_flag = flagUnmapped .|. b_flag r1,
b_seq = V.take l $ b_seq r1,
b_qual = V.take l $ b_qual r1,
b_exts = updateE "FF" (Int $ extAsInt 0 "FF" r1 .|. eflagTrimmed) $ b_exts r1 }
default_fwd_adapters :: [ W.Vector Nucleotides ]
default_fwd_adapters = map (W.fromList. map toNucleotides)
[ "AGATCGGAAGAGCGGTTCAG"
, "AGATCGGAAGAGCACACGTC"
, "AGATCGGAAGAGCTCGTATG" ]
default_rev_adapters :: [ W.Vector Nucleotides ]
default_rev_adapters = map (W.fromList. map toNucleotides)
[ "AGATCGGAAGAGCGTCGTGT"
, "GGAAGAGCGTCGTGTAGGGA" ]
merge_score
:: [ W.Vector Nucleotides ]
-> [ W.Vector Nucleotides ]
-> W.Vector Nucleotides -> W.Vector Qual
-> W.Vector Nucleotides -> W.Vector Qual
-> Int
-> Int
merge_score fwd_adapters rev_adapters !read1 !qual1 !read2 !qual2 !l
= 6 * fromIntegral (l `min` V.length read1)
+ 6 * fromIntegral (max 0 (l - V.length read1))
+ foldl' (\acc fwd_ad -> min acc
(match_adapter l read1 qual1 fwd_ad +
6 * fromIntegral (max 0 (V.length read1 - V.length fwd_ad - l)))
) maxBound fwd_adapters
+ foldl' (\acc rev_ad -> min acc
(match_adapter l read2 qual2 rev_ad +
6 * fromIntegral (max 0 (V.length read2 - V.length rev_ad - l)))
) maxBound rev_adapters
+ match_reads l read1 qual1 read2 qual2
{-# INLINE match_adapter #-}
match_adapter :: Int -> W.Vector Nucleotides -> W.Vector Qual -> W.Vector Nucleotides -> Int
match_adapter !off !rd !qs !ad
| V.length rd /= V.length qs = error "read/qual length mismatch"
| efflength <= 0 = 0
| otherwise
= fromIntegral . unsafePerformIO $
W.unsafeWith rd $ \p_rd ->
W.unsafeWith qs $ \p_qs ->
W.unsafeWith ad $ \p_ad ->
prim_match_ad (fromIntegral off)
(fromIntegral efflength)
p_rd p_qs p_ad
where
!efflength = (V.length rd - off) `min` V.length ad
foreign import ccall unsafe "prim_match_ad"
prim_match_ad :: CInt -> CInt
-> Ptr Nucleotides -> Ptr Qual
-> Ptr Nucleotides -> IO CInt
{-# INLINE match_reads #-}
match_reads :: Int -> W.Vector Nucleotides -> W.Vector Qual -> W.Vector Nucleotides -> W.Vector Qual -> Int
match_reads !l !rd1 !qs1 !rd2 !qs2
| V.length rd1 /= V.length qs1 || V.length rd2 /= V.length qs2 = error "read/qual length mismatch"
| efflength <= 0 = 0
| otherwise
= fromIntegral . unsafePerformIO $
W.unsafeWith rd1 $ \p_rd1 ->
W.unsafeWith qs1 $ \p_qs1 ->
W.unsafeWith rd2 $ \p_rd2 ->
W.unsafeWith qs2 $ \p_qs2 ->
prim_match_reads (fromIntegral minidx1)
(fromIntegral maxidx2)
(fromIntegral efflength)
p_rd1 p_qs1 p_rd2 p_qs2
where
!minidx1 = (l - V.length rd2) `max` 0
!maxidx2 = l `min` V.length rd2
!efflength = ((V.length rd1 + V.length rd2 - l) `min` l) `max` 0
foreign import ccall unsafe "prim_match_reads"
prim_match_reads :: CInt -> CInt -> CInt
-> Ptr Nucleotides -> Ptr Qual
-> Ptr Nucleotides -> Ptr Qual -> IO CInt
{-# INLINE twoMins #-}
twoMins :: (Bounded a, Ord a) => a -> Int -> (Int -> a) -> (a,Int,a)
twoMins a0 imax f = go a0 (-1) maxBound 0 0
where
go !m1 !i1 !m2 !i2 !i
| i == imax = (m1,i1,m2)
| otherwise =
case f i of
x | x < m1 -> go x i m1 i1 (i+1)
| x < m2 -> go m1 i1 x i (i+1)
| otherwise -> go m1 i1 m2 i2 (i+1)
mergeTrimBam :: Monad m => Int -> Int -> [W.Vector Nucleotides] -> [W.Vector Nucleotides] -> Enumeratee [BamRec] [BamRec] m a
mergeTrimBam lowq highq fwd_ads rev_ads = convStream go
where
go = do r1 <- headStream
if isPaired r1
then tryHead >>= go2 r1
else return $ trimBam lowq highq fwd_ads r1
go2 r1 Nothing = error $ "Lone mate found: " ++ show (b_qname r1)
go2 r1 (Just r2) = return $ mergeBam lowq highq fwd_ads rev_ads r1 r2