{-# OPTIONS_GHC -XFlexibleContexts -XTemplateHaskell #-}
module Syntax.UnifyTy (

    Unify(..), AUnify(..),
    checkMatch, specialise,

) where

import Control.Monad (foldM)
import Data.Foldable (foldlM)
import Control.Monad.Error
import Control.Monad.State
import Data.Maybe (maybeToList)
import Data.List (transpose, intersect)
import Data.Function (on)

import Syntax.Context
import Syntax.Type
import Syntax.Name
import Syntax.Class hiding (merge)
import Syntax.Class.Subst as Subst (merge)
import Syntax.Ppr

import Debug.Trace
--import Igor2.Logging

type TySubst = Subst Type


checkMatch s t = on match typeOf s t >> return ()

specialise :: (Error e, MonadError e m) => Type -> Type -> C m Type
specialise t1 t2 = liftM (flip apply t2) (match t1 (last (unArrowT t2)))

--------------------------------------------------------------------------------
-- Matching Types

-- | @match t1 t2@ returns the substitution 's', s.t. @apply s t2 == t1@.  
--   May fail inside 'C' if t2 is _not_ more general than t1.
instance Unify Type where

    match t (VarT n)               = return $ n <~ t
    match t (ForallT c (VarT n))   = check (n <~ t) c >> return (n <~ t)
    match t (ForallT c (AppT l r)) = match t (AppT (quantify l c) (quantify r c))
    match (ForallT c (AppT l r)) t = match (AppT (quantify l c) (quantify r c)) t
    match (AppT l r) (AppT l' r')  = do sl <- match l l'
                                        sr <- match r r'
                                        Subst.merge sl sr
    match (ConT n) (ConT n')
        | n == n'                  = return nullSubst
    match t1 t2                    = throwError . strMsg $ "Type " ++ show t1 ++
                                              " does not match " ++
                                              show t2 ++ "."

matchesPred :: Monad m => Pred -> Pred -> C m Bool
matchesPred (Pred n1 t1) (Pred n2 t2) = if n1 == n2 then matches t1 t2 else return False
            
{- | Check if the current substitution is satisfiable in the given type context.
     Reduce the type context to head-normal forms first.
-}
check :: (Error e, MonadError e m) => TySubst -> TyCxt -> C m ()
check s c = do c' <- reduce c
               allM (checkit c') (assocs s) >>= failIfFalse  -- . (ts $ "CHECK :\n" ++ (show s) ++ " in " ++ (show c'))
    where
    -- For each predicate in context, which imposes a restriction on the given
    -- variable in the subtitution, check if the restriction still holds after 
    -- substitution
    checkit c (n,t) = allM (satisfies t) (allPreds [n] c)
    failIfFalse True  = return () 
    failIfFalse False = throwError . strMsg $ "Substitution " ++ show s ++
                               " not satisfieable in " ++ show c
    
{- | Check if a types satisifies a predicate which is in head-normal form:
      * a universal quantified variable satisfies any predicate    
      * a quantified variable satisfies a predicate, if its context entails the 
        predicate
      * for any other type we need to check 
-}        
satisfies :: (Error e, MonadError e m) => Type -> Pred -> C m Bool
satisfies t@(VarT _) p             = return True
--    return . (ts $ (show t) ++ " ??1?? " ++ (show p) ++ " | ") $ True
satisfies t@(ForallT c (VarT _)) p =  c `entails` p
--    c `entails` p >>= return . (ts $ (show t) ++ " ??2?? " ++ (show p) ++ " | ")
satisfies t p = getInstances (predClass p) >>= anyM (flip matches t)
--    getInstances n >>=  (anyM (matches t)) . ( (\ts -> trace (show ts) ts )) >>= return . (\ts -> trace ( (show t) ++ " ??3?? " ++ (show p) ++ " | " ++ (show ts)) ts )


--
-- Context Reduction 
-- e.g. '(Eq a, Eq a)' to 'Eq a' ot 'Eq a, Ord a' tp 'Ord a'
-- 
reduce :: (Error e, MonadError e m) => TyCxt -> C m TyCxt
reduce ps = toHnfs ps >>= simplify
    
simplify :: Monad m => TyCxt -> C m TyCxt
simplify = loop []
    where
    loop rs [] = return rs
    loop rs (p:ps) = do b <- entails (rs ++ ps) p  
                        if b then loop rs ps else loop (p:rs) ps

toHnfs :: (Error e, MonadError e m) => TyCxt -> C m TyCxt
toHnfs ps  = mapM toHnf ps >>= return . concat

toHnf  :: (Error e, MonadError e m) => Pred -> C m TyCxt
toHnf p
    | inHnf p   = return [p]
    | otherwise = byInst p >>= maybe (throwError . strMsg $ "context reduction") toHnfs
                            
inHnf :: Pred -> Bool
inHnf = hnf . predMember
    where
    hnf (VarT _)   = True
    hnf (AppT t _) = hnf t
    hnf _          = False
    
-- 
-- Entailment 
--  
                           
entails :: Monad m => TyCxt -> Pred -> C m Bool
entails ctx p = do
     b1 <- mapM bySuper ctx >>= anyM (`matchesPred` p) . concat
     b2 <- byInst p >>= maybe (return False) (allM (entails ctx))
     return (b1 || b2)
--     return $ (ts $ "CTX: " ++ (show ctx) ++ "\nPRED: " ++ (show p) ++ "\nENT:" ++ (show (b1,b2)) ) $ (b1 || b2)            

-- Given a predicate compute all predicates derivable from SuperClass relations
bySuper :: (MonadReader Context m) => Pred -> m [Pred]
bySuper p@(Pred n t) =
    getSuperClasses n >>= mapM (\n' -> bySuper (Pred n' t)) >>=
     return . (p:) . concat
--     return $ (ts $ "BySuper: " ++ (show p) ++ " " ) $ p:concat sps

byInst :: Monad m => Pred -> C m (Maybe [Pred])
byInst p@(Pred n t) = do
    ins <- getInstances n
    is <- mapM (tryInst t) ins
    return $ msum is

tryInst :: Monad m => Type -> Type -> C m (Maybe [Pred])
tryInst t (ForallT c t') = liftM (\u -> Just (map (\(Pred n t'') -> Pred n (apply u t'')) c)) (match t t')
                             `safeCatchErrorC` const (return Nothing)
tryInst t       t'       = liftM (const (Just [])) (match t t')
                             `safeCatchErrorC` (const (return Nothing))

--------------------------------------------------------------------------------
-- Antiunifying Types


instance AUnify Type where
    aunify [] = error "aunify: empty list"
    aunify t = let t' = map propCxt t in
        if sameRoots t'
        then let l = (map subterms t') in
             if [] `elem` l then return . head $ t'
             else mapM aunify (transpose l) >>= return . (roots t) 
        else bindVar t
    
bindVar :: (Error e, MonadError e m) => [Type] -> AU m Type Type
bindVar img =  getMap >>= maybe (mkVar img) return . lookup img
     
mkVar :: (Error e, MonadError e m) => [Type] -> AU m Type Type
mkVar vimg = do 
    vnm <- getCnt >>= return . ('a':) . show
    cxt <- lift $ cmpCxt vnm vimg
    var <- return . (flip quantify cxt)  . varT $ vnm
    putImg var vimg
    return var
      
cmpCxt :: (Error e, MonadError e m) => String -> [Type] -> C m TyCxt
cmpCxt vnm vimg = do
    insts <- allInstances
    nms   <- mapM (collectPreds vnm insts) vimg
    reduce . (foldl1 intersect) $ nms
    
collectPreds :: Monad m => String -> [(Name, [Type])] -> Type -> C m [Pred]
collectPreds vnm is (ForallT c (VarT _)) = liftM concat $ mapM (\p -> bySuper $ mkPred (predClass p) vnm) c
collectPreds vnm is (VarT _) = return []
collectPreds vnm is t =
    filterM ( (anyM (`matches` t)) . snd ) is >>=
    return . map ((flip mkPred vnm) . fst)
 
--------------------------------------------------------------------------------
--  Auxiliaries
                        
allM :: (Monad m) => (a -> m Bool) -> [a] -> m Bool
allM p = foldlM (\b a -> liftM (b&&) (p a)) True

anyM :: (Monad m) => (a -> m Bool) -> [a] -> m Bool
anyM p = foldlM (\b a -> liftM (b||) (p a)) False


