{- Dephd - decode phd (phred output) files.

   Generates seq/qual output, quality plots, or rankings.
-}

{-# LANGUAGE BangPatterns #-}

module Main where

import Control.Concurrent
import Control.Monad
import Data.Char
import Data.List (groupBy,sortBy,isPrefixOf,unfoldr)
import Data.Maybe
import System.Console.GetOpt
import System.Directory
import System.Environment (getArgs,getEnvironment)
import System.Exit
import System.IO
import System.IO.Unsafe
import System.Process
import Text.Printf
import Text.Regex
import qualified Data.ByteString.Lazy.Char8 as B
import qualified Data.ByteString.Lazy as BB

import Bio.Sequence
import Bio.Sequence.SeqData (hasqual)
import Bio.Util (countIO)

-- ------------------------------------------------------------
-- Option Handling
-- ------------------------------------------------------------

data MyOpts = O { actions ::  [(FilePath,Sequence Nuc) -> IO ()] -- ^ Apply to each sequence
                , outputs :: [Handle]                   -- ^ Output handles that must be closed
                , filters :: (FilePath,Sequence Nuc) -> (FilePath,Sequence Nuc) -- ^ Filter sequences before processing
                , zerofilter :: (FilePath,Sequence Nuc) -> Bool
                , inputs  :: [String] -> IO [(FilePath,Sequence Nuc)] -- ^ Turn args into sequences
                , verbose :: Bool  -- ^ Verbose output (progress reporting) and sequence trimming
                , libtable :: Maybe FilePath
                }

defaultopts :: MyOpts
defaultopts = O { actions = [], outputs = [], filters = id
                , zerofilter = const True
                , libtable = Nothing
                , inputs = readFiles, verbose = False }

getOptions :: IO (MyOpts, [String], [String])
getOptions  = do
  (os,ns,es) <- getArgs >>= (return . getOpt Permute options)
  os' <- foldl (>>=) (return defaultopts) os
  return (os',ns,es)

data PlotType = J | P | X

options :: [OptDescr (MyOpts -> IO MyOpts)]
options =
    [ Option ['v'] [] (NoArg (\opt -> return opt {verbose = True})) "Verbose output"
    , Option ['h'] ["help"] (NoArg (\_ -> do {putStrLn (usage []); exitWith ExitSuccess}))
                                                            "Display usage information"
    -- Output options
    , Option ['R'] ["output-ranks"]  (ReqArg setrank "file") "Set ranked output file"
    , Option ['F'] ["output-fasta"] (ReqArg setfasta "file") "Set fasta output file"
    , Option ['Q'] ["output-qual"]  (ReqArg setqual "file")  "Set quality output file"
    , Option ['E'] ["output-dbEST"] 
        (ReqArg (\arg opt -> do h <- openFile arg WriteMode 
                                lt <- getLibTable (libtable opt) h
                                (c,p) <- getContPub h
                                w <- mkWriteEST lt (c,p)
                                return opt { actions = (w h.snd):actions opt, outputs = h:outputs opt}) "file")   "Output file suitable for dbEST submission"
    , Option ['J'] ["output-plot"]  (NoArg  (setplot J))   "Plot sequence quality, JPEG"
    , Option ['P'] ["output-plot"]  (NoArg  (setplot P))   "Plot sequence quality, EPS"
    , Option ['X'] ["output-xplot"] (NoArg  (setplot X))   "Display quality plots"

    , Option ['l'] ["libtable"] (ReqArg (\arg opt -> return opt { libtable = Just arg }) "FILE")  "Specify a library table for ESTs"

    -- Filter options
    , Option ['t'] ["filter-trim"] (NoArg filterTrim) "Trim output sequences.\nSpecify *before* -q if you want to trim based on quality!"
    , Option ['q'] ["filter-qual"] (ReqArg filterQual "num") "Mask by quality to given threshold"
    -- Input options
        
    , Option [] ["input-dirs"]     (NoArg  (\opt -> return opt { inputs = readDirs }))
                                                            "Read directories containing PHD files"
    , Option [] ["input-list"]     (NoArg inputList) "Read the files listed in an index file"
    , Option ['i'] ["input-fasta-qual"] (NoArg inputFQ)  "Read a fasta- and optionally a qual file"
    , Option ['z'] ["ignore-empty-seqs"] (NoArg (\opt -> return opt { zerofilter = zeroFilter})) "Eliminate zero-length sequences from output"
    ]
    where
      inputList opt = return $ opt { inputs = \args -> case args of
                                                         [arg] -> do s <- readFile arg
                                                                     readFiles $ lines s
                                                         _     -> error ("Please specify only one file for list input: "++show args)}
      inputFQ opt = return $ opt
                  { inputs = \arg -> do ss <- case arg of [fa,q] -> readFastaQual fa q
                                                          [fa]   -> readFasta fa
                                                          _      -> error ("Too many files specified for fasta/qual input: "++show arg)
                                        return $ map (\s -> (toStr $ seqlabel s, castToNuc s)) ss }
      addFilter act opt     = let f = filters opt in return $ opt { filters = f . act }
      filterTrim = addFilter trimSeq
      filterQual t = case reads t of [(t',_)] -> addFilter (\(p,s) -> (p,qualAdjust t' s))
                                     _ -> error ("Couldn't parse numeric argument to -q: '"++t++"'")
      zeroFilter = (/=0). seqlength . snd
      addAction act arg opt = let as = actions opt
                                  hs = outputs opt
                                  condOpenArg fn 
                                      | fn == "-"           = return stdout 
                                      | "-" `isPrefixOf` fn = error ("Refusing output file name ('"++fn
                                                                     ++ "') starting with '-'.\n"
                                                                     ++ "Use ' ./"++fn++"' if this is what you want.")
                                      | otherwise           = openFile fn WriteMode
                              in do h <- condOpenArg arg
                                    return opt { actions = act h:as
                                               , outputs = h:hs }

      setrank  = addAction (\h -> hPutStrLn h . unwords {- columns -} . qualchk . snd)
      setfasta = addAction (\h -> hWriteFasta h . return . snd)
      setqual  = addAction (\h -> hWriteQual h . return . snd)
      setplot c opt = do
             g <- hasgp
             let how = case c of J -> \f -> "set terminal jpeg\nset output \""
                                        ++subRegex phd_rx f "" ++ ".jpg\"\n"
                                 P -> \f -> "set terminal postscript color eps\nset output\""
                                        ++subRegex phd_rx f "" ++ ".eps\"\n"
                                 X -> const ""
             if g then let as = actions opt in return $ opt { actions = plot how : as }
                  else error ("You requested quality plots, but I can't find 'gnuplot' in your search path.\n")


type LibTable = [[(String,String)]]  -- | column and value

readLibTable :: FilePath -> IO LibTable
readLibTable f = do
  let decomment = filter (\x -> (not . null) x && head x /= '#')
  (ch:ls) <- (map words . decomment . lines) `fmap` readFile f
  return $ map (zip ch) ls

-- | Classify a read according to the LibTable
classify :: [(Regex,String)] -> String -> String
classify ps str = case concatMap (class1 str) ps of
                    [] -> error ("no match for "++str++" in library table")
                    [x] -> x
                    s@(_:_) -> error ("multiple matches for "++str++": "++show s)
    where class1 st (r,s) = maybe [] (const [s]) (matchRegex r st)


-- Write ESTs in (partial) GenBank/dbEST format
-- (todo: speed up by going to ByteStrings)

getContPub :: Handle -> IO (String,String)
getContPub h = do  
  let fs = ["contact.txt","publication.txt"]
  ftest <- mapM doesFileExist fs
  when (not $ and ftest) $ do
       mapM_ mkContPub $ map snd $ filter (not . fst) $ zip ftest fs
       error ("Please fill in the required information")
  let field x f = do ls <- (filter ((x++":") `isPrefixOf`) . lines) `fmap` readFile f 
                     case ls of [l] -> return (dropWhile isSpace $ drop (length x+1) l)
                                _ -> error ("File: "++f++" is present, but does not contain field "++show x)
  readFile "publication.txt" >>= hPutStr h
  hPutStrLn h "||"
  readFile "contact.txt" >>= hPutStr h  
  hPutStrLn h "||"  
  
  cont_name <- field "NAME" "contact.txt"
  citation  <- field "TITLE" "publication.txt"
  return (cont_name,citation)  

-- | Generate skeleton files if they don't exist.
-- contact.txt must contain a NAME field, "citation.txt" must contain TITLE.
mkContPub :: String -> IO ()
mkContPub f = do
  writeFile f $ maybe "" id $ lookup f skeletons
  putStrLn ("There was no file '"++f
    ++"', I generated one for you, please make sure it contains the appropriate data.")
  where skeletons = [("contact.txt","TYPE:\tCont\nNAME:\tobligatory!\nTEL:\nEMAIL:\nINST:\nADDR:\n")
                    ,("publication.txt","TYPE:\tPub\nTITLE:\tobligatory!\nAUTHORS:\nJOURNAL:\nVOLUME:\nISSUE:\nYEAR:\nSTATUS:\n")]

getLibTable :: Maybe FilePath -> Handle -> IO LibTable
getLibTable Nothing _ = error "Specify a libtable before -E"
getLibTable (Just lt) h = do
  libtab <- readLibTable lt
  let split str = let (k,v) = span (/=':') str
                  in if null k || null v then error ("incorrect format in $DBLIB: "++str)
                     else (k,dropWhile isSpace $ drop 1 v)
  extraLibs <- map split `fmap` maybe [] lines `fmap` lookup "DBLIB" `fmap` getEnvironment
  let trsp '_' = ' ' -- underscores represent spaces in libtable
      trsp x   = x
      extract fs n = case filter ((==(map toLower n)) . map toLower . fst) (fs++extraLibs) of
        [] -> if n `elem` ["ORGANISM","NAME"]
              then error (n++" is a required field - specify in "++lt++" or $DBLIB")
              else ""
        [(_,x)] -> if x/="?" then n++":\t"++map trsp x++"\n" else ""
        _ -> error ("Multiple definitions for field \""++n++"\"")
      genlib fs =  "TYPE:\tLib\n"
                   ++ concatMap (extract fs) ["NAME", "TISSUE", "STAGE", "ORGANISM", "STRAIN"]
                   ++ "||\n"
  hPutStr h $ concatMap genlib libtab
  return libtab
               
mkWriteEST :: LibTable -> (String,String) -> IO (Handle -> Sequence Nuc -> IO ())
mkWriteEST libtab  (cont_name,citation) = do
  let matchtable = let ns = concatMap (map snd . filter ((=="Name") . fst)) $ libtab
                       rs = concatMap (map snd . filter ((=="Pattern") . fst)) $ libtab
                   in [(mkRegex r,n) | (r,n) <- zip rs ns]

      lookuplibrary = classify matchtable

  extraFields <- maybe [] lines `fmap` lookup "DBEST" `fmap` getEnvironment

  let hwe h s = let hiqual = if not (hasqual s) then []
                               else let as = sliding_avg 20 (seqqual s)
                                        (trim_left,trim_right) = (length $ takeWhile (<20) as, fromIntegral (seqlength s) - length (takeWhile (<20) $ reverse as))
                               in ["HIQUAL_START:\t" ++ show (1+min trim_left trim_right), "HIQUAL_STOP:\t"  ++ show trim_right] -- if trim_right is less, there is no high qual
                    polya = [case findPolyA s of Just _ -> "POLYA:\tY"; Nothing -> "POLYA:\tN"]     -- POLYA: "Y" or "N"
                    clone   = []   -- ditto?
                    put_id  = []   -- from annotations.csv
                    est_name = toStr (seqlabel s) 
                in case seqlength s of
                  0 -> return ()
                  _ -> hPutStr h $ unlines $
                         [ "TYPE:\tEST"
                         , "STATUS:\tNew"
                         , "CONT_NAME:\t"++cont_name
                         , "CITATION:\t"++citation
                         , "LIBRARY:\t"++lookuplibrary est_name
                         , "EST#:\t"++ est_name
                         ] ++ clone ++ put_id ++ hiqual ++ polya ++ extraFields ++
                         [ "COMMENT:\tgenerated by dephd"
                         , "SEQUENCE:\t"++toStr (seqdata s)
                         , "||"
                         ]
  return hwe

trimSeq :: (FilePath,Sequence Nuc) -> (FilePath,Sequence Nuc)
trimSeq (i,s@(Seq _ d mq)) = 
    case trims s of 
      ([t1,t2],h') -> let clip = B.take (fromIntegral t2-fromIntegral t1) . B.drop (fromIntegral t1)
                          s'   = Seq h' (clip d) (case mq of Nothing -> Nothing
                                                             Just q  -> Just (clip q))
                      in (i,appendHeader s' $ unwords ["clipped:",show t1,show t2])
      _       -> (i,s)

-- todo: clip only 'n's?
-- todo: add trimming info in header?  
-- (Currently, nothing protects agains re-trimming with the same parameters...)

-- PolyA finding.  Given error probability e, the probability that the base really is A is 
-- match: 1-e/0.25   mismatch e/3/0.25.  Qual Q => e = 10^(-Q/10) => 1-10^(-Q/10)/0.25
-- ie. match: log 4 - Q/10*log 10 - log 3
findPolyA :: Sequence Nuc -> Maybe (Int,Int)
findPolyA (Seq _ d mq) = 
      let qd = zip (B.unpack d) (maybe (repeat 15) BB.unpack mq)
          scores = map (\(c,q) -> if toUpper c=='A' then match q else mismatch q) qd
          match x' = let x = fromIntegral x' in log (4*(1-1/10**(x/10)))
          mismatch x' = let x = fromIntegral x' in log 4 - log 3 - x/10*log 10
          cumulative = scanl (\a b -> let r = a + b in max 0 r) 0
          (zi,mi,maxscore) = findmax $ cumulative scores
      in if maxscore > 12 then Just (zi+1,mi) else Nothing  -- arbitrary constant alert!

findmax :: [Double] -> (Int,Int,Double)
findmax = go 0 (0,0,0) . zip [0..]
    where go _ cm [] = cm
          go _ cm ((i,0):rest) = go i cm rest
          go last_z (cmz,cmi,cmx) ((i,x):rest) = if x > cmx then go last_z (last_z,i,x) rest 
                                                 else go last_z (cmz,cmi,cmx) rest

hasgp :: IO Bool
hasgp = return . isJust =<< findExecutable "gnuplot"

readDirs, readFiles :: [FilePath] -> IO [(FilePath,Sequence Nuc)]
readDirs dirs  = mapM' myReadPhd =<< filterM isPhdFile =<< return . concat =<<
                 mapM myGetDirectoryContents =<< filterM doesDirectoryExist dirs
readFiles = mapM' myReadPhd

mapM' :: (a -> IO b) -> [a] -> IO [b]
mapM' _      [] = hPutStrLn stderr "Warning: nothing to do!\n(Use '-h' for help)" >> return []
mapM' action xs = mapM action xs

-- ------------------------------------------------------------
-- Processing
-- ------------------------------------------------------------

-- | Quality is divided into four categories, mainly based on length of
--   a sliding average of quality greater than 20.
data Quality = Junk | Poor | Good | Excl deriving (Show,Eq)

main :: IO ()
main = do
  (opts,fs,errs) <- getOptions
  when (not $ null errs) $ error $ usage errs
  let process xs = if not (verbose opts) then return xs
                   else countIO ("processing "++show (length xs)++" sequences: ")
                            ", done.\n" 100 xs
  mapM_ (sequence_ . zipWith ($) (actions opts) . repeat) =<< return . filter (zerofilter opts) . map (filters opts) =<< process =<< (inputs opts) fs
  mapM_ hClose (outputs opts)

phd_rx :: Regex
phd_rx = mkRegex ".phd(.[0-9]+)?$"

isPhdFile :: FilePath -> IO Bool
isPhdFile fn = do exists <- doesFileExist fn
                  let matches = matchRegex phd_rx fn
                  return (exists && isJust matches)

myGetDirectoryContents :: FilePath -> IO [FilePath]
myGetDirectoryContents d = return . map ((d++"/")++) =<< getDirectoryContents d

-- | Adjust sequence content according to quality.
--   Upper case is >limit and sliding avg >limit*1.3
--   Trimming suggestions so that ends with average windows q<limit are clipped
qualAdjust :: Double -> Sequence a -> Sequence a
qualAdjust _ (Seq _ _ Nothing) = error "no quality data - impossible!"
qualAdjust limit sss =  Seq l_trim (B.unfoldr conv avgs) (Just q)
    where sq@(Seq l d (Just q)) = seqmap (\(cv,qv) -> (cv,if cv `elem` "ACGTacgt" then qv else 0)) sss
          avgs = (sliding_avg 1 q, sliding_avg 20 q, d)
          conv (a:as,s:ss,dd) = Just (if a>limit && s>limit*1.3 then toUpper (B.head dd)
                                      else if  a<4 || s<7 then 'n'
                                           else toLower (B.head dd),(as,ss,B.tail dd))
          conv ([],[],dd) | B.null dd = Nothing -- else broken invariant
          conv _ = error "internal error in 'qualAdjust/conv'"

          (trim_left,trim_right) = let (vs,as,_) = avgs
                                       atrim_l = length $ takeWhile (<limit) as
                                       atrim_r = length (takeWhile (<limit) $ reverse as)
                                   in ( atrim_l + length (takeWhile (<(limit*1.3)) (drop atrim_l vs))
                                      , fromIntegral (seqlength sq) - (atrim_r + length (takeWhile (<limit*1.3) $ drop atrim_r $ reverse vs)))
          l_trim = if trim_left > trim_right then B.concat [l,B.pack " QTRIM: 0 0"]
                   else B.concat (l:map B.pack [" QTRIM: ",show trim_left," ",show trim_right])

myReadPhd :: FilePath -> IO (FilePath,Sequence Nuc)
myReadPhd f = unsafeInterleaveIO (do p <- readPhd f
                                     p `seq` return (f,p))

usage :: [String] -> String
usage errs = usageInfo msg options
  where msg = (if (not $ null errs) then concat errs++"\n"
               else "dephd: analyze phd files (phred output)\n")
              ++"Usage: dephd -[RFQPX] [phdfile..]\n"
              ++"       dephd -[RFQPX] --input-dirs [phddir..]\n"

-- | Align columns: should be a standard function?
columns :: [[String]] -> [String]
columns ls = map (pad (head lens:map negate (tail lens))) ls
    where lens = collens ls

pad :: [Int] -> [String] -> String
pad x = unwords . zipWith padto x
    where padto l s = if l<0 then replicate (negate l-length s) ' '++s
                      else s ++ replicate (l-length s) ' '

collens :: [[String]] -> [Int]
collens ls = let ls' = filter (not . null) ls
             in if null ls' then []
                else (maximum . map (length . head) $ ls'):collens (map tail ls')

-- ugly hack
myhead :: [Int] -> Int
myhead x = if null x then 0 else head x

-- | Report (various?) quality estimates
qualchk :: Sequence a -> [String]
qualchk s = [toStr $ seqlabel s
            ,printf "%.1f" avgqual]
            ++ map show [s15, s30, a20] ++[show qtot]
    where
      s15 = stretch 15 qs
      s30 = stretch 30 qs
      a20 = stretch 20 . sliding_avg 20 . seqqual $ s
      qtot | a20 < 75 && s15 < 50   = Junk
           | a20 < 150              = Poor
           | a20 < 250 || s30 < 150 = Good
           | otherwise              = Excl

      qs = map (fromIntegral . ord) . toStr . seqqual $ s
      avgqual = sum qs / fromIntegral (seqlength s) :: Double
      stretch i = myhead . stretches i
      stretches i = sortOn negate . map (subtract 1 . length)
                    . groupBy (const (>i))
      sortOn f = sortBy (\x y -> compare (f x) (f y))

-- | Plot the quality of a sequence in the background
bgplot :: (FilePath -> String) -> (FilePath,Sequence Nuc) -> IO ThreadId
bgplot x = forkIO . plot x

-- | Feed the quality data to gnuplot (check if installed?)
plot :: (FilePath -> String) -> (FilePath,Sequence Nuc) -> IO ()
plot term (f,s) = case seqlength s of
    0 -> hPutStrLn stderr ("cannot plot empty sequence: "++f) >> return ()
    _ -> do     (i,o,e,p) <- runInteractiveCommand "gnuplot -persist"
                hPutStr i (mkGnuplot term (f,s))
                hClose i
                x <- waitForProcess p
                hGetContents o >>= hPutStr stderr
                hGetContents e >>= hPutStr stderr
                case x of ExitSuccess ->  return ()
                          ExitFailure j -> hPutStr stderr (errmsg++show j)
                                           >> return ()
    where errmsg = "'gnuplot' failed for "++f++" with exit code "

-- | Build a GNUplot file for graphing quality
mkGnuplot :: (FilePath -> String) -> (FilePath,Sequence Nuc) -> String
mkGnuplot term (f,s) = concat 
                [ header
                , unlines $ map (show . ord) $ toStr $ seqqual s
                , "e\n"
                , unlines $ map show $ sliding_avg 20 $ seqqual s
                , "e\n"
                ]
    where name = toStr (seqlabel s)
          arrow (i,c) = "set arrow "++show i++" from "++show c++",0 to "++show c ++",20 nohead\n"
          mtrim =  concatMap arrow $ zip [(1::Int)..] (fst $ trims s)
          header = term f
                   ++"set xlabel \"position\"\nset ylabel \"quality\"\n"
                   ++ mtrim
                   ++ "set title \""++name++"\"\nset yrange [0:100]\nplot \"-\" t \"qual\" with points,\"-\" t \"avg\" with lines lt 3, 20 with lines lt 2 t \"thresh\"\n"

-- | Look for trimming information.  Phred outputs "TRIM:" in the
--   sequence header, followed by trimming information.  Lucy outputs
--   just a bunch of numbers, the latter two are trimming information.
--   Returns the trim points, and the header with trimming info removed.
trims :: Sequence Nuc -> ([Int],SeqData)
trims s = let ws = B.words $ seqheader s
              (pt,rest1) = get_trim (B.pack "TRIM:") ws
              (dt,rest2) = get_trim (B.pack "QTRIM:") rest1
              mintrims [a,b] [c,d] = [max a c,min b d]
              mintrims [] x = x
              mintrims x _  = x
          in if all (all isDigit . B.unpack) (tail ws) && length ws > 2
                            then (map (read . B.unpack) $ reverse $ take 2 $ reverse ws
                                 ,B.unwords $ reverse $ drop 2 $ reverse ws)
                            else (mintrims pt dt, B.unwords rest2)

get_trim :: B.ByteString -> [B.ByteString] -> ([Int],[B.ByteString])
get_trim key sh = case dropWhile ((/=) key) sh of
                    (_:x:y:rest) -> ([read $ B.unpack x, read $ B.unpack y]
                                    ,takeWhile ((/=) key) sh ++ rest)
                    _ -> ([],sh)

test_seq :: Sequence Nuc                                     
test_seq = Seq (B.pack "foo TRIM: 1 10 QTRIM: 4 15") (B.pack "1234567890abcdefghij") Nothing

sliding_avg :: Int -> QualData -> [Double]
sliding_avg i q = let k = max 1 (min i (fromIntegral $ B.length q) `div` 2)
                      sum1 = fromIntegral $ sum $ map ord $ take k $ toStr $ q
                  in take (fromIntegral $ B.length q) $ unfoldr getAvg (sum1,k,B.drop (fromIntegral k) q,q)
    where getAvg :: (Int,Int,QualData,QualData) -> Maybe (Double,(Int,Int,QualData,QualData))
          getAvg (!a,!n,!q1,!q2) | B.null q1 && B.null q2 = Nothing
                                 | B.null q1 = let x = a-ord (B.head q2) in Just (fromIntegral a/fromIntegral n,(x, n-1, q1, B.tail q2))
                                 | n<i       = let x = a+ord (B.head q1) in Just (fromIntegral a/fromIntegral n,(x, n+1, B.tail q1, q2))
                                 | otherwise = let x = a+ord (B.head q1)-ord (B.head q2) 
                                               in Just (fromIntegral a/fromIntegral n, (x, n, B.tail q1, B.tail q2))

