
-- | Term substitutions, application of substitutions to terms, and 
--   modifications of substitutions.
module Syntax.Class.Subst (
    -- * The class
    Substitutable(..),
    
    -- * The data type 
    -- ** Definition 
    Subst,
    -- ** Constructors
    nullSubst, (<~),
    
    -- * Functions over substitutions
    (@@), merge,
    lookupS, assocs,

) where

import Syntax.Class.Term
import Syntax.Name

import Data.List (nub, intersect, union)
import Data.Function (on)
import qualified Data.Map as M
import qualified Data.Set as S

import Control.Monad.Error (Error, MonadError, strMsg, throwError)

-- | The class of all subtitutables
class (Term t) => Substitutable t where
  -- | applies a subtituion to a term
  apply :: (Subst t) -> t -> t
  
  applyL :: (Subst t) -> [t] -> [t]
  applyL s = map (apply s)
  
-- | Associating 'Term's with variable 'Name's as a list of replacements
newtype Subst t = Subst { unSubst :: M.Map Name t } deriving (Show)

-- | The empty Substitution
nullSubst  :: (Substitutable t) => Subst t
nullSubst   = Subst M.empty

-- | Constructing a substitution from a single replacement
(<~)      :: (Substitutable t) => Name -> t -> Subst t
n <~ t     = Subst (M.singleton n t)


infixr 4 @@
-- | Composing substitutions 's1' and 's2', s.t.  
--        @ apply s1 .apply s2 == apply (s1 @@ s2) @
(@@)       :: (Substitutable t) => Subst t -> Subst t -> Subst t
s1 @@ s2    = Subst $ M.union (M.map (apply s1) (unSubst s2)) (unSubst s1)


-- | Parallel composition of substitutions, which checks that the two 
--   substitutions agree  at every variable in the domain of both, thus
--   guarantees that @ apply (merge s1 s2) = apply (merge s2 s1)@.
merge      :: (Substitutable t, Monad m, Error e, MonadError e m) => Subst t -> Subst t -> m (Subst t)
merge s1 s2 = if agree then return $ Subst ((M.union `on` unSubst) s1 s2)
                       else throwError $ strMsg "merge fails."
    where
    agree = all (\v -> apply s1 (var v) == apply s2 (var v))
                   (S.toList ((S.intersection `on` (M.keysSet . unSubst)) s1 s2))
    var = toVar . snd . head . assocs $ s1

lookupS n = M.lookup n . unSubst
assocs = M.assocs . unSubst
