{-# OPTIONS_GHC #-}
module Igor2.Data.Hypotheses (

    Hypo,  Hypos,
    hypo, -- hypos, 
    open, clsd, rating, callings,
    isFinished,
    
    openRules, closedRules, allRules, allBindings, simplifiedBindings,
    developH, 

    module Igor2.Data.CallDependencies,
    module Igor2.Data.Rules

    )where

import Prelude hiding ((<$>))

import Data.Util (showAsSet)

import Syntax
import Igor2.Ppr

import Control.Monad (foldM)
import Control.Monad.Error
import Control.Monad.Identity (runIdentity)
import Control.Monad.Reader (mapReaderT)
import Control.Arrow (first, second)

import Data.List (foldl', partition, (\\), isPrefixOf)

import Data.Set (Set)
import qualified Data.Set as S

import Data.Map (Map)
import qualified Data.Map as M

import Data.Function (on)
import Data.Maybe (isJust, fromJust)

import Igor2.Data.Rules hiding (sameSymAt, isOpen, isClosed)
import Igor2.Data.IOData

import Igor2.Data.CallDependencies -- (CallDep, Call, tryAddCall, noCalls, cycles)

import Igor2.Data.Rateable
import Igor2.Logging

--------------------------------------------------------------------------------
-- | Hypothese
--------------------------------------------------------------------------------


data Hypo  = HH { open :: (CovrRules)
                , clsd :: (Map Name Rules) -- function name and names of dependent functions
                , callings :: CallDep
                , rating :: RatingData
                }
               deriving (Eq, Show)


{- ----------
 | Instance declarations of data type 'Hypo'
-} ----------

                    
instance Pretty Hypo where
	pretty h = text "Hypo" <+> pretty (rating h) <$> 
			   (vcat $ map pretty (allBindings h))        
    
-- | Constructor for data type 'Hypothese'
hypo :: (Monad m) => [CovrRule] -> C m Hypo
hypo rs  = updateRating $ foldl' (flip unsafeExtend) h rs
    where
    h   = HH S.empty M.empty noCalls emptyRatingData

isFinished :: Hypo -> Bool
isFinished    = S.null . open

closedRules :: Hypo -> Rules
closedRules h = foldl' S.union S.empty (M.elems (clsd h))

openRules :: Hypo -> Rules
openRules h   = (S.map crul(open h))

allRules :: Hypo -> Rules
allRules h    = S.union (closedRules h) (openRules h)

allBindings :: Hypo -> [(Name,Rules)]
allBindings h = M.toList $ foldl' ins (clsd h) (S.toList $ open h)
    where
    ins m e   = M.alter (insertM (crul e)) (star e) m
    star      = mkName . ('*':) . show . name

bindings :: Hypo -> [(Name,Rules)]
bindings h = M.toList $ foldl' ins (clsd h) (S.toList $ open h)
    where
    ins m e = M.alter (insertM (crul e)) (name e) m

simplifiedBindings :: (Monad m) =>
                     Hypo -> C m (CallDep, [(Name, Rules)])
simplifiedBindings h = sbnds [] (callings h) (allBindings h)
    where
    
    sbnds blckl cd bnds = 
        case (annexCandidates $ cd) \\ blckl of
        -- get those functions which are the outer most in the call graph
        -- either constant functions, or those with a single call
            []      -> return (cd,bnds)
            l@(_:_) -> let (is,nis) = partition (isInjectable bnds) l in 
                       (foldM inject bnds is) >>= 
                         sbnds (blckl++nis) (foldl (flip annexFun) cd is)
            -- add the non-injectables to the blacklist, update the 
            -- CallDep and inject the injectables in all other bindings  
            
    isInjectable bnds n = let fbnds = lookup n bnds in and $ 
        [ not $ "*" `isPrefixOf` (show n)
        -- true if function is not unfinished  
        , maybe False id (liftM ((==1).S.size) fbnds)
        -- true if there are bindings with this name (if not, it should be the 
        -- name of a bgk function) and the binding has only a single rule 
        -- attached.
        , maybe False (not . (any hasHO) . S.toList) fbnds
        -- True if the function itself is not a Ho
        , not $ any hasHO $ filter (flip doesCallTo n) $ 
          concatMap (S.toList . snd) $ bnds
        -- True if any function which calls the function with name at 
        -- hand is not a HO
        ]
        
    inject bnds l = let (r,rs) = partition ((l==).fst) $ bnds
                        cll = buildCall . (second $ head . S.toList) . head $ r
                    in uncurry replaceInAll cll rs
   


        
{-| 
  The only way to change a hypothese is by developing a CovrRule in it.
  Together with the changes, a list of Calls has to bre provided, denoting
  the call dependencies between various CovrRules. 
  If one of these calls is not admissible, the function fails inside m'
-}
developH :: (Monad m, Error e, MonadError e m) =>
           CovrRule           -> -- the CovrRule that was changed
           Hypo               -> -- the hypothese to develop
           (CovrRules,[Call]) -> -- ( the CovrRules resulting from the change
                                -- , a list of Calls )
           C m Hypo             -- the resulting hypothese in the C Monad
developH rf h (rfs,calls) =  do 
    let h'@(HH open clsd cdps rd) = unsafeModify rf rfs h
    cdps' <-  foldM tryAddCall cdps calls
    updateRating h'{callings = cdps'}


-- | Internal Modifiers for data type 'Hypo'

-- adds a CovrRule to the Hypo
unsafeExtend :: CovrRule -> Hypo -> Hypo
unsafeExtend rf h@(HH os cs _ _) 
    | isOpen rf   = h{open = S.insert rf os}
    | otherwise   = h{clsd = M.alter (insertM (crul rf)) (name rf) cs}

-- deletes a CovrRule in the Hypo
unsafeShrink :: CovrRule -> Hypo -> Hypo
unsafeShrink rf h@(HH os cs _ _) 
    | isOpen rf  = h{open = S.delete rf os}
    | otherwise  = h{clsd = M.alter (deleteM (crul rf)) (name rf) cs}

-- combined version of the functions above
unsafeModify :: CovrRule    -- ^replace/update old rule
             -> CovrRules   -- ^with new rules
             -> Hypo        -- ^in hypothese
             -> Hypo
unsafeModify rold newrs h =     
    let shrnk = unsafeShrink rold h in
    S.fold unsafeExtend shrnk newrs

updateRating :: (Monad m) => Hypo -> C m Hypo
updateRating h = do v <- rate h; return h{rating = v}

--------------------------------------------------------------------------------
-- | Hypotheses
--------------------------------------------------------------------------------
type Hypos = [Hypo] 

--------------------------------------------------------------------------------
-- Rate a Hypothese
--------------------------------------------------------------------------------


instance Rateable Hypo where
    rate h       = mapReaderT (return . runIdentity) $ do
      numPartitions        <- numberOfPartitions h
      numClosedPartitions  <- numberOfClosedPartitions h
      heu                  <- heuristic h
      let numOpenRules     = numberOfOpenRules h
      let numFreeVars      = numberOfFreeVars h
      let numTotalRules    = numberOfTotalRules h  
      -- better than numTotalRules !!!
      let numNonConstBinds = numberOfNonConstBinds h
      let numCycles        = cycles . callings $ h         
      let numFVarPerRule   = 
            if (numOpenRules == 0) then 0.0 
              else on (/) fromIntegral numFreeVars numOpenRules
      return ( ratingData
               --numPartitions + numCycles + numOpenRules,
               numPartitions
               -- numCycles,
               numOpenRules
               numTotalRules
               numFreeVars
               -- numNonConstBinds
               heu
               -- numFVarPerRule
             )


-- count the number of open 'Rule's in a hypothese
numberOfOpenRules :: Hypo -> Int                 
numberOfOpenRules = S.size.openRules 

-- count the number of all rules in a hypothese
numberOfTotalRules :: Hypo -> Int 
numberOfTotalRules =  S.size.allRules 

-- count the number of all functions with more than one rule in a hypothese 
numberOfNonConstBinds :: Hypo -> Int 
numberOfNonConstBinds =  length . (filter ((>1) . S.size . snd)) . allBindings

numberOfFreeVars :: Hypo -> Int
numberOfFreeVars h = 
    foldl1 (+) $ map (length.freeVars) (S.toList $ allRules h)

numberOfPartitions ::  Monad m => Hypo -> C m Int
numberOfPartitions = numberOfLeastPatterns . allRules
        
numberOfClosedPartitions :: Monad m => Hypo -> C m Int
numberOfClosedPartitions = numberOfLeastPatterns . closedRules 

-- number of Pattern that do not subsume any other pattern          
numberOfLeastPatterns :: Monad m => Rules -> C m Int
numberOfLeastPatterns rs = 
  liftM length $ (foldM leastPatterns []) ((filter (not . hasHO)) . S.toList $ rs)
      where
        leastPatterns [] p       = return [p]
        leastPatterns (p1:ps) p2 = do
          b1 <- matchLhss p1 p2
          if b1 then return (p1:ps)
            else do b2 <- matchLhss p2 p1
                    if b2 then return (p2:ps)
                      else liftM (p1:) (leastPatterns ps p2)

-- a one-number heuristic, currently only used for coloring a node in the search
-- tree when visualised with istviewer.
heuristic :: (Monad m) => Hypo -> C m ([Float], Float)
heuristic h = do
    (cd, bs)  <- simplifiedBindings h
    let allRs = concatMap (S.toList . snd) bs
    let numFs = length bs
    let numRs = length allRs
    numPs     <- numberOfPartitions h         
    let numCs = cycles cd
    let numLs = loops cd
    let fracs = map divide [(numFs,numRs) -- rules per function
                           ,(numPs,numRs) -- rules per partition
--                         ,(numCs,numFs) -- cycles per function
--                         ,(numLs,numFs) -- loops per function
                                    ]
    return $ (fracs , ((sum fracs)/fromIntegral(length fracs)))
    
divide = uncurry ((/) `on` fromIntegral)
