-----------------------------------------------------------------------------
-- Type:        Types
-- 
-- Part of `Typing Haskell in Haskell', version of November 23, 2000
-- Copyright (c) Mark P Jones and the Oregon Graduate Institute
-- of Science and Technology, 1999-2000
-- 
-- This program is distributed as Free Software under the terms
-- in the file "License" that is included in the distribution
-- of this software, copies of which may be obtained from:
--             http://web.cecs.pdx.edu/~mpj/thih/
--
-----------------------------------------------------------------------------

-- |
{-# OPTIONS_GHC -XTemplateHaskell #-}
module Syntax.Type (
    -- * Data type 'Type'  
    Type (..),
     
    -- ** Predicates
    Pred(Pred), predClass, predMember,
    -- ** Useful type synonyms
    TyCxt, QualTy,
    -- ** Constructors
    appT, arrowCon, arrowT, varT, conT, forallT, listCon, listT, tupleCon, tupT, infixT, sectionType,
    foldAppT, mkPred,
    -- ** Deconstructors
    tyArgs, tyCtor, unArrowT,  unfoldAppTargs,
    -- ** Modifiers
    quantify, propCxt, fixType,
    -- ** Inspectors
    allPreds,isFunT, isHOT, isListT, dataName, 
    
    sameTy,
    -- * Class for typed things
    Typed(..),
    -- ** Inspectors on 'Typed' things 
    hasFunT,  hasHOT, hasListT
    
    
)where 

import Data.List (foldl', nub, intersect, intersectBy)
import Control.Monad
import Control.Monad.Error (Error, MonadError, strMsg, throwError)
import Data.Function (on)
import Data.Maybe (fromMaybe)

import Syntax.Name
--import qualified Language.Haskell.TH.Syntax as TH( Type(..) )

import Syntax.Class.Term (Term(..), getVarNames, Size(sizeS))
import Syntax.Class.Subst

import Debug.Trace


data Type = ForallT TyCxt Type
 | VarT Name
 | ConT Name
 | AppT Type Type
 deriving (Show, Ord, Eq)



-- Type synonyms for our convenience. If we do this hacks on Type, at least let
-- us try to spell the implicit assumptions out.
type QualTy = Type -- assumes Type to be 'ForallT _ _ _'
data Pred = Pred {predClass :: Name, predMember :: Type} deriving (Show, Ord, Eq)
type TyCxt = [Pred]

mkPred :: Name -> String -> Pred
mkPred n v = Pred n (varT v)

--------------------------------------------------------------------------------
-- Inspecting Types
--------------------------------------------------------------------------------

isFunT :: Type -> Bool
isFunT =  (>1) . length . unArrowT

isHOT :: Type -> Bool
isHOT = (any isFunT) . unArrowT 

isListT :: Type -> Bool
isListT (AppT (ConT n) _) = n == listTypeName
isListT   _               = False

-- The name of a data type if it is a data type and not a primitive
dataName :: (Error e, MonadError e m) => Type -> m Name
dataName (ForallT _ t) = dataName t
dataName (VarT _)      = throwError . strMsg $ "dataName: Variable!"
dataName (ConT n)      = return n
dataName (AppT t _)    = dataName t
--------------------------------------------------------------------------------
-- Typed Things
--------------------------------------------------------------------------------

class Typed t where
    typeOf :: t -> Type

hasFunT :: (Typed t) => t -> Bool
hasFunT = isFunT . typeOf

hasHOT :: (Typed t) => t -> Bool
hasHOT = isHOT . typeOf

hasListT :: (Typed t) => t -> Bool
hasListT = isListT . typeOf
--------------------------------------------------------------------------------
-- Auxiliary functions for Type as Term
--------------------------------------------------------------------------------

sectionType (AppT (AppT a e1) e2) | a == arrowCon = e2
sectionType (ForallT cxt t) = quantify (sectionType t) cxt
sectionType t = error $ "Types.sectionType: no function type " ++ show t

--argumentType = head . unArrowT

varT = VarT . mkName
conT = ConT

arrowCon = ConT ''(->)

unArrowT :: Type -> [Type]
unArrowT (AppT (AppT a e1) e2) | a == arrowCon = e1 : unArrowT e2
unArrowT (ForallT cxt t) = map (flip quantify cxt) $ unArrowT t
unArrowT e = [e]

arrowT :: [Type] -> Type
arrowT []  = error "Types.arrowT: empty list of types"
arrowT [t] = t
arrowT ts  = foldr1 apArrowT ts

apArrowT :: Type -> Type -> Type
apArrowT t1 t2 = AppT (AppT arrowCon t1) t2

forallT :: [(Name, String)] -> Type -> Type
forallT ps t = fixType $ ForallT (map (uncurry mkPred) ps) t

-- TODO: DANGER !!! correctness of types is not checked !!!
listCon = ConT ''[]
listT :: Type -> Type
listT = AppT listCon
-- TODO: DANGER !!! correctness of types is not checked !!!
tupleCon = ConT . tupleDataName
tupT :: [Type] -> Type
tupT l = foldAppT (tupleCon $ length l) l
-- TODO: DANGER !!! correctness of types is not checked !!!
-- first type is the type constructor
infixT :: Type -> Type -> Type -> Type
infixT ct at1 at2 = sectionType . sectionType $ ct

-- | Make a type application, keep track of quantified types
appT :: Type -> Type -> Type
appT (ForallT c1 t1) (ForallT c2 t2) = quantify (AppT t1 t2) (c1 ++ c2)
appT  t1             (ForallT c t2)  = quantify (AppT t1 t2) c
appT (ForallT c t1)  t2              = quantify (AppT t1 t2) c
appT  t1             t2              = AppT t1 t2

-- | Make a type with multiple arguments
foldAppT :: Type -> [Type] -> Type
foldAppT t ts = foldl' appT t ts

-- | Disassemble a type in its type constructor and its arguments    
unfoldAppT :: Type -> [Type]
unfoldAppT e = f [] e
    where
    f done (ForallT cxt t) = map (flip quantify cxt) $ f done t
    f done (AppT e1 e2)    = f (e2:done) e1
    f done e               = e:done

fixType :: Type -> Type
fixType (ForallT cxt t) = quantify t cxt
fixType t = t

    
unfoldAppTargs = tail . unfoldAppT


--------------------------------------------------------------------------------
-- Instance Declarations
--------------------------------------------------------------------------------


instance Term Type where
    sameSymAtRoot (VarT _) (VarT _)           = True
    sameSymAtRoot (ConT n1) (ConT n2)         = n1 == n2

    sameSymAtRoot (ForallT c1 t1) (ForallT c2 t2) = sameSymAtRoot t1 t2 && ((c1 `intersect` c2) == c1)
--    sameSymAtRoot (ForallT c1 t1) t2          = sameSymAtRoot t1 t2
--    sameSymAtRoot t1 (ForallT _ t2)           = sameSymAtRoot t1 t2
    sameSymAtRoot t1@(AppT _ _) t2@(AppT _ _) = on (==) (head . unfoldAppT) t1 t2
    sameSymAtRoot _ _                         = False

    subterms (ForallT c t) = map (flip quantify c) (subterms t)
    subterms (AppT l r)    = [l, r]
    subterms _             = []
    
    -- The extra case for AppT is also covered in the general case, but
    -- spelling it out here speeds things up considerably by avoiding the
    -- unfoldAppT.
    equal (AppT l1 l2) (AppT r1 r2) = equal l1 r1 && equal l2 r2
    equal s t = sameSymAtRoot s t && and (on (zipWith equal) subterms s t)

    root (ForallT c t) = \ts -> quantify (root t (map rmCxt ts)) (nub $ concatMap tyCxt ts)
    root (AppT _ _)    = \[l, r] -> appT l r
    root t             = const t

    getVar (VarT n) = Just n
    getVar _        = Nothing

    toVar _ = VarT
    

instance Substitutable Type where
  apply s (VarT u)      = fromMaybe (VarT u) (lookupS u s)
  apply s (AppT l r)    = appT (apply s l) (apply s r)
  apply s (ForallT c t) = quantify (apply s t) c
  apply s t             = t

instance Size Type where
    sizeS (VarT _) = (+1)
    sizeS (ConT _) = (+1)
    sizeS (AppT l r) = sizeS l . sizeS r
    sizeS (ForallT c t) = sizeS t

--
-- Type Auxiliaries
--
-- |Add quantified type variables
quantify :: Type -> TyCxt -> Type
quantify (ForallT cxt t) qtys =
    let cxt' = nub $ allPreds (getVarNames t) (cxt ++ qtys)
    in if null cxt' then t else ForallT cxt' t
    where
quantify t@(AppT _ _) qtys = quantify (ForallT [] t) qtys
quantify t@(VarT _) qtys = quantify (ForallT [] t) qtys
quantify t _ = t

-- | Gets the type context from a type, may be the empty type context
tyCxt :: Type -> TyCxt
tyCxt (ForallT cxt _) = cxt
tyCxt        _        = []

rmCxt :: Type -> Type
rmCxt (ForallT _ t) = t
rmCxt        t      = t

propCxt :: Type -> Type
propCxt t@(ForallT _ (VarT _))   = t
propCxt (ForallT cxt (AppT l r)) = AppT (quantify l cxt) (quantify r cxt)
propCxt t = t

-- | returns all type predicates which qualify a variable with given name
allPreds :: [Name] -> TyCxt -> TyCxt
allPreds ns = filter (isIn ns) 

-- | True if the predicate qualifies any of the given variable names
isIn :: [Name] -> Pred -> Bool
isIn ns p = any (`elem` ns) (getVarNames (predMember p))


-- | Disassemble a type into its arguments    
tyArgs t = case unfoldAppT t of [] -> []; l -> tail l

-- | Disassemble a type into its type constructor    
tyCtor = head . unfoldAppT


--------------------------------------------------------------------------------
-- Comparing Types
--------------------------------------------------------------------------------

sameTy :: (Typed t) => t -> t -> Bool
sameTy = equal `on` typeOf

