{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Utils.BranchAndBound
    (bnb, NodeValue(..), Bag(), NodeOps(..), PriorityQueue, empty, insert)
    where

import qualified Data.Map as Map
import Data.Map (Map())
import Control.Monad.State as State
import Data.Functor ((<$>))

{- Nodes -}

data NodeValue v = LowerBound v | Solution v

nodeValue :: NodeValue v -> v
nodeValue (LowerBound v) = v
nodeValue (Solution v) = v

data NodeOps k v n = NodeOps
    { keyOf :: n -> k
    , valueOf :: n -> NodeValue v
    , children :: n -> [n]
    }

{- PriorityQueue -}

class Ord k => PriorityQueue q k n where
    extractMin :: q k n -> Maybe (n, q k n)
    insert :: (k, n) -> q k n -> q k n
    insertList :: [(k, n)] -> q k n -> q k n
    insertList ns q0 = foldr (\n q -> insert n q) q0 ns

{- Bag, a PriorityQueue implementation -}

data Bag k v = Bag { unBag :: Map k [v] }

empty = Bag Map.empty

insertFront :: Ord k => (k, n) -> Bag k n -> Bag k n
insertFront (k, x) (Bag m) = Bag (Map.insertWith (\[new] old -> new : old) k [x] m)

instance Ord k => PriorityQueue Bag k n where
    extractMin (Bag m) = if Map.null m
                         then Nothing
                         else let (k, (v:_)) = Map.findMin m
                                  tailOrDel [_]    = Nothing
                                  tailOrDel (_:xs) = Just xs
                              in Just $ (v, Bag (Map.update tailOrDel k m))

    insert = insertFront

{- Branch and bound implementation -}

data BnbState q k v n = BnbState
    { queue :: q k n
    , upperBound :: v
    , solutions :: [n]
    }

updateQueue :: (q k n -> q k n) -> BnbState q k v n -> BnbState q k v n
updateQueue f s = s { queue = f (queue s) }

setQueue :: q k n -> BnbState q k v n -> BnbState q k v n
setQueue q s = s { queue = q }

updateUpperBound :: (v -> v) -> BnbState q k v n -> BnbState q k v n
updateUpperBound f s = s { upperBound = f (upperBound s) }

updateSolutions :: ([n] -> [n]) -> BnbState q k v n -> BnbState q k v n
updateSolutions f s = s { solutions = f (solutions s) }

{-|
    'bnb' is a branch-and-bound implementation that returns all minimal solutions.

    The first argument '_queue_witness' is a type witness for the queue that
    should be used and will not be evaluated.

    The second argument 'nodeOps' provides the necessary operations on nodes.
    The key is used to control the order in which nodes are searched. In which
    order nodes with equal keys are searched depends on the queue. The value is
    used to prune the search. Only solutions with minimum value are returned.

    The third argument 'x' is the root node.
 -}
bnb :: forall q k v n. (Ord k, Ord v, Bounded v, PriorityQueue q k n) => NodeOps k v n -> (q k n) -> [n]
bnb nodeOps initialQueue =
    let withKey x = (keyOf' x, x)
        withKeys = map withKey
        keyOf' = keyOf nodeOps
        valueOf' = valueOf nodeOps
        children' = children nodeOps
        initialState = BnbState
            { queue = initialQueue
            , upperBound = maxBound `asTypeOf` nodeValue (valueOf' undefined)
            , solutions = []
            }
        insertChildren :: n -> State (BnbState q k v n) ()
        insertChildren x = State.modify (updateQueue (insertList (withKeys $ children' x)))
        extractMinNode :: State (BnbState q k v n) (Maybe n)
        extractMinNode = do
            q <- State.gets queue
            case extractMin q of
                Nothing      -> return Nothing
                Just (x :: n, q') -> do
                    State.modify (setQueue q')
                    return (Just x)
        step :: State (BnbState q k v n) ()
        step = do
            ub <- State.gets upperBound
            mx <- extractMinNode :: (Ord k, PriorityQueue q k n) => State (BnbState q k v n) (Maybe n)
            case (\x -> (x, valueOf' x)) <$> mx of
                Nothing -> error "No solution found. This should never happen."
                Just (x, Solution v) ->
                    case (v `compare` ub) of
                        LT -> State.modify (\state ->
                                state { upperBound = v
                                      , solutions = [x] }
                              )
                        EQ -> State.modify (updateSolutions (x:))
                        GT -> return ()
                Just (x, LowerBound v) ->
                    if v > ub
                        then return ()
                        else insertChildren x
        queueIsEmpty :: State (BnbState q k v n) Bool
        queueIsEmpty = do
            q <- State.gets queue
            case extractMin q of
                Nothing -> return True
                _       -> return False
        allSolutionsFound :: State (BnbState q k v n) Bool
        allSolutionsFound = do
            mx <- State.gets (extractMin . queue)
            ub <- State.gets upperBound :: State (BnbState q k v n) v
            case mx of
                -- Queue is empty
                Nothing -> return True
                Just (x, q') -> return (nodeValue (valueOf' x) > ub)
        bnb' :: State (BnbState q k v n) ()
        bnb' = do done <- allSolutionsFound
                  unless done (step >> bnb')
    in solutions $ execState bnb' initialState
