{-# LANGUAGE ScopedTypeVariables #-}
module Igor2.RuleDevelopment.Accumulator where

import Prelude hiding ((<$>))

import Control.Monad.State (State, runState, state)
import Control.Monad (liftM, forM)
import Data.List (foldl1', isPrefixOf, intersperse)
import Data.Set (Set())
import qualified Data.Set as Set

import Syntax.Class.Term (root, subterms, subtermAt, toVar, isVar,
    Position(Root), (°), applyAtPos, getVarNames)
import Syntax.Expressions (TExp(TConE, TVarE), foldTAppE)
import Syntax.Type (arrowT, isFunT, typeOf)
import Syntax (mkName, Name)
import Igor2.Data.CallDependencies (Call)
import Igor2.Data.IOData (CovrRule, CovrRules, covrRules, crul, modifycrul, name)
import Igor2.Data.IgorMonad (IM, breakupM, addIO, coverAll)
import Igor2.Data.Rules (lhs, rhs, rule, rules, Rule(), Rules(), RulePos(Body), mkCallAt)
import Igor2.Logging (logIN, logDE, waypointS)
import Igor2.Ppr ((<+>), (<$>), (<//>), empty, char, pretty, text, cat, punctuate, squotes)

import Debug.Trace (trace)

mapFst :: (a -> a') -> (a, b) -> (a', b)
mapFst f (x, y) = (f x, y)

mapSnd :: (b -> b') -> (a, b) -> (a, b')
mapSnd f (x, y) = (x, f y)

mapFsts :: (a -> a') -> [(a, b)] -> [(a', b)]
mapFsts = map . mapFst

mapSnds :: (b -> b') -> [(a, b)] -> [(a, b')]
mapSnds = map . mapSnd

accumIntro :: CovrRule -> IM [(CovrRules, [Call])]
accumIntro rf = do
    waypointS $ text "Introducing accumulator for" <+> (squotes $ pretty $ rf)
    cruls :: [Rule] <- liftM (map crul) $ breakupM rf
    let -- Contains a list of positions of variable free terms for each rule.
        variableFreeTerms :: [[(Position, TExp)]]
        variableFreeTerms = map (snd . getVariableFreeTerms . rhs) cruls

        getVariableFreeTerms :: TExp -> (Bool, [(Position, TExp)])
        getVariableFreeTerms t =
            let hasVariables :: [Bool]
                vfreeSubterms' :: [[(Position, TExp)]]
                (hasVariables, vfreeSubterms') = unzip
                    $ map getVariableFreeTerms (subterms t)
                -- Complement the positions, concatenate the results.
                vfreeSubterms :: [(Position, TExp)]
                vfreeSubterms = concatMap (\(i, st) -> mapFsts (i°) st)
                                $ zip [0..] vfreeSubterms'
            in if not (isVar t) && and hasVariables
               then (True, (Root, t) : vfreeSubterms)
               else (False, vfreeSubterms)

        -- These are just for readability.
        vftPos :: (Position, TExp) -> Position
        vftPos = fst

        vftTerm :: (Position, TExp) -> TExp
        vftTerm = snd

    logIN (text "Found the following variable-free subterms:" <$> pretty variableFreeTerms)
    -- A set of all variable-free subterms that occur in all IO examples.
    let candidateTerms :: Set TExp
        candidateTerms = foldl1' Set.intersection
                            $ map (Set.fromList . map vftTerm)
                            $ variableFreeTerms
    -- Grep all entries denoting a given term t, and project only the positions
    -- to the result.
    let filterPosByTerm :: TExp -> [[(Position, TExp)]] -> [[Position]]
        filterPosByTerm t = map (map vftPos . filter ((== t) . vftTerm))
    -- A partition of the inner lists of variableFreeTerms by candidate terms.
    let positionsByCandidate :: [(TExp, [[Position]])]
        positionsByCandidate = map (\ct -> (ct, filterPosByTerm ct variableFreeTerms)) (Set.toList candidateTerms)
        numPositions = map (map length . snd) positionsByCandidate :: [[Int]]
    logIN (text "Choosing combinations from:" <$> pretty positionsByCandidate)
    logIN (text "Expected total number of combinations:"
            <$> cat (punctuate (text " + ")
                        $ map (foldr (<//>) empty . intersperse (char '*') . map pretty)
                              numPositions
                    )
            <+> text "=" <+> pretty (sum $ map product numPositions)
        )
    -- Replace each [[Position]] by the cartesian product of its inner lists.
    -- Where before there was one [Position] for each IO example (in
    -- positionsByCandidate), now each (in positionCombinationsByCandidate)
    -- [Position] has one entry for each IO example.
    let positionCombinationsByCandidate :: [(TExp, [[Position]])]
        positionCombinationsByCandidate = mapSnds sequence positionsByCandidate
    logIN (text "Number of combinations per candidate term:" <$> pretty (mapSnds length positionCombinationsByCandidate))

    -- Search for a free name for the new accumulator variable
    let accVarNamePrefix = "aCC"
        occuringAccVarNames = filter (isPrefixOf accVarNamePrefix) $ map show $ concatMap getVarNames $ concatMap lhs cruls
        accVarNames = [ accVarNamePrefix ++ show i | i <- [1..] ]
        freeVarNames = [ name | name <- accVarNames, not (name `elem` occuringAccVarNames) ]
        newAccVarName = mkName (head freeVarNames)

    let addAccVar initExpr poss =
           rules [ let accVar = toVar initExpr newAccVarName
                   in rule (lhs r ++ [accVar])
                           (applyAtPos (const accVar) p (rhs r))
                 | (p, r) <- zip poss cruls ]

    let crulss :: [(TExp, Rules)]
        crulss = concatMap (\(initExpr, combs) ->
                     map (\poss -> (initExpr, addAccVar initExpr poss)) combs
                 ) positionCombinationsByCandidate

    fnnames <- mapM (addIO . snd) crulss :: IM [Name]

    rfaccs <- mapM coverAll fnnames :: IM [CovrRule]

    let callfn :: Name -> TExp -> Rule -> Rule
        callfn accfn initExpr r = mkCallAt (Body Root) accfn (lhs r ++ [initExpr]) r
    let rfnews :: [CovrRule]
        rfnews = [ modifycrul rf (callfn accfn initExpr) | ((initExpr, _), accfn) <- zip crulss fnnames ]
    let result :: [(CovrRules, [Call])]
        result = [ (covrRules [rfnew, rfacc], [(name rf, accfnname, GT)]) | (accfnname, rfacc, rfnew) <- zip3 fnnames rfaccs rfnews ]
    logIN (text "Resulting accumulator functions and calls:" <$> pretty result)
    return result
