module Main where

import Test.Framework (defaultMain)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import qualified Test.QuickCheck as QC
import qualified Test.QuickCheck.Property as QCP

import Data.Function (on)
import Data.Functor ((<$>))
import Data.List (mapAccumL, nubBy)
import Data.Maybe (fromJust)
import Control.Applicative ((<*>))
import Control.Monad (liftM, liftM2)
import Control.Monad.Reader (mapReaderT, runReader, runReaderT)
import qualified Control.Monad.State as St
import qualified Data.Map as M
import qualified Data.Set as S

import Language.Haskell.TH.Syntax (Name, mkName)

import Syntax (AUnify(), C, Position(Root, Dot), Size(), TExp(TVarE, TWildE,
    TAppE, TConE), Term(), Type(AppT, ForallT, VarT, ConT), Unify(),
    (°), apply, applyAtPos, arrowCon, boolCon, ctx_types, defaultContext,
    eitherCon, equal, getVarNames, intCon, isFunT, lgg, listCon, match,
    maybeCon, safeCatchErrorC, size, substitute, subtermAt, subterms, tAppE,
    tupleCon, typeOf, unArrowT, unfoldTAppE)
import Syntax.Ppr (Pretty)

instance QC.Arbitrary Name where
  arbitrary = mkName <$> QC.listOf1 (QC.elements ['a'..'z'])
  shrink n = case drop 1 (show n) of { [] -> []; n' -> [mkName n'] }

data Kind = Star | Arrow Kind Kind deriving (Eq, Ord)
infixr 9 `Arrow`

instance Show Kind where
  showsPrec _ Star = ('*':)
  showsPrec _ (x `Arrow` y) = ('(':) . shows x . (" -> "++) . shows y . (')':)

arrows 0 = Star
arrows n = Star `Arrow` arrows (n - 1)

simpleTypeKinds = [
    (boolCon, arrows 0),
    (intCon, arrows 0),
    (tupleCon 0, arrows 0),
    (listCon, arrows 1), 
    (maybeCon, arrows 1),
    (arrowCon, arrows 2),
    (eitherCon, arrows 2)] ++ [ (tupleCon i, arrows i) | i <- [2..15] ]

simpleTypes :: Kind -> [Type]
simpleTypes kind = map fst $ filter ((kind ==) . snd) simpleTypeKinds

instance QC.Arbitrary Kind where
 arbitrary = QC.elements (takeWhile (not . null . simpleTypes) $ map arrows [0..])

genType :: S.Set Name -> Kind -> Int -> QC.Gen Type
genType visible kind size = if null gens then error "huh?" else QC.oneof gens
  where simple    = simpleTypes kind ++ if kind == Star then map VarT (S.toList visible) else []
        conssimple = if null simple then id else (QC.elements simple:)
        genapp    = do argkind <- QC.arbitrary `QC.suchThat` (not . null . simpleTypes . (`Arrow` kind))
                       arg <- genType visible argkind (size `div` 2)
                       con <- genType visible (argkind `Arrow` kind) (size `div` 2)
                       return $ AppT con arg
        consapp   = if size <= 0 || (null . simpleTypes $ Star `Arrow` kind) then id else (genapp:)
        gens      = conssimple $ consapp []

newtype MatchingTypes = MatchingTypes { getMatchingTypes :: (Type, Type) }
    deriving Show

instance QC.Arbitrary MatchingTypes where
    arbitrary = do
        baseType <- QC.arbitrary
        possWithSubterms <- randomPossWithSubterms baseType
        let names = getConsistentNames "mtyvar" baseType (map snd possWithSubterms)
        let vars = map VarT names
        let substType = replaceSubterms const baseType vars (map fst possWithSubterms)
        return $ MatchingTypes (baseType, substType)

newtype MatchingTExps = MatchingTExps { getMatchingTExps :: (TExp, TExp) }
    deriving Show

instance QC.Arbitrary MatchingTExps where
    arbitrary = do
        baseTExp <- QC.arbitrary
        varPossWithSubterms <- randomPossWithSubterms baseTExp
        wildPossWithSubterms <- randomPossWithSubterms baseTExp
        let names = getConsistentNames "mtevar" baseTExp (map snd varPossWithSubterms)
        let vars = map TVarE names
        let varSubstTExp = replaceSubterms (. typeOf) baseTExp vars (map fst varPossWithSubterms)
        let wildSubst = TWildE (mkName "_")
        let wildSubstTExp = replaceSubterms (. typeOf) baseTExp (repeat wildSubst) (map fst wildPossWithSubterms)
        return $ MatchingTExps (wildSubstTExp, varSubstTExp)

randomPossWithSubterms :: (Term t, Size t) => t -> QC.Gen [(Position, t)]
randomPossWithSubterms term = do
    poss <- nonOverlappingPos term
    numSubsts <- chooseLog (succ (length poss))
    return [ (pos, fromJust (subtermAt term pos)) | pos <- take numSubsts poss ]

replaceSubterms apply term0 substs poss =
        foldl (\term (subst, pos) -> applyAtPos (apply subst) pos term)
              term0
              (zip substs poss)

getConsistentNames :: (Ord t, Term t) => String -> t -> [t] -> [Name]
getConsistentNames prefix term subTerms =
    let takenNames = S.fromList (getVarNames term)
        freeNames = [ n | i <- [1..],
                          let n = mkName (prefix ++ show i),
                          S.notMember n takenNames ]
    in St.evalState (mapM assignName subTerms) (M.empty, freeNames)

assignName :: (Ord t, Term t) => t -> St.State (M.Map t Name, [Name]) Name
assignName term = do
    (usedNames, nextName:freeNames) <- St.get
    case M.lookup term usedNames of
        Nothing   -> St.put (M.insert term nextName usedNames, freeNames) >> return nextName
        Just name -> return name

chooseLog :: Integral a => a -> QC.Gen a
chooseLog n = liftM (pred . round . exp) $ QC.choose (0.0 :: Float, log (fromIntegral (succ n)))

getAllPos :: Term t => t -> [Position]
getAllPos t = (Root :) $ concat $ snd
    $ mapAccumL (\i st -> (i+1, map (i°) (getAllPos st))) 0 (subterms t)

nonOverlappingPos :: Term t => t -> QC.Gen [Position]
nonOverlappingPos term = do
    rposs <- shuffle $ getAllPos term
    return $ nubBy overlap rposs

overlap :: Position -> Position -> Bool
overlap _ Root = True
overlap Root _ = True
overlap (Dot i p) (Dot i' p') | (i /= i') = False
                              | otherwise = overlap p p'

shuffle :: [a] -> QC.Gen [a]
shuffle [] = return []
shuffle xs = do
  (y, ys) <- QC.elements (selectOne xs)
  (y:) <$> shuffle ys
  where
    selectOne [] = []
    selectOne (y:ys) = (y,ys) : map (second (y:)) (selectOne ys)

second :: (b -> b') -> (a, b) -> (a, b')
second f (x, y) = (x, f y)

kindOf (VarT {})     = Star
kindOf (AppT t _)    = let _ `Arrow` r = kindOf t in r
kindOf (ForallT _ t) = kindOf t
kindOf c@(ConT {})   = fromJust $ lookup c simpleTypeKinds

instance QC.Arbitrary Type where
  arbitrary = QC.sized $ genType S.empty Star

  shrink a@(AppT b c)  =
      (if kindOf a == kindOf c then (c:) else id) $
      map (flip AppT c) (QC.shrink b) ++ map (AppT b) (QC.shrink c)
  shrink (ForallT _ t) = [t]
  shrink (VarT n)      = map VarT (QC.shrink n)
  shrink (ConT n)      = []

simpleTExps :: [TExp]
simpleTExps = map (uncurry TConE) . M.assocs $ ctx_types defaultContext

type Env = M.Map Name Type

genTExp :: Maybe Type -> Env -> Int -> QC.Gen (TExp, Env)
genTExp t env size = QC.oneof gens >>= uncurry app
  where andenv e = (e, env)
        simple = map andenv $ maybe id (\t' -> filter ((t' ==) . typeOf)) t simpleTExps
        conssimple = if null simple then id else (QC.elements simple:)
        var = do t' <- maybe QC.arbitrary return t
                 n <- QC.arbitrary
                 if n `M.member` env then var
                                   else QC.elements [(TVarE n t', M.insert n t' env),
                                                     andenv (TWildE n t')]
        fromenv = QC.elements . map (andenv . uncurry TVarE) $ M.assocs env
        consenv = if M.null env then id else (fromenv:)
        gens = conssimple $ consenv [var]
        app e env' = if not . isFunT $ typeOf e then return (e, env') else (do
            let argt = head . unArrowT $ typeOf e
            (e', env'') <- genTExp (Just argt) env' (pred size)
            app (tAppE e e') env'')

instance QC.Arbitrary TExp where
  arbitrary = fmap fst . QC.sized $ genTExp Nothing M.empty

  shrink e@(TAppE {}) = tail $ unfoldTAppE e
  shrink (TWildE n t) = [ TWildE n' t | n' <- QC.shrink n ]
  shrink _ = []

newtype UniformTExps = UniformTExps { getUniformTExps :: (TExp, TExp) }
    deriving Show

instance QC.Arbitrary UniformTExps where
  arbitrary = do t1 <- QC.arbitrary
                 t2 <- fmap fst . QC.sized $ genTExp (Just $ typeOf t1) M.empty
                 return $ UniformTExps (t1, t2)

qcCompare :: (Show a) => (a -> a -> Bool) -> a -> a -> QC.Property
qcCompare eq x y =
  QC.printTestCase ("Left: " ++ show x) $
  QC.printTestCase ("Right: " ++ show y) $
  x `eq` y

runContext = flip runReader defaultContext

typeWitness :: Type
typeWitness = undefined
texpWitness :: TExp
texpWitness = undefined

lggAssociative :: (AUnify t, Pretty t, Term t) => t -> t -> t -> t -> QC.Property
lggAssociative _witness t1 t2 t3 = runContext $ (do
    t23   <- lgg [t2, t3]
    t123  <- lgg [t1, t23]
    t123' <- lgg [t1, t2, t3]
    return $ qcCompare equal t123 t123'
    ) `safeCatchErrorC` error "lgg failed"

matchApplyC :: (Term t, Unify t) => t -> t -> t -> C (Either String) QCP.Property
matchApplyC _witness t1 t2 = do
    s <- match t1 t2
    let t2' = apply s t2
    return $ qcCompare equal t1 (apply s t2)

matchApply :: (Term t, Unify t) => t -> t -> t -> QC.Property
matchApply _witness t1 t2 = runContext $ matchApplyC _witness t1 t2
    `safeCatchErrorC` const QC.discard

matchApplyUnforgiving :: (Term t, Unify t) => t -> t -> t -> QC.Property
matchApplyUnforgiving _witness t1 t2 = runContext $
    either error return `mapReaderT` matchApplyC _witness t1 t2

compareEither :: (a -> a -> Bool) -> (b -> b -> Bool) -> Either a b -> Either a b -> Bool
compareEither _ f (Right x) (Right y) = f x y
compareEither f _ (Left x)  (Left y)  = f x y
compareEither _ _ _         _         = False

lggSymmetric :: (AUnify t, Pretty t, Term t) => t -> t -> t -> QC.Property
lggSymmetric _witness t1 t2 = (compareLgg `on` runContextT . lgg) [t1, t2] [t2, t1]
    where strAnyEq :: String -> String -> Bool
          strAnyEq _ _ = True
          compareLgg = qcCompare (compareEither strAnyEq equal)
          runContextT = flip runReaderT defaultContext

lggMatches :: (AUnify t, Pretty t, Term t, Unify t) => t -> t -> t -> QC.Property
lggMatches _witness t1 t2 =
    case  runContext (fmap Just (lgg [t1, t2])
          `safeCatchErrorC` const (return Nothing)) of
        Nothing  -> QC.discard
        Just t12 -> runContext ((do
                s1 <- match t1 t12
                let t1' = apply s1 t12
                s2 <- match t2 t12
                let t2' = apply s2 t12
                return (QC.property (t1' `equal` t1 && t2' `equal` t2)))
            `safeCatchErrorC` (error . ("match of lgg failed: "++)))

tests = [
    testProperty "Types: lgg [a, lgg [b, c]] == lgg [a, b, c]" (lggAssociative typeWitness),
    testProperty "Exprs: lgg [a, lgg [b, c]] == lgg [a, b, c]" (lggAssociative texpWitness),
    testProperty "Types: lgg [a, b] == lgg [b, a]" (lggSymmetric typeWitness),
    testProperty "Exprs: lgg [a, b] == lgg [b, a]" (uncurry (lggSymmetric texpWitness) . getUniformTExps),
    testProperty "Types: apply (match a b) b == a" (matchApply typeWitness),
    testProperty "Exprs: apply (match a b) b == a" (matchApply texpWitness),
    testProperty "MatchingTypes: apply (match a b) b == a" (uncurry (matchApplyUnforgiving typeWitness) . getMatchingTypes),
    testProperty "MatchingTExps: apply (match a b) b == a" (uncurry (matchApplyUnforgiving texpWitness) . getMatchingTExps),
    testProperty "Types: lgg matches" (lggMatches typeWitness),
    testProperty "Exprs: lgg matches" (uncurry (lggMatches undefined) . getUniformTExps)
    ]

main = defaultMain tests
