{- | Annotate sequences with GO terms, via BLASTX hits to UniProt -}

module Main where

import Data.ByteString.Lazy.Char8 (unpack)
import qualified Data.ByteString.Lazy.Char8 as B
import qualified Data.Map as M
import Data.Map (Map)
import Control.Monad
import Data.Maybe
import System.IO
import System.Directory (doesDirectoryExist, createDirectory)
import Data.Array.Unboxed
import Text.Printf

import qualified Bio.Alignment.BlastData as BD
import Bio.Alignment.BlastData (BlastResult)

import Bio.Alignment.BlastXML as X -- new BlastFlat exports this too
import Data.List (intersperse,sortBy,partition,nub)
import Bio.Sequence.GeneOntology
import Bio.Alignment.BlastFlat 

import Annotate
import Options
import Html

main :: IO ()
main = do
  opts <- getOptions
  when (null (inputs opts)) $ fail "No input files specified!"
  (goterms,goanns) <- read_go opts
  keggs <- read_kegg opts

  let tabulator = tabulate opts goterms goanns keggs
      readXMLs fs = return . concat =<< mapM X.readXML fs
                  
  case format opts of
    Csv ->  readXMLs (inputs opts) >>= csvize opts tabulator
    Html -> readXMLs (inputs opts) >>= htmlize opts tabulator

-- ---------------------------------------------------------------------
-- Implement CSV output
-- ---------------------------------------------------------------------

-- | build CSV and HTML output from a header and a function to apply to blast records
csvize, htmlize :: Options -> ([String], BlastRecord -> [[String]]) -> [BlastResult] -> IO ()
csvize opts (header,writer) brs =
    do hs <- count opts 10 "Generating output: " $ concatMap results brs
       let csvFormat = unlines . map (concat . intersperse "," . map maybeQuote)
           maybeQuote = filter (/=',')
       output opts $ csvFormat (header : concatMap writer hs)

htmlize opts (header,writer) brs =
    do doesDirectoryExist htmldir >>= \d -> if d
          then when (verb opts) $ hPutStrLn stderr "Warning: reusing exisiting directory 'blast.d'"
          else createDirectory htmldir
       brs' <- count opts 10 "Generating output: " $ concatMap results brs
       withFile "index.html" WriteMode $ \h -> do
         hPutStr h (htmlheader (head brs) header)
         mapM (mkHtml h writer) brs'
         hPutStr h htmlfooter

-- | read file and return a header, and a function for generating output
tabulate :: Options -> GoDefinitions -> GoAnnotations -> KoAnnotations -> ([String], BlastRecord -> [[String]])
tabulate opts gds pts kos =
    let header = ["Query","from","to","Target","Description","from","to","ident","bitscore","E-value","direction"]
                 ++ (if isJust (goann opts) then ["GO terms"]
                        ++ case select opts of 
                             All -> case format opts of 
                                      Html -> ["Indirect","Hierarchy"]
                                      _ -> ["Hierarchy"]
                             _   -> ["Indirect","Hierarchy"]
                    else [])
                 ++ (if (not . null . ko) opts then ["KEGG"] else [])
        wr = case format opts of Html -> showTop
                                 Csv -> case select opts of All -> showAll
                                                            Top -> showTop
                                                            Reg -> showReg
    in (header, wr gds pts kos)

-- a writer generates one or more output records from a BlastRecord
type Writer = GoDefinitions -> GoAnnotations -> KoAnnotations -> (BlastRecord -> [[String]])

showAll, showTop, showReg :: Writer

-- nb: the 'return' is due to flatten needing a *list* of BlastRecord
showAll gds pts ks = map show1 . flatten . return
    where show1 bf = showFlat bf 
                     ++ (if not (M.null pts) then [showGo gds pts [bf],showGoHier gds pts [bf]] else [])
                     ++ (if not (M.null ks) then [showKegg ks [bf]] else [])

showTop g p k = showTop' g p k . flatten . return

showTop' :: GoDefinitions -> GoAnnotations -> KoAnnotations -> [BlastFlat] -> [[String]]
showTop' gds pts ks = map show1 . select_first
    where
      show1 (bf,go1,go2,go3,k1) = showFlat bf ++ (if not (M.null pts) then [go1,go2,go3] else [])
                                  ++ (if not (M.null ks) then [k1] else [])
      select_first [] = []
      -- take all hits for this query
      select_first (x:xs) = let (as,zs) = span (\y -> query x == query y) xs
                                (bs,ys) = span (\y -> subject x == subject y) as
                                -- ysubs = map head $ group $ map subject ys -- NB!
                            in (merge (x:bs), showGo gds pts [x], showGo gds pts ys, showGoHier gds pts (x:ys), showKegg ks (x:ys))
                                   : select_first zs

-- merge all hits against the same subject
merge :: [BlastFlat] -> BlastFlat
merge (x:xs) = let (rest,_) = partition (\y -> e_val y == e_val x) xs
                   q1 = minimum (map q_from (x:rest))
                   q2 = maximum (map q_to   (x:rest))
                   s1 = minimum (map h_from (x:rest))
                   s2 = maximum (map h_to   (x:rest))
               in x { q_from = q1, q_to = q2, h_from = s1, h_to = s2 }
merge [] = error "needs at least one blast hit to generate output"

showReg gds pts ks = concatMap (showTop' gds pts ks) . select_region . flatten . return
    where
      select_region [] = []
      select_region (x:xs) = let (this' ,next) = span ((query x==).query) xs
                                 (this'',next') = partition ((dir x==).dir) this'
                                 x' = merge (x:takeWhile ((subject x==).subject) this'')
                                 this = sortBy (compare `on` q_from) $ this''
                             in regions [x'] (q_from x', q_to x') this ++ select_region (next'++next)

      regions :: [BlastFlat] -> (Int,Int) -> [BlastFlat] -> [[BlastFlat]]
      regions x (a,b) (y:ys) = if overlap (a,b) y then regions (y:x) (min a (q_from y), max b (q_to y)) ys
                               else reverse x : regions [y] (q_from y, q_to y) ys
      regions x (_,_) []     = [reverse x]

      overlap (a,b) s = q_from s < (a+b) `div` 2 || q_to s < b + (a+b) `div` 2
      f `on` g  = \x y -> f (g x) (g y)

-- | generate one record (i.e. line of output)
showFlat :: BlastFlat -> [String]
showFlat bf  = [ head $ words $ unpack $ query bf
               , show $ q_from bf, show $ q_to bf
               , head $ words $ unpack $ subject bf
               , unpack $ B.dropWhile (==' ') $ B.dropWhile (/=' ') $ subject bf
               , show $ h_from bf, show $ h_to bf
               , show (fst . identity $ bf), show (bits bf), show (e_val bf)
               , dir bf
               ]

dir :: BlastFlat -> String
dir bf = case aux bf of Strands p p' -> if p==p' then "Fwd" else "Rev"
                        Frame d _ -> if d==Minus then "Rev" else "Fwd"

-- | Render KEGG inforation from a set of UniProtAcc's.
showKegg :: KoAnnotations -> [BlastFlat] -> String
showKegg ks fs = unwords $ map show $ concatMap (flip M.lookup ks) $ map chop $ map subject fs

-- | Render GO information from a set of UniProtAcc's.
showGo :: GoDefinitions -> GoAnnotations -> [BlastFlat] -> String
showGo gds pts = showGD . nub . concatMap (extractGOs pts) . map subject
    where
      showGD = concatMap (\gt -> show gt ++ showGT (liftM fst $ M.lookup gt gds))
      showGT (Just (GoDef _ str cls)) = " ("++show cls++": "++unpack str++") "
      showGT Nothing                   = " "

-- | Simlar to "ShowGo", but retain bit scores and extract hierarchical information by
--   recursive lookups.
showGoHier :: GoDefinitions -> GoAnnotations -> [BlastFlat] -> String
showGoHier gds pts = showGD . sortBy (compare `on` (negate . snd)) . M.toList . M.unionsWith max . map getAll
    where
      getAll :: BlastFlat -> Map GoTerm Double
      getAll b = let (s,x) = (subject b, bits b)
                     gs    = nub $ extractGOs pts s
                 in M.fromList $ zip (concatMap (collect gds) $ gs) (repeat x)
      showGD = concatMap (\(gt,x) -> show gt ++ printf " [%.1f]" x ++ showGT (liftM fst $ M.lookup gt gds))
      showGT (Just (GoDef _ str cls)) = " ("++show cls++": "++unpack str++") "
      showGT Nothing                   = " "

      on f g x y = f (g x) (g y)

-- recursively extract all GoTerms from a starting point
collect :: GoDefinitions -> GoTerm -> [GoTerm]
collect gds t = t : case M.lookup t gds of
                      Just (_,[]) -> []
                      Nothing      -> []
                      Just (_,ts) -> concatMap (collect gds) ts

-- convert a UniProtAcc into a list of GoTerms
extractGOs :: GoAnnotations -> UniProtAcc -> [GoTerm] 
extractGOs pts = map (GO . fromIntegral) . maybe [] elems . flip M.lookup pts . chop
