{-# OPTIONS_GHC #-}
module Igor2.Data.Rules (
    
    Rule, rule, rhs, lhs, ithArg, butIthArg, butIthArgs,
    OpenPos,
    RulePos(..), ruleVarPos, sameSymAt,
    hasHO, freeVars, hasFreeVars, openPositions, lggRules, matchLhss,
    hasVarAt, hasCtorAt, matchesLhs, matchesRhs, matchEval, matchEvals,
    Rules, mkRules, rules, subrule, mkCallAt,
    mkCall, buildCall, doesCallTo, replaceInAll,
    isOpen, isClosed,

    insertM, deleteM, replaceM,
    LHS, RHS, hypos2decs,

    mapRuleRhsM, mapRuleLhsM, mapRuleM,

    module Syntax.Specification,
    -- hypos2decs, rules2decs,
    -- tPat2TExp,
    
--    module Syntax.Terms,    
--    module Syntax.Unifier       
    
    )where

import Prelude
 
import Data.List ( foldl', transpose, (\\), nub, sortBy, sort, partition
                 , isPrefixOf, deleteBy)
import qualified Data.List as L

import Data.Set (Set)
import Data.MySet (deleteM, insertM, replaceM)
import qualified Data.Set as S ( empty, fromList, toList, toAscList, map, null
                               , insert, delete, size)

import Data.Maybe
import Data.Function (on)
 
import Data.Util 

import Syntax hiding (sameSymAt, mapTermM, mapTerm)
import qualified Syntax as T (sameSymAt, mapTermM, mapTerm)
import Control.Monad.Error

import qualified Language.Haskell.TH as TH

import qualified Control.Monad.State as CMState
import qualified Data.Set

import Igor2.Logging
import Syntax.Specification (Equation(..), FunBind(..), mkEq, mkFB, fName)
import Igor2.Ppr

--------------------------------------------------------------------------------
-- Datatype Rule
--------------------------------------------------------------------------------


type LHS = [TExp]
type RHS = TExp

mkRule :: Equation -> Rule
mkRule (UnGuardEq lhs rhs) = rule lhs $ T.mapTerm insertWildcard rhs
    where
        insertWildcard (TVarE name type') | "undefined" == (TH.nameBase name) =
            TWildE name type'
        insertWildcard t = t

mkRules :: [Equation] -> Rules
mkRules = rules . (map mkRule)

ithArg i = (!!i) . lhs 
butIthArg i   = butIthArgs [i] 
butIthArgs is = (\l -> l \\ (map (l!!) is)) . lhs 

-- | Rule with hidden data constructor
data Rule = R {lhs :: LHS, rhs :: RHS} deriving (Eq, Ord)

instance Pretty Rule where
    pretty r = pretty $ mkEq (lhs r)(rhs r)
    -- Need the TH.PprLib to get the correct variable names on lhs and rhs

instance Show Rule where
    show r = (show . lhs $ r) ++ " = " ++ (show . rhs $ r)
        
instance Typed Rule where
    typeOf r = arrowT $ (map typeOf (lhs r))++[typeOf . rhs $ r]    
        
-- | Rule constructor
rule :: LHS -> TExp -> Rule
rule = R

mapRuleRhsM :: Monad m => (TExp -> m TExp) -> Rule -> m Rule
mapRuleRhsM f r@(R {rhs = rhs}) = T.mapTermM f rhs
    >>= \new_rhs -> return r { rhs = new_rhs }

mapRuleLhsM :: Monad m => (TExp -> m TExp) -> Rule -> m Rule
mapRuleLhsM f r@(R {lhs = lhs}) = mapM (T.mapTermM f) lhs
    >>= \new_lhs -> return r { lhs = new_lhs }

mapRuleM :: Monad m => (TExp -> m TExp) -> Rule -> m Rule
mapRuleM f r = mapRuleLhsM f r >>= mapRuleRhsM f

-- TODO uneccessray
data RulePos = Arg Int Position -- zero-based
             | Body Position
             deriving (Eq, Ord,Show)

instance Pretty RulePos where
    pretty (Arg i pos) = text "Arg" <> int i <+> pretty pos 
    pretty (Body pos)  = text "Body" <+> pretty pos 

-- | An open position is defined by an expression (which should be a variable)
--   and a list of 'Position's in a term, which are open. 
--   This only makes sense in combination with a 'Rule'
type OpenPos = (TExp,Position)

-- |Returns a list with all free variables (as a list of expressions)
--  in the given rule 
freeVars  :: Rule -> [TExp]
freeVars  r = 
    let lhsvars = nub $ concatMap getVars $ lhs r
        rhsvars = nub.getVars.rhs $ r
    in rhsvars \\ lhsvars

hasFreeVars :: Rule -> Bool
hasFreeVars = not.null.freeVars

openPositions :: Rule -> [OpenPos]
openPositions r = 
    concatMap (openPos r) (freeVars r)
    where
    openPos r e = map ((,) e) $ getPos (rhs r) e

 -- | A rule is open if it is not closed
isOpen :: Rule -> Bool
isOpen = not . isClosed

-- | A rule is closed if it does not have any open positions
isClosed :: Rule -> Bool
isClosed = null . openPositions

-- | Computes all positions in a rule at which is a variable. The first list
--   contains the variable positions of the first argument, the second of the 
--   second argument and the last of the right-hand side of the rule.            
ruleVarPos :: Rule -> [[RulePos]]
ruleVarPos r =  lside ++ [rside]    
    where
    argcount = [0.. (length $ lhs r) -1]
    lside = map (\i -> map (Arg i) (varpos ((lhs r) !! i)) ) argcount
    rside = map Body (varpos (rhs r))
    varpos t =  concat.snd.unzip $ getVarPos t  

-- dissolve a subterm out from the rhs and use it as new rhs
subrule :: RulePos -> Rule -> Maybe Rule
subrule (Body p)  (R l r) = liftM2 R (Just l) (subtermAt r p)
subrule (Arg i p) (R l r) = liftM2 newrule (subtermAt (l!!i) p) (Just r)
    where
    newrule l r = R [l] r 

-- DEAD CODE
--replaceInLhs :: Rule -> RulePos -> TExp -> Rule
--replaceInLhs r (Arg i pos) t = 
--    let (pb,(p:pa)) = splitAt i (lhs r)
--    in r{lhs= concat [pb, [(substitute t pos p)],pa]}
--

replaceInRhs :: Rule -> RulePos -> TExp -> Rule
replaceInRhs r (Body pos) t = 
    let rs = rhs r
    in r{rhs= (substitute t pos rs)}

-- TODO nicer and clearer
mkCallAt :: RulePos -> Name -> [TExp] -> Rule -> Rule
mkCallAt (Body pos) n args r = 
    let resty   = typeOf . fromJust . (flip subtermAt pos) . rhs $ r
    in replaceInRhs r (Body pos) (mkCall n resty args)

-- DEAD CODE            
--callLevel :: Rule -> [Name] -> Int
--callLevel = (flip countCalls) . rhs

hasHO :: Rule -> Bool
hasHO = isHOApp . rhs
-- DEAD CODE        
--ruleSubtermAt :: RulePos -> Rule -> Either (Maybe TExp) (Maybe TExp)
--ruleSubtermAt p@(Arg _ _) r = Left $ lhsSubtermAt p r
--ruleSubtermAt p@(Body _) r  = Right $ rhsSubtermAt p r
--
-- DEAD CODE
--lhsSubtermAt :: RulePos -> Rule -> Maybe TExp
--lhsSubtermAt (Arg i p) r = subtermAt ((lhs r) !! i) p
--lhsSubtermAt _ _ = error $ "Data.Rules.lhsSubtermAt : The given position is" ++ 
--                           "not a position on the left-hand side of the rule"  
--
-- DEAD CODE
--rhsSubtermAt :: RulePos -> Rule -> Maybe TExp
--rhsSubtermAt (Body p) r  = subtermAt (rhs r) p
--rhsSubtermAt _ _ = error $ "Data.Rules.rhsSubtermAt : The given position is" ++ 
--                           "not a position on the right-hand side of the rule"

hasVarAt :: Rule ->RulePos -> Bool
hasVarAt (R l _)(Arg i p) = varAtPos (l !! i) p
hasVarAt (R _ r)(Body p)  = varAtPos r p

hasCtorAt :: Rule -> RulePos -> Bool
hasCtorAt r p = not $ r `hasVarAt` p
   
-- | Returns 'true' if both 'Rule's have the same symbol at the specified position         
sameSymAt :: RulePos -> Rule -> Rule -> Bool
sameSymAt (Arg i p) = (T.sameSymAt p) `on`  ((!! i) . lhs)
sameSymAt (Body p ) = (T.sameSymAt p) `on` rhs 

lggRules :: (Monad m, MonadError e m, Error e) => [Rule] -> C m Rule
lggRules rules  = do
        let tlhss   = transpose $ map lhs rules
        let rhss    = map rhs rules
        -- ATTENTION: 
        -- The rhss may contain existentially quantified variables (EQVs)!
        -- According to Igor's matching policy, they may be replaced by a term
        -- occuring in another AU-image which is identical except the EQVs.
        -- Example:
        -- let cx, be a constant, vx a variable, and ex an EQV.
        -- If the image [c1, v1, c2] is replaced by v2, so is [c1,v2,e1] , 
        -- because the context of e1 is the same as of c2 modulo variable 
        -- renaming. 
        -- Since EQVs occur only on rhss, we need to antiunify the lhss first.
        r <-  lggL (tlhss ++ [rhss])
        return $ rule (init r)(last r)

-- DEAD CODE
--matchRules :: Rule -> Rule -> LM Bool
--matchRules (R l1 r1) (R l2 r2) = matches2 (l1,r1) (l2,r2) id
     
-- DEAD CODE
--isMoreSpecific :: Rule -> Rule -> Bool
--isMoreSpecific a b = (&&) (matchLhss a b) (matchRhs a b)
--
-- DEAD CODE
--isMoreGeneral :: Rule -> Rule -> Bool
--isMoreGeneral = flip isMoreSpecific

matchEval :: (Error e, MonadError e m) => LHS -> Rule -> C m TExp
matchEval args r = liftM (flip apply (rhs r)) $ matchL args (lhs r)

matchEvals :: (MonadPlus m, Error e, MonadError e m) => LHS -> [Rule] -> C m TExp
matchEvals l rs = msum $ map (matchEval l) rs 

matchesLhs :: (Error e, MonadError e m) => Rule -> Rule -> C m (Subst TExp)
matchesLhs = matchL `on` lhs

matchesRhs :: (Error e, MonadError e m) => Rule -> Rule -> C m (Subst TExp)
matchesRhs = match `on` rhs

matchLhss :: Monad m => Rule -> Rule -> C m Bool
matchLhss = matchesL `on` lhs

matchRhs :: Monad m => Rule -> Rule -> C m Bool
matchRhs = matches `on` rhs

-------------
-- Auxiliaries
-------------

mkCall :: Name -> Type -> [TExp] -> TExp
mkCall n resty ps =
    let funty = arrowT $ (++ [resty]) $ map typeOf ps  
    in  foldTAppE (TConE n funty) ps
    -- TODO Should be a TVarE, but then matchEval and rule simplification
    -- may not work properly 

buildCall (n,cr) = (mkCall n (typeOf.rhs $ cr) (lhs cr), rhs cr)

                           
doesCallTo :: Rule -> Name ->  Bool
doesCallTo r n = dc n (rhs r)
    where
    dc n (TVarE nc _)      = n == nc
    dc n (TLitE _ _)       = False
    dc n (TWildE _ _)      = False
    dc n (TConE nc _)      = n == nc
    dc n (TAppE a1 a2)     = dc n a1 || dc n a2
--    dc n (TCondE i t e _)  = any (dc n) [i,t,e]
 
-- DEAD CODE  
--countCalls :: [Name] -> TExp -> Int
--countCalls ns e = cc (0::Int) ns e
--    where
--    cc :: Int -> [Name] -> TExp -> Int
--    cc c ns (TVarE n _)       = if isCall n ns then c + 1 else c
--    cc c ns (TLitE _ _)       = c
--    cc c ns (TWildE _ _)      = c
--    cc c ns (TConE n _)       = c
--    cc c ns (TAppE a1 a2)     = (c+) $ on (+) (cc 0 ns) a1 a2
----    cc c ns (TCondE i t e _)  = (c+).sum $ map (cc 0 ns) [i,t,e]
--
-- DEAD CODE  
--isCall :: Name -> [Name] -> Bool
--isCall n ns 
--    | n `elem` ns                 = True
--    | "fun" `isPrefixOf` (show n) = True
--    | otherwise                   = False
    
hypos2decs :: [[(Name,Rules)]] -> [[FunBind]]
hypos2decs hs =  map ((map toFunB).rearrange) hs
    where
    rearrange = (uncurry  ((++) `on` sort)). partition tgtOrBgk
    tgtOrBgk  (n,_) =  not . (isPrefixOf "fun") . show $ n
    fromRule r = mkEq (lhs r)(rhs r) 
    toFunB (n,rs) = mkFB n (map fromRule (S.toList rs))
    
--simplify :: (MonadError e m) => [Name] -> [(Name,Rules)] -> C m [(Name,Rules)]
--simplify blckl l = 
--    case partition isRec l of
--        (r,[])     -> return r
--        (rec,nrec) -> let r = (getOneWhich hasNoCalls nrec)
--                               `mplus`
--                              (getOneWhich hasOneCall nrec)
----                               `mplus`
----                              (getOneWhich hasTwoCall nrec)
--                      in maybe (return $ rec ++ nrec) (simplify_ rec nrec) r
--                              
--    where
--    simplify_ rec nrec r = liftM (simplify blckl) $ uncurry replaceInAll (buildCall r) (rec ++ (deleteBy ((==) `on` fst) r nrec))
--    hasNoCalls = ((==0).(countCalls blckl).rhs.head.S.toList.snd)
--    hasOneCall = ((==1).(countCalls blckl).rhs.head.S.toList.snd)
--    hasTwoCall = ((==2).(countCalls blckl).rhs.head.S.toList.snd)
--    buildCall (n,r) = let cr = head . S.toList $ r
--                      in (mkCall n (typeOf.rhs $ cr) (lhs cr), rhs cr)
--    isRec (n,r)     = (elem n blckl) || ((>1).S.size $ r)

--getOneWhich :: (a -> Bool) -> [a] -> Maybe a    
--getOneWhich _ []    = Nothing
--getOneWhich f (x:xs)
--    | f x        = Just x
--    | otherwise  = getOneWhich f xs

replaceInAll :: (Monad m) => TExp -> TExp -> [(Name, Rules)] -> C m [(Name, Rules)]
replaceInAll cc ct =
    mapM (\(n,rs) -> mapM repCall (S.toList rs) >>= return . ((,) n) . rules )
    where    
    repCall r = do rhs' <- replaceCall (rule [cc] ct) (rhs r)
                   return r{rhs=rhs'}
--    repCall r = r{rhs= replaceTerm cc ct (rhs r)}

replaceCall :: (Monad m) => Rule -> RHS -> C m RHS
replaceCall _ t@(TVarE _ _)       = return t
replaceCall _ t@(TLitE _ _)       = return t
replaceCall _ t@(TWildE _ _)      = return t
replaceCall _ t@(TConE n _)       = return t
replaceCall r t = matchEval [t] r `safeCatchErrorC` \_ -> repSubterms t
    where
    repSubterms :: Monad m => RHS -> C m RHS
    repSubterms t = liftM (root t) $ mapM (replaceCall r) (subterms t)


--------------------------------------------------------------------------------
-- Datatype Rules
--------------------------------------------------------------------------------    

-- | 'Rules' are an indexed collection of 'Rule's with no duplicates
type Rules = Set Rule
    
rules :: [Rule] -> Rules        
rules = S.fromList

-- UNSAFE
ruleAtIndex :: Int -> Rules -> Rule
ruleAtIndex  i rs = S.toAscList rs !! i
