{-# LANGUAGE TemplateHaskellQuotes #-}
module Sasha.Internal.Word8Set (
    memberCode,
) where

import Data.WideWord.Word256 (Word256 (..))
import Data.Word             (Word64, Word8)
import Data.Word8Set         (Word8Set)

import qualified Data.Bits     as Bits
import qualified Data.Word8Set as W8S

import Language.Haskell.TH.Syntax


-- | Optimized routing to check membership when 'Word8Set' is statically known.
--
-- @
-- 'memberCode' c ws = [||'member' $$c $$(liftTyped ws) ||]
-- @
--
memberCode :: Code Q Word8 -> Word8Set -> Code Q Bool
memberCode :: Code Q Word8 -> Word8Set -> Code Q Bool
memberCode Code Q Word8
c Word8Set
ws
    -- simple cases
    | Word8Set -> Bool
W8S.null Word8Set
ws
    = [|| False ||]

    | Word8Set -> Bool
W8S.isFull Word8Set
ws                   
    = [|| True ||]

    | Word8Set -> Int
W8S.size Word8Set
ws forall a. Eq a => a -> a -> Bool
== Int
1
    = [|| $$c == $$(liftTyped (W8S.findMin ws)) ||]

    | Word8Set -> Int
W8S.size Word8Set
ws forall a. Eq a => a -> a -> Bool
== Int
2
    = [|| $$c == $$(liftTyped (W8S.findMin ws)) || $$c == $$(liftTyped (W8S.findMax ws)) ||]

    -- continuos range
    | Just (Word8
l, Word8
r) <- Word8Set -> Maybe (Word8, Word8)
W8S.isRange Word8Set
ws   
    = [|| $$(liftTyped l) <= $$c && $$c <= $$(liftTyped r) ||]

    -- low chars
    | Word256 Word64
0 Word64
0 Word64
0 Word64
x <- Word8Set -> Word256
W8S.toWord256 Word8Set
ws
    = [|| $$c < 64 && Bits.testBit ($$(liftTyped x) :: Word64) (fromIntegral ($$c :: Word8)) ||]

    -- fallback
    | Bool
otherwise
    = [|| W8S.member $$c $$(liftTyped ws) ||]