{- | Link sequences (read from the xml file) to desired annotations 
     (Currently support GO, todo support KEGG)
 -}
module Annotate where

import Data.ByteString.Lazy.Char8 (ByteString,unpack,pack)
import qualified Data.ByteString.Lazy.Char8 as B
import Data.Map  (Map)
import qualified Data.Map as M
import Data.Set  (Set)
import qualified Data.Set as S
import Data.List (nub)
import Data.Maybe

import Control.Monad
import Control.Arrow (second)

import Bio.Sequence.GeneOntology
import Bio.Sequence.KEGG
import qualified Bio.Alignment.BlastData as BD
import Bio.Alignment.BlastData (BlastResult)
import Bio.Alignment.BlastFlat 
import Bio.Alignment.BlastXML as X -- new BlastFlat exports this too
import Data.Array.Unboxed
import Data.Word

import Options

type GoDefinitions = Map GoTerm (GoDef,[GoTerm])
type GoAnnotations = Map UniProtAcc (UArray Int Word16)
type KoAnnotations = Map UniProtAcc KO

{- | Read GO annotations using GOA.
   
   1. build a map of GO terms to description (25K lines - tiny)
   2. collect all proteins names from the XML file (1 pass over .xml)
   3. build a map of proteins to GO terms (1 pass over gene_assoc..), and add GO term desc
-}
read_go :: Options -> IO (GoDefinitions, GoAnnotations)
read_go opts  = do  -- (Opts (Just go) (Just goa),[xml],[])
  -- goterms are not necessary when we don't also use goanns
  case goann opts of
    Nothing -> return (M.empty,M.empty)
    Just goa -> do
      ts <- case godef opts of
              Nothing -> return M.empty
              Just obo -> do
                    t <- return . goTerms =<< (count opts) 1000 "reading GO terms: "
                           =<< readObo obo
                    when (isNothing $ M.lookup (GO 0) t) $ return () -- strictness hack
                    return t

      proteins <- getProts opts
      when (S.member (B.empty) proteins == False) $ return () -- strictness
      mt <- readGOA goa >>= (count opts) 1000 "reading GO annotations: "
            >>= return .  protTerms proteins
      when (isNothing $ M.lookup B.empty mt) $ return () -- strictness, again
      return (ts,mt)
  

-- ----------------------------------------------------------------------
-- Build tables of annotation information
-- ----------------------------------------------------------------------

-- | Parse GO.terms_and_ids into a map
goTerms :: [(GoDef,[GoTerm])] -> GoDefinitions
goTerms = M.fromList . map go2keyval
    where go2keyval gd@(GoDef gt _ _,_) = (gt,gd)

-- | Read blast xml output and collect UniProt hits
getProts :: Options -> IO (Set UniProtAcc)
getProts opts = do
  hs <- mapM (\f -> (count opts) 100 ("reading accessions from '"++f++"': ")
                    . concatMap BD.hits . concatMap BD.results =<< X.readXML f) (inputs opts)
  return . S.fromList . map (chop . BD.subject) $ concat hs

-- | Convert "UniProt_ABCDEF Blah blah blah" to "ABCDEF"
--   Warning: only tested on this format, may break!
chop :: ByteString -> ByteString
chop = B.copy . B.drop 1 . B.dropWhile (/='_') . head . B.words

-- | Using GO definitions and the set of uniprot terms, scan the associations file and
--   snarf all relevant data.  Requires the annotations to be grouped by protein.
protTerms :: (Set UniProtAcc) -> [Annotation] -> GoAnnotations
protTerms ps as = M.fromAscList . map toArray . partitions . filter isMember $ as
    where toArray (Ann up (GO gt1) _:xs) = let gs = nub (gt1:[ gt | Ann _ (GO gt) _ <- xs ])
                                               a  = listArray (0,length gs-1) $ map fromIntegral gs
                                           in a `seq` (pack . unpack $ up,a)
          toArray [] = error "cannot read an empty list of annotations!"
          isMember (Ann up _ _) = up `S.member` ps
          partitions xs@(Ann x _ _:_) = let (p1,pps) = span (\(Ann y _ _) -> y==x) xs
                                        in p1:partitions pps
          partitions [] = []


{- | Read KEGG annotations 
   KEGG is specified as abbreviated organism prefixes. 
   Files to be read are $prefix_ko.list and $prefix_uniprot.list.
 -}
read_kegg :: Options -> IO KoAnnotations
read_kegg opts = M.unions `fmap` mapM build_kegg (ko opts)
    where -- build KEGG info for a single organism
          build_kegg org = do 
            k <- return . M.fromList =<<
                 (count opts) 1000 ("parsing KEGG information for '"++org++"'.. ")
                     =<< merge_files org
            when (isNothing $ M.lookup (B.pack "") k) $ return () -- strictness hack
            return k
          -- parse and merge files; they seem to be lexicographically sorted
          merge_files org = do
            g2u <- map (second decodeUP) `fmap` genReadKegg (org++"_uniprot.list") 
            g2k <- map (second decodeKO) `fmap` genReadKegg (org++"_ko.list")
            return $ merge g2u g2k

          -- merge sort on fst, returning snd            
          merge :: [(ByteString,UniProtAcc)] -> [(ByteString,KO)] -> [(UniProtAcc,KO)]
          merge (x:xs) (y:ys) | fst x < fst y = merge xs (y:ys)
                              | fst x > fst y = merge (x:xs) ys
                              | otherwise     = (snd x,snd y) : merge xs ys
          merge [] _ = []
          merge _ [] = []

