{-# OPTIONS_GHC   #-}
module Igor2.Data.CallDependencies( 

        CallDep, Call, 
        noCalls, addCall, annexFun, tryAddCall,
        admissible, allowedMaxCall, cycles, loops, annexCandidates,
        calledBy
        
       )where

import Control.Monad.Error
import Data.Maybe (maybeToList, fromJust, mapMaybe)
import Data.Function (on)

import Data.Map (Map)
import qualified Data.Map as M

import Data.List(unfoldr, (\\), deleteBy)

import Data.Graph.Inductive hiding (nodes)
--(
--    Graph, Gr, Node, LNode, LPath(..), LEdge, DynGraph,
--    lbft, noNodes, run_, insMapNodeM, 
--    insMapEdgeM, delMapEdgeM, grev, equal, labEdges, 
--    )
import qualified Data.Graph.Inductive as G

import Control.Arrow
--import Data.Graph

import Language.Haskell.TH hiding (match)
import Control.Monad (liftM2)
import Data.List (delete, maximumBy, sortBy, groupBy)
import Igor2.Logging
import Igor2.Ppr hiding (group)

{-
 | Computing the transitive closure of calling dependencies. 
 
   Basic Idea:
   If function 'f' calls function 'g', then 'f' depends on 'g' (f -> g). The 
   argument(s) of a call could either increase, decrease or remain in size, thus
   the the dependency could be of kind LT, EQ, GT (-LT->, -EQ->, -GT->).
   
   Calling dependencies are transitive, so if f -> g and g -> h then also 
   f -> h. The kind of the transitive dependency has the maximal type of all 
   compound dependencies with ordering LT < EQ < GT. 
   
   If already a calling dependency f -> g exists, the following possibilities 
   for g calling f are allowed:
   
   f -GT-> g  => g is not allowed to call f
   f -EQ-> g  => g -LT-> f
   f -LT-> g  => g -LT-> f or g -EQ-> f
   
   Dependencies are stored in a graph with labelled nodes and edges of type 
   (GR Name Ordering) from the "functional graph library" by Martin Erwig: 
       http://hackage.haskell.org/cgi-bin/hackage-scripts/package/fgl
       
   To retrieve a certain Node, the Name of the calling function and the graph-
   internal node number are stored in a Map, whereas a Node in a graph is simply
   an Int value, i.e. the nth Node has the value n.
  
-}          

-- Not nice but necessary. A constant to ceiling the number calling cycles (f -> b -> f) 
_MaxCycles :: Int
_MaxCycles = 5

type Call = (Name, Name, Ordering)

newtype Calls = Calls{ unC :: G.Gr Name Ordering} deriving (Eq)

instance Show Calls where
    show = show.unC
instance Pretty Calls where
    pretty = align . vcat. (map text) . (delete []) . lines .show
                         
data CallDep = CD
    { callings :: Calls
    , nodes :: M.Map Name Node
    }
    deriving(Show)

instance Eq CallDep where
    (==) c1 c2 = on (==) callings c1 c2 &&
                 on (==) (M.toAscList . nodes) c1 c2

instance Pretty CallDep where
    pretty = 
        (\d -> if (show d) == "" then text "<no_deps>" else d).pretty.callings
    
noCalls :: CallDep
noCalls = CD (Calls G.empty) M.empty

addCall :: Call -> CallDep -> CallDep
addCall c cd = either error id $ tryAddCall cd c
       
annexFun :: Name -> CallDep -> CallDep
annexFun nn cd = 
    let g   = (unC . callings $ cd)
        cd' = asNode nn (nodes cd) >>= \n ->
                return (composeAll (inn g n)(out g n)) >>= 
                return . (remFunUnsafe nn) . (flip addEdgesUnsafe cd)
    in maybe cd id cd'
    where
    composeAll :: [LEdge Ordering] -> [LEdge Ordering] -> [LEdge Ordering]
    composeAll as bs = foldl (\es e -> (map (flip compose e) as)++es) [] bs
    compose (s,_,l1)(_,g,l2) = (s,g,max l1 l2)
     
tryAddCall ::(Error e, MonadError e m) => CallDep -> Call -> m CallDep
tryAddCall cd c = 
    if admissible c cd 
       then return $ addCallUnsafe c cd 
       else throwError . strMsg $ "Call " ++ show c ++ " not admissible in " ++ show cd

{-
 | Returns 'True' if the Call is admissible in the given CallDependencies,
 | 'False' otherwise.
-}
admissible :: Call -> CallDep -> Bool
admissible c@(f,g,o) cd@(CD cs ns) 
    | f == g    = o == LT  -- direct recursive calls must be smaller
    | otherwise =  ((<= _MaxCycles) . cycles . (addCallUnsafe c) $ cd)
                     &&  
                   (and $ map (check o)  deps)
    where
    deps = case liftM2 (,)(asNode f ns) (asNode g ns) of 
        (Just (nf,ng)) -> dependencies ng nf (unC cs)
        Nothing       -> []
        
    --check :: Ordering -> (LNode Ordering) -> Bool
    check LT (_,LT) = True  -- f -LT-> g  => g -LT-> f  
    check LT (_,EQ) = True  -- f -LT-> g  => g -EQ-> f
    check EQ (_,LT) = True  -- f -EQ-> g  => g -LT-> f
    check _  _      = False -- f -GT-> g  => g -//-> f

cycles :: CallDep -> Int 
cycles = length . cyclesIn . unC . callings

-- counts the loops in the call graph, i.e., edges from a node to itself
loops :: CallDep -> Int 
loops = (ufold countLoops 0) . unC . callings 
    where
    countLoops c i = (i+) . length $ filter ((node' c)==) (suc' c) 

-- given a function anem, return the names of all functions, which _directly_
-- this function    
calledBy :: Name -> CallDep -> [Name]
calledBy n  cdp = 
    case asNode n (nodes cdp) of
         Just node -> mapMaybe (flip fromNode (callings cdp)) . map caller . filter ((node==) . callee) . G.labEdges . unC . callings $ cdp
         Nothing -> []
    where
      caller (c,_,_) = c
      callee (_,c,_) = c

{- | Computes all nodes in the call graph which may be annectable w.r.t. the 
     call graph, i.e. functions which do not call any other function or only 
     one other function. All functions/nodes with indegree of 0, i.e., those 
     which are not called by any other (the target functions) are removed.
-}
annexCandidates :: CallDep -> [Name]
annexCandidates cd  =
    let g        = unC $ callings cd
        ncs      = map (G.context g &&& id) (G.nodes g)
        outedges = map (first G.out') ncs
        outdeg0  = [ n | (o, n) <- outedges, null o ]
        outdeg1  = [ n | (o, n) <- outedges, case o of { [_] -> True; _ -> False } ]
        indeg0   = [ n | (c, n) <- ncs, null (G.inn' c) ]
        replcbls = foldr delete (outdeg0 ++ outdeg1) indeg0
        -- we need to remove the target functions, which are the only ones with
        -- indegree 0
    in mapMaybe (flip fromNode (callings cd)) replcbls

allowedMaxCall :: Name
                -- ^the name of the calling function
                -> CallDep              
                -- ^existing call dependencies 
                -> Map Name (Maybe Ordering)
                -- ^ name 'n' and ordering 'o' pairs, 'n' is the name of a 
                --   function in the context of the given 'CalLDep' and 'o' is
                --   its the maximal allowed change of the argument size. For
                --   'o' values following chnages are allowed: 
                --      Just GT --> GT, EQ, LT
                --      Just EQ --> EQ, LT
                --      Just LT --> LT
                --      Nothing --> not allowed to call
allowedMaxCall n (CD cs ns) =
    M.fromList $ concatMap (buildAllowed ns) (secure revdeps)
    where
    revdeps = compress.concat.maybeToList $ 
                liftM (flip pathsFrom  (grev.unC $ cs))
                                       (asNode n ns)
    -- get the reversed call dependencies: (f,GT) means n -GT-> f
    secure = map (maximumBy (compare `on` snd)) .
             groupBy ((==) `on` fst) .
             sortBy (compare `on` fst)
    buildAllowed ns (nd, o) = case (fromNode nd cs, o) of
        (Just n, GT)  -> [(n, Nothing)]
        (Just n, EQ)  -> [(n, Just LT)]
        (Just n, LT)  -> [(n, Just EQ)]
        (Nothing, _)  -> [] 

{-
 | Get a the node by a Name (if it is in the Map
-}
asNode :: Name -> M.Map Name Node -> Maybe Node
asNode n ns = M.lookup n ns

{- 
 | Get a the Name by a node 
-}
fromNode :: Node -> Calls -> Maybe Name
fromNode n (Calls g) = G.lab g n
{-
 |- Unsafe functions to modify calling dependencies
-}

-- adds a call only, iff the call does not exist yet, or it is greater then any 
-- other (possible transitive) call  
addCallUnsafe :: Call -> CallDep -> CallDep
addCallUnsafe (f1, f2, l) cd =
    let (n1, cd')              = addFunUnsafe f1 cd
        (n2, CD (Calls cs) ns) = addFunUnsafe f2 cd'
        cs' = G.insEdge (n1, n2, l) $ G.delEdge (n1, n2) cs
    in CD (Calls cs') ns
    
-- add a calling function as a node if it is not already present and associate 
-- the node number with the name for later retrieval
addFunUnsafe :: Name -> CallDep -> (G.Node, CallDep)
addFunUnsafe n cd@(CD (Calls cs) ns) = 
    case M.lookup n ns of
      Just n' -> (n', cd)
      Nothing -> let [n'] = G.newNodes 1 cs
                     cs'  = G.insNode (n', n) cs
                 in (n', CD (Calls cs') (M.insert n n' ns))

remFunUnsafe :: Name -> CallDep -> CallDep
remFunUnsafe n cd@(CD (Calls cs) ns) =
    if not $ M.member n ns
      then cd
      else let cs' = maybe cs id $ liftM (flip G.delNode cs) (asNode n ns)
               ns' = M.delete n ns
           in CD (Calls cs') ns'

addEdgesUnsafe :: [LEdge Ordering] -> CallDep -> CallDep           
addEdgesUnsafe es (CD (Calls cs) ns) = CD (Calls (insEdges es cs)) ns

--------------------------------------------------------------------------------
-- stolen from graphalyze 
--------------------------------------------------------------------------------
type LNGroup a = [LNode a]
type NGroup = [Node]
 -- | Find all cycles in the given graph.
cyclesIn   :: (DynGraph g) => g a b -> [LNGroup a]
cyclesIn g = map (addLabels g) (cyclesIn' g)

-- | Find all cycles in the given graph, returning just the nodes.
cyclesIn' :: (DynGraph g) => g a b -> [NGroup]
cyclesIn' = concat . unfoldr findCycles

-- | Find all cycles containing a chosen node.
findCycles :: (DynGraph g) => g a b -> Maybe ([NGroup], g a b)
findCycles g
    | isEmpty g = Nothing
    | otherwise = Just . getCycles . matchAny $ g
    where
      getCycles (ctx,g') = (cyclesFor (ctx, g'), g')
      
 -- | Find all cycles for the given node.
cyclesFor :: (DynGraph g) => GDecomp g a b -> [NGroup]
cyclesFor = map init .
            filter isCycle .
            pathTree .
            first Just
    where
      isCycle p = (not $ single p) && ((head p) == (last p))
      
-- | Obtain the labels for a list of 'Node's.
--   It is assumed that each 'Node' is indeed present in the given graph.
addLabels    :: (Graph g) => g a b -> [Node] -> [LNode a]
addLabels gr ns = map (\n -> (n, fromJust $ lab gr n)) ns

 -- | Return true if and only if the list contains a single element.
single     :: [a] -> Bool
single [_] = True
single  _  = False

-- | Find all possible paths from this given node, avoiding loops,
--   cycles, etc.
pathTree             :: (DynGraph g) => Decomp g a b -> [NGroup]
pathTree (Nothing,_) = []
pathTree (Just ct,g)
    | isEmpty g = []
    | null sucs = [[n]]
    | otherwise = (:) [n] . map (n:) . concatMap (subPathTree g') $ sucs
    where
      n = node' ct
      sucs = suc' ct
      -- Avoid infinite loops by not letting it continue any further
      ct' = makeLeaf ct
      g' = ct' & g
      subPathTree gr n' = pathTree $ match n' gr

-- | Remove all outgoing edges
makeLeaf           :: Context a b -> Context a b
makeLeaf (p,n,a,_) = (p', n, a, [])
    where
      -- Ensure there isn't an edge (n,n)
      p' = filter (\(_,n') -> n' /= n) p

-- -----------------------------------------------------------------------------           
-- -----------------------------------------------------------------------------


-- HACKS
-- fgl is rather undocumented and the manual is outdated, so I often only got 
-- some notion of what a function does, by 'trial and error'

-- This function returns all Nodes reachable from the starting Node, with 
-- accumulated edge labels as label
dependencies :: (Ord a, Graph gr) => Node -> Node -> gr a1 a -> [(Node, a)]
dependencies n1 n2 g = compress $ pathsFromTo n1 n2 g


-- compute the transitive closure of the dependencies 
compress  :: (Ord a) => [[LNode a]] -> [LNode a]
compress [] = []
compress (x:xs) = map (foldl1 accumulate) (x:xs)

-- merge two LNodes by by maxing the Ordering, keeping the left label
accumulate :: Ord a => (LNode a) -> (LNode a) -> (LNode a) 
accumulate (n,l1) (_,l2) = (n, max l1 l2)
         

-- get all nodes starting in 'fn' and filter all paths which do not end on 'tn'
pathsFromTo :: (Graph gr) => Node -> Node -> gr a l -> [[LNode l]]
pathsFromTo fn tn g = filter (endsin tn) $ pathsFrom fn g
    where
    endsin _ []          = False -- should however, never be the case
    endsin n ((n',_):_) = n == n'
    

-- lbft seems to compute the 'labelled breadth first root tree' from a starting 
-- node, and the result is a list of labelled paths from each reachable node to
-- the starting node. This is what pathsFrom is supposed to do (after some more hacking) 
pathsFrom :: (Graph gr) => Node -> gr a l -> [[LNode l]]
pathsFrom n g = unPath $ lbft n g 

unLP :: LPath t -> [LNode t]
unLP (LP a) = a

-- The (LRTRee l) (aka [LPath l]) contains always the starting node and the end 
-- of each path
unPath :: [LPath l] -> [[LNode l]]
unPath = tail.map (init.unLP)


-- Testing
f = mkName "f"
g = mkName "g"
h = mkName "h"
i = mkName "i"
j = mkName "j"

c1 = addCall (f,j,LT)$ addCall (j,f,LT)$ addCall (g,f,LT)$ addCall (h,i,GT) $ addCall (g,h,EQ) $ addCall (f,g,LT) $ noCalls

