{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module      :   Grisette.Core.Data.MemoUtils
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.Data.MemoUtils
  ( -- * Hashtable-based memoization
    htmemo,
    htmemo2,
    htmemo3,
    htmup,
    htmemoFix,
  )
where

import Data.Function (fix)
import Data.HashTable.IO as H
import Data.Hashable
import System.IO.Unsafe

type HashTable k v = H.BasicHashTable k v

-- | Function memoizer with mutable hash table.
htmemo :: (Eq k, Hashable k) => (k -> a) -> k -> a
htmemo :: forall k a. (Eq k, Hashable k) => (k -> a) -> k -> a
htmemo k -> a
f = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  HashTable RealWorld k a
cache <- forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
H.new :: IO (HashTable k v)
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ \k
x -> forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    Maybe a
tryV <- forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> IO (Maybe v)
H.lookup HashTable RealWorld k a
cache k
x
    case Maybe a
tryV of
      Maybe a
Nothing -> do
        -- traceM "New value"
        let v :: a
v = k -> a
f k
x
        forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> v -> IO ()
H.insert HashTable RealWorld k a
cache k
x a
v
        forall (m :: * -> *) a. Monad m => a -> m a
return a
v
      Just a
v -> forall (m :: * -> *) a. Monad m => a -> m a
return a
v

-- | Lift a memoizer to work with one more argument.
htmup :: (Eq k, Hashable k) => (b -> c) -> (k -> b) -> (k -> c)
htmup :: forall k b c. (Eq k, Hashable k) => (b -> c) -> (k -> b) -> k -> c
htmup b -> c
mem k -> b
f = forall k a. (Eq k, Hashable k) => (k -> a) -> k -> a
htmemo (b -> c
mem forall b c a. (b -> c) -> (a -> b) -> a -> c
. k -> b
f)

-- | Function memoizer with mutable hash table. Works on binary functions.
htmemo2 :: (Eq k1, Hashable k1, Eq k2, Hashable k2) => (k1 -> k2 -> a) -> (k1 -> k2 -> a)
htmemo2 :: forall k1 k2 a.
(Eq k1, Hashable k1, Eq k2, Hashable k2) =>
(k1 -> k2 -> a) -> k1 -> k2 -> a
htmemo2 = forall k b c. (Eq k, Hashable k) => (b -> c) -> (k -> b) -> k -> c
htmup forall k a. (Eq k, Hashable k) => (k -> a) -> k -> a
htmemo

-- | Function memoizer with mutable hash table. Works on ternary functions.
htmemo3 ::
  (Eq k1, Hashable k1, Eq k2, Hashable k2, Eq k3, Hashable k3) =>
  (k1 -> k2 -> k3 -> a) ->
  (k1 -> k2 -> k3 -> a)
htmemo3 :: forall k1 k2 k3 a.
(Eq k1, Hashable k1, Eq k2, Hashable k2, Eq k3, Hashable k3) =>
(k1 -> k2 -> k3 -> a) -> k1 -> k2 -> k3 -> a
htmemo3 = forall k b c. (Eq k, Hashable k) => (b -> c) -> (k -> b) -> k -> c
htmup forall k1 k2 a.
(Eq k1, Hashable k1, Eq k2, Hashable k2) =>
(k1 -> k2 -> a) -> k1 -> k2 -> a
htmemo2

-- | Memoizing recursion. Use like 'fix'.
htmemoFix :: (Eq k, Hashable k) => ((k -> a) -> (k -> a)) -> k -> a
htmemoFix :: forall k a. (Eq k, Hashable k) => ((k -> a) -> k -> a) -> k -> a
htmemoFix (k -> a) -> k -> a
h = forall a. (a -> a) -> a
fix (forall k a. (Eq k, Hashable k) => (k -> a) -> k -> a
htmemo forall b c a. (b -> c) -> (a -> b) -> a -> c
. (k -> a) -> k -> a
h)