module Igor2.RuleDevelopment.Matching where

import Prelude hiding ((<$>))

import qualified Data.Map as M
import Data.List hiding (group)
import Data.Function (on)
import Data.Maybe (catMaybes, isJust, fromJust)
import Control.Monad
import Control.Monad.Error
import Control.Monad.Trans
import Control.Arrow ((***))

import Syntax hiding (merge)
import qualified Syntax.Class.Subst as Subst (merge)
import Igor2.Ppr
import Igor2.Config
import Igor2.Logging
import qualified Igor2.Data.Rules as Rules
import Igor2.Data.IgorMonad
import Igor2.Data.CallDependencies

import Utils.BranchAndBound (bnb, Bag)
import qualified Utils.BranchAndBound as Bag (empty, insert)
import qualified Utils.BranchAndBound as BnB (
    NodeValue(LowerBound, Solution),
    NodeOps(NodeOps, keyOf, children, valueOf))

introduceMatchings :: CallDep -> CovrRule -> IM [(CovrRules,[Call])]
introduceMatchings cd cr = do
    bknames  <- background
    tgtnames <- targets
    p <- usePara
    let scope     = allowedMaxCall (name cr) $ cd
    -- scope of current hypo, sub function calls, and previous background calls
    let bgkcalls  = tag (Just GT) bknames
    -- calls to background knowledge
    let tgtcalls  = tag (Just LT) ((name cr):tgtnames)
    -- self call and calls to prallel synthesized targets
    -- let paracalls = if p then tag (Just LT) $ calledBy (name cr) cd else []
    -- if we use paramorphisms, the mediating function should not match against their caller, because thios is trivial.
    let allcalls  =  M.toList $ insertAll scope (bgkcalls ++ tgtcalls)
    -- bkgcalls, tgtcalls, and paracalls may overwrite calls in scope, as well as each other

    waypointS $ text "Computing matchings"
    logIN ( text "Try for:" <^> pretty cr)
    logDE ( text "Call Dependencies:" <^> pretty cd )
    logDE ( text "scope:" <^> pretty scope )
    logDE ( text "bgkcalls:" <^> pretty bgkcalls )
    logDE ( text "tgtcalls:" <^> pretty tgtcalls )
    logDE ( text "Allowed calls:" <^> pretty allcalls)
    liftM concat $ mapM (tryCall cr) allcalls

    where
      insertAll = foldl (flip $ uncurry M.insert)
      tag t = map (flip (,) t)
-- Try to compute a call in 'CovrRule' to funtion 'Name' with call relation
-- 'Ordering'
tryCall :: CovrRule -> (Name, Maybe Ordering) -> IM [(CovrRules,[Call])]
tryCall _ (n, Nothing) =
    logDE (indent 2 $ text "(-) Not allowed to call" <+> (squotes.text.show $ n)) >>
    return [] -- no call allowed
tryCall cr (n, Just o) = do
     dcp    <- directCalls cr (n,o)
     if not (null dcp)
        then do logIN (linebreak <>
                        text "Direct call to" <+> (squotes.pretty $ n) <+>
                        text "possible!")                
                liftM (:[]) $ makeDirectCall cr n (minimumBy (compare `on` fst) dcp) -- take the smallest callrel
        else do indirectCalls cr (n,o)

{-
    All ArgLists have the same length (during one call from indirectCalls),
    namely the number of arguments of the target function (callee); the ith
    list entry always corresponds to the ith argument.

    A FunList is a collection of io examples which will be used to create a new
    function via addIO (in mkIndirectCall).
-}
type ArgList a = [a]
type FunList = [Rule]

{-| Compute all indirect calls from the provided 'CovrRule' to the function
    'Name', i.e. where the arguments for the call have to be synthesised using
    subfunctions.
-}
indirectCalls :: CovrRule -> (Name, Ordering) -> IM [(CovrRules,[Call])]
indirectCalls cr (n,o) = do
     cllrs  <- breakupM $ cr
     tgtrs  <- liftM (getAll n) getEvidence

     logIN (linebreak <>
             text "No direct call to" <+> (squotes.pretty $ n) <+>
             text "possible!" <+>
             text "Need to compute full matchings (C vs T)")
     logDE (align (text "Caller:" <+> pretty cllrs) <$>
             align (text "Target:" <+> pretty tgtrs))

     -- This is important. due to polymorphism functions like 'last' in
     -- background knowledge would match everything. For those functions only
     -- direct match should be allowed.
     if all (isVar . rhs . crul) cllrs || all (isVar . rhs . crul) tgtrs
         then stop
         else do ios <- makeIOMatrix cllrs o tgtrs
                 if any null ios
                     then cancel
                     else proceed ios

     where
         cancel    = do logIN (text "Insufficient matchings!" <+>
                               text "At least one argument had no I/Os at all!")
                        return []
         proceed m = do logDE (text "Matchings" <$> pretty (map (map pretty) m))
                        greedyMatch <- doGreedyMtch
                        context <- gets igor_ctx
                        let matchings = if greedyMatch
                                            then bestMatchings context m
                                            else allMatchings m
                        mapM (mkIndirectCall cr n) matchings
         -- sequence computes all possibilities taking one from each list
         stop      = do logIN (text "Caller or Target has variable on rhs." <+>
                               text "Would compute too many matchings.")
                        return []

-- TODO: Maybe replace this by (Monad m, MonadError m e) => m
type MatchingErrorMonad = Either String

{-| Computes all possible matchings from an IO matrix, where each inner list of
    type [(Ordering, [Rule])] corresponds to one caller IO example. In the result,
    each inner list of type [[Rule]] contains exactly one element of each previous
    inner list - i.e., it's the cross-product of all inner lists.
-}
allMatchings :: [[(Ordering, ArgList Rule)]] -> [(Ordering, ArgList (Either FunList TExp))]
allMatchings = map (\ios -> (maximum $ map fst ios, map Left . transpose $ map snd ios)) . sequence

{-| Works like 'allMatchings', but returns only those list entries with
    a maximum number of closed lggs (of the inner 'FunList's).
-}
bestMatchings :: Context -> [[(Ordering, ArgList Rule)]] -> [(Ordering, ArgList (Either FunList TExp))]
bestMatchings context ioMatrix = bestMatchingsBnb context ioMatrix

data MatchingBnbNode =
    BnbNode
        { nodeHistory :: (Ordering, ArgList FunList)
        , nodeLggs :: (ArgList Rule)
        , residualIOMatrix :: [[(Ordering, ArgList Rule)]]
        }

{-
    Either-usage:
        - Left: open rule, list of IO examples
        - Right: closed rule, rhs corresponding to the caller's lhs
-}
bestMatchingsBnb :: Context -> [[(Ordering, ArgList Rule)]] -> [(Ordering, ArgList (Either FunList TExp))]
bestMatchingsBnb context ioMatrix = let
        keyOfBnbNode :: MatchingBnbNode -> Int
        keyOfBnbNode (BnbNode { nodeHistory = hist, nodeLggs = lggs }) =
            length $ filter Rules.isOpen lggs

        childrenOfBnbNode :: MatchingBnbNode -> [MatchingBnbNode]
        childrenOfBnbNode (BnbNode
            { nodeHistory = hist
            , nodeLggs = lggs
            , residualIOMatrix = (xs:xxs)
            }) =
                map (\(o, rs) ->
                    let cmInnerNode = do
                            let nhist = (max o *** multiCons rs) hist
                            nlggs <- nextLggs lggs rs
                            return (BnbNode
                                { nodeHistory = nhist
                                , nodeLggs = nlggs
                                , residualIOMatrix = xxs
                                })
                    in case cmInnerNode `withC` context of
                        Left err -> error err
                        Right x -> x
                ) xs

        valueOfBnbNode :: MatchingBnbNode -> BnB.NodeValue Int
        valueOfBnbNode x@(BnbNode { residualIOMatrix = []  }) = BnB.Solution   (keyOfBnbNode x)
        valueOfBnbNode x@(BnbNode { residualIOMatrix = _:_ }) = BnB.LowerBound (keyOfBnbNode x)

        bnbNodeOps :: BnB.NodeOps Int Int MatchingBnbNode
        bnbNodeOps = BnB.NodeOps
            { BnB.keyOf = keyOfBnbNode
            , BnB.children = childrenOfBnbNode
            , BnB.valueOf = valueOfBnbNode
            }

        initialNode = BnbNode
            { nodeHistory = (LT, [])
            , nodeLggs = []
            , residualIOMatrix = ioMatrix
            }

        initialQueue :: Bag Int MatchingBnbNode
        initialQueue = Bag.insert (keyOfBnbNode initialNode, initialNode) Bag.empty

        solutions :: [MatchingBnbNode]
        solutions = bnb bnbNodeOps initialQueue

        saveClosedLggs :: ArgList FunList -> ArgList Rule -> ArgList (Either FunList TExp)
        saveClosedLggs = zipWith (\ios lgg -> if Rules.isClosed lgg then Right (rhs lgg) else Left ios)

    -- bnb must not return nodes where the last argument is non-empty, because
    -- those are inner nodes rather than leafs.
    in [ (ord, saveClosedLggs rules lggs)
       | BnbNode { nodeHistory = (ord, rules)
                 , nodeLggs = lggs
                 , residualIOMatrix = []
                 }
            <- solutions
       ]

{-| multiCons is a helper function for bestMatchingsBnb.

    It works nearly like zipWith (:), except when the two arguments don't have
    the same length. The second argument is not allowed to be longer (this will
    result in a pattern match failure). If the first argument is longer, the
    second is padded with [].

    This is only intended for the case where the second argument is [].
    Otherwise, both lists should always have the same length.
-}
multiCons :: [a] -> [[a]] -> [[a]]
multiCons []     []     = []
multiCons (x:xs) []     = [x]   : multiCons xs []
multiCons (x:xs) (y:ys) = (x:y) : multiCons xs ys

{-| nextLggs is a helper function for bestMatchingsBnb.

    As the type anti-unifier doesn't handle "forall a. a" as a neutral element
    during anti-unification, bestMatchingsBnb just passes an empty list of lggs,
    and nextLggs handles this case by just returning its second argument.
-}
nextLggs :: ArgList Rule -> ArgList Rule -> C MatchingErrorMonad (ArgList Rule)
nextLggs []   rs = return rs
nextLggs lggs rs = zipWithM (\x y -> lggRules [x, y]) rs lggs

concatMapM :: (Monad m) => (a -> m [b]) -> [a] -> m [b]
concatMapM f = liftM concat . mapM f

{-| Compute all direct calls from the provided 'CovrRule' to the function 'Name'
    i.e. a call to function 'Name' where the arguments for the call cab directly
    be constructed from the lhs pattern of the caller.
-}
directCalls :: CovrRule -> (Name,Ordering) -> IM [(Ordering,LHS)]
directCalls cr (n,maxcall) = do

     cvdIOs    <- breakupM $ cr
     allIOs   <- liftM (getAll n) getEvidence
     -- get coverd IOs
     let scio  =  minimumBy (compare `on` (size.lhs.crul)) cvdIOs
     -- get the smallest (by size on lhs) covered IO example

     comp     <- getPatComparison
     let tgtrs = admissibleIOs comp scio allIOs
     -- get for the smallest covered example all admissible target IOs w.r.t
     -- the given maximal call relation
     tgtlhs   <- mkNormalPats scio tgtrs
     -- the LHSs of admissible target rules, renamed w.r.t. the smallest IO
     -- variable renaming
     subs     <- lift . lift $ (matchesLhs `on` crul) scio cr
                `catchError`  -- should not happen here
                (\_e -> return nullSubst)
     -- get the substitution with which the smallest IO matches the pattern of
     -- its covering rule

     let subs' = map (\(n,t) -> (toVar t n,t)) (assocs subs)
     -- simple conversion for our convenience (names to variables)
     let pats =  (map . buildPat) subs' tgtlhs
     -- generate all patterns
     cpats  <-  filterM (testPat n cvdIOs (lhs.crul $ cr)) pats
     -- correct tested patterns
     return $ map ((,) =<< flip comp (lhs.crul $ cr)) $ cpats
     where
     admissibleIOs cmp io = filter $ (maxcall >=) . flip (on cmp (lhs.crul)) io
     mkNormalPats io   = lift . fmap catMaybes . mapM (on ((lift .) . normalPat) crul io)

testPat :: Name -> [CovrRule] -> LHS -> LHS -> IM Bool
testPat n cllrs lp rps = do
    let rs = map (rule lp) rps
    liftM and $ mapM (check rs . crul) cllrs
    where
    check rs cr = do
        rside <- applyC $ mapM (matchEval . lhs $ cr) rs :: IM (Either String [TExp])
        rside' <- either (return . const Nothing) (evalIO n) rside
        return $ maybe False (\r -> not (anySubterm isWild r) && (r == rhs cr)) rside'
{- In principle, the condition should be just r == rhs cr, but (==) is not
 - transitive, because for TExp it becomes equal. Since we need transitivity
 - here, we make it transitive by additionally requiring the absense for
 - wildcards in r. -}

buildPat subs tgtlhss =  map buildPatArg tgtlhss
    where
    buildPatArg tgtlhs =
        fromJust $ mplus (liftM fst $ find ((tgtlhs==).snd) subs)
                         -- either the pattern is already in the substitution
                         -- then take it
                         (Just $ (root tgtlhs $ buildPat subs (subterms tgtlhs)))
                         -- or (if not) we keep the top ctor symbol and apply
                         -- 'buildPat' recursively to the subterms

{-| @normalPat r1 r2@ returns the lhs of 'r2' after applying the the subsitution
     resulting from matching the rhs of 'r1' against the rhs of 'r2'. This is
     kind of a normalized lhs pattern of 'r2' w.r.t. 'r1' (variable renaming)
-}
normalPat :: Monad m => Rule -> Rule -> C m (Maybe [TExp])
normalPat cll tgt =
    liftM (Just . flip applyL (lhs tgt)) (matchesRhs cll tgt)
      `safeCatchErrorC` \_ -> return Nothing

mapBothM :: Monad m => (a -> m a') -> (b -> m b') -> Either a b -> m (Either a' b')
mapBothM f _ (Left x)  = liftM Left (f x)
mapBothM _ f (Right x) = liftM Right (f x)

mapLeftM :: Monad m => (a -> m a') -> Either a b -> m (Either a' b)
mapLeftM f (Left x)  = liftM Left (f x)
mapLeftM _ (Right x) = return (Right x)

leftToMaybe :: Either a b -> Maybe a
leftToMaybe (Left x) = Just x
leftToMaybe _        = Nothing

catLefts :: [Either a b] -> [a]
catLefts  = catMaybes . map leftToMaybe

mkIndirectCall :: CovrRule -> Name -> (Ordering, ArgList (Either FunList TExp)) -> IM (CovrRules, [Call])
mkIndirectCall cr tgtn ios = do
    -- subargsio contains a list of rules for each argument of tgtn, one rule
    -- (in each inner list) for each of cr's IO examples.
    let subargsio = snd ios
    -- a subfn is either the name of a newly added function, or a closed term
    subfns        <- mapM (mapLeftM (addIO . rules)) subargsio :: IM [Either Name TExp]
    let subftys   = map (typeOf . either (rhs . head) id) subargsio
    let subcalls  = map (\(r, ty) -> either (\n -> mkCall n ty (lhs $ crul cr)) id r) (zip subfns subftys)
    let cr'       = modifycrul cr $ mkCallAt (Body Root) tgtn subcalls
    subinis       <- mapM coverAll $ catLefts subfns
    let call      = ((name cr), tgtn, fst ios)                  -- call to target
    let calls     = map (\n -> (tgtn, n, EQ)) $ catLefts subfns -- auxiliary calls
    logDE (text "Call possible:" <^> pretty cr' <+> pretty call)
    return (covrRules (cr':subinis),call:calls)


makeDirectCall :: CovrRule -> Name -> (Ordering,LHS) -> IM (CovrRules,[Call])
makeDirectCall cr tgtn (o,pat) = do
    let cr'  = modifycrul cr $ mkCallAt (Body Root) tgtn pat
    let call = (name cr, tgtn, o)
    logDE (text "Call possible:" <^> pretty cr' <+> pretty call)
    return (covrRules [cr'],[call])

{- | For n rules of the target (callee) function (t = [t1 .. tn]) and m rules of
     the caller function (c =  [c1 .. cn]), the cross-product ( t * c) is
     generated as a list of columns

        [[c1t1  [c2t1  ... [cnt1
         ,c1t2  ,c2t2  ... ,cnt2
         ,...   ,...   ... ,...
         ,c1tn] ,c2tn] ... ,cntn]]

     where each 'citj' is a list of rules resulting from @abduceIO ti o cj@.

     The where clause is for documentation purposes only (obviously).
-}
makeIOMatrix :: [CovrRule] -> Ordering -> [CovrRule] -> IM [[(Ordering, ArgList Rule)]]
makeIOMatrix cllrs o tgtrs = do
    patcomp <- getPatComparison
    lift $ sequence [ liftM catMaybes $ sequence [abduceIO patcomp c o t | t <- tgtrs ] | c <- cllrs]
        where sequence :: [LM a] -> LM [a]
              sequence = Control.Monad.sequence

-- | @abduceIO cll o tgt@ abduces one IO pair for each argument of 'tgt' if
--   admissible. It is not admissible to call 'tgt' from 'cll' if the difference
--   in size of the lhss of 'tgt' and 'cll' is greater than the max difference
--   'o'. Each IO pair will be used in a different function (added in
--   mkIndirectCall), each of which computes the respective argument of the
--   function tgt belongs to.
abduceIO :: (LHS -> LHS -> Ordering) -> CovrRule -> Ordering -> CovrRule -> LM (Maybe (Ordering, ArgList Rule))
abduceIO comp cll maxcallrel tgt = do
    let callrel = on comp (lhs . crul) tgt cll
    if maxcallrel < callrel then
       logDE (indent 2 $
               text "(-) Discarded Match" <+> pretty tgt <+>
               text (show callrel) <+> pretty cll <+> text "not allowed" <+>
               text (show callrel)) >>
       return Nothing
    -- check whether matching is allowed, not allowed if maxcallrel < callrel
       else do s <- lift $ (liftM Just $ (matchesRhs `on` crul) cll tgt) `catchError`
                          (\e -> return Nothing)
              -- match on rhss
               logDE (indent 2 $
                       text (if isJust s then "(+)" else "(-)") <+>
                       text "Try       Match" <+> pretty tgt <+>
                       text (show callrel) <+> pretty cll <+>
                       text "allowed" <+> text (show maxcallrel) <+>
                       text "Match?" <+> (bool $ isJust s))
               case s of
                Nothing  -> return Nothing
                Just s   -> do let sameVarName (TVarE i1 _) (TVarE i2 _) = i1 == i2
                                   sameVarName _            _            = False
                               let tgtvars = nub $ concatMap getVars (lhs (crul tgt))
                               let bthvars = nub (getVars (rhs (crul cll))) `intersect` nub (getVars (rhs (crul tgt)))
                               let unaffectedvars = deleteFirstsBy sameVarName (tgtvars \\ bthvars) $ map (\(n,t) -> toVar t n) (assocs s)
                               -- TODO: the following line may be slow, because merge is linear already
                               let Right s' = foldM Subst.merge nullSubst (s : [ n <~ TWildE n t | TVarE n t <- unaffectedvars]) :: Either String (Subst TExp)
                               -- replace all vars not in the substitution
                               -- and those which are in both rhss by wildcards
                               let lhss' = lhs.crul $ cll
                               let rhss' = map (apply s') (lhs.crul $ tgt)
                               -- new rhss are the substituted lhss of tgt
                               return $ Just (callrel, map (rule lhss') rhss')
              -- return the calling relation and a list of rules, i.e. one IO
              -- pair for each argument of the called function
