module ProjectM36.AtomFunctions.Primitive where
import ProjectM36.Base
import ProjectM36.Relation (relFold)
import ProjectM36.Tuple
import ProjectM36.AtomFunctionError
import ProjectM36.AtomFunction
import qualified Data.HashSet as HS
import qualified Data.Vector as V
import qualified Data.ByteString.Base64 as B64
import qualified Data.Text.Encoding as TE

primitiveAtomFunctions :: AtomFunctions
primitiveAtomFunctions = HS.fromList [
  --match on any relation type
  AtomFunction { atomFuncName = "add",
                 atomFuncType = [IntAtomType, IntAtomType, IntAtomType],
                 atomFuncBody = body (\((IntAtom i1):(IntAtom i2):_) -> pure (IntAtom (i1 + i2)))},
  AtomFunction { atomFuncName = "id",
                 atomFuncType = [TypeVariableType "a", TypeVariableType "a"],
                 atomFuncBody = body (\(x:_) -> pure x)},
  AtomFunction { atomFuncName = "sum",
                 atomFuncType = foldAtomFuncType IntAtomType IntAtomType,
                 atomFuncBody = body (\(RelationAtom rel:_) -> relationSum rel)},
  AtomFunction { atomFuncName = "count",
                 atomFuncType = foldAtomFuncType (TypeVariableType "a") IntAtomType,
                 atomFuncBody = body (\((RelationAtom relIn):_) -> relationCount relIn)},
  AtomFunction { atomFuncName = "max",
                 atomFuncType = foldAtomFuncType IntAtomType IntAtomType,
                 atomFuncBody = body (\((RelationAtom relIn):_) -> relationMax relIn)},
  AtomFunction { atomFuncName = "min",
                 atomFuncType = foldAtomFuncType IntAtomType IntAtomType,
                 atomFuncBody = body (\((RelationAtom relIn):_) -> relationMin relIn)},
  AtomFunction { atomFuncName = "lt",
                 atomFuncType = [IntAtomType, IntAtomType, BoolAtomType],
                 atomFuncBody = body $ intAtomFuncLessThan False},
  AtomFunction { atomFuncName = "lte",
                 atomFuncType = [IntAtomType, IntAtomType, BoolAtomType],
                 atomFuncBody = body $ intAtomFuncLessThan True},
  AtomFunction { atomFuncName = "gte",
                 atomFuncType = [IntAtomType, IntAtomType, BoolAtomType],
                 atomFuncBody = body $ \args -> intAtomFuncLessThan False args >>= boolAtomNot},
  AtomFunction { atomFuncName = "gt",
                 atomFuncType = [IntAtomType, IntAtomType, BoolAtomType],
                 atomFuncBody = body $ \args -> intAtomFuncLessThan True args >>= boolAtomNot},
  AtomFunction { atomFuncName = "not",
                 atomFuncType = [BoolAtomType, BoolAtomType],
                 atomFuncBody = body $ \(b:_) -> boolAtomNot b },
  AtomFunction { atomFuncName = "makeByteString",
                 atomFuncType = [TextAtomType, ByteStringAtomType],
                 atomFuncBody = body $ \((TextAtom textIn):_) -> case B64.decode (TE.encodeUtf8 textIn) of
                   Left err -> Left (AtomFunctionBytesDecodingError err)
                   Right bs -> pure (ByteStringAtom bs) }
  ]
  where
    body = AtomFunctionBody Nothing
                         
intAtomFuncLessThan :: Bool -> [Atom] -> Either AtomFunctionError Atom
intAtomFuncLessThan equality ((IntAtom i1):(IntAtom i2):_) = pure (BoolAtom (i1 `op` i2))
  where
    op = if equality then (<=) else (<)
intAtomFuncLessThan _ _= pure (BoolAtom False)

boolAtomNot :: Atom -> Either AtomFunctionError Atom
boolAtomNot (BoolAtom b) = pure (BoolAtom (not b))
boolAtomNot _ = error "boolAtomNot called on non-Bool atom"

--used by sum atom function
relationSum :: Relation -> Either AtomFunctionError Atom
relationSum relIn = pure (IntAtom (relFold (\tupIn acc -> acc + (newVal tupIn)) 0 relIn))
  where
    --extract Int from Atom
    newVal :: RelationTuple -> Int
    newVal tupIn = castInt ((tupleAtoms tupIn) V.! 0)
    
relationCount :: Relation -> Either AtomFunctionError Atom
relationCount relIn = pure (IntAtom (relFold (\_ acc -> acc + 1) (0::Int) relIn))

relationMax :: Relation -> Either AtomFunctionError Atom
relationMax relIn = pure (IntAtom (relFold (\tupIn acc -> max acc (newVal tupIn)) minBound relIn))
  where
    newVal tupIn = castInt ((tupleAtoms tupIn) V.! 0)

relationMin :: Relation -> Either AtomFunctionError Atom
relationMin relIn = pure (IntAtom (relFold (\tupIn acc -> min acc (newVal tupIn)) maxBound relIn))
  where
    newVal tupIn = castInt ((tupleAtoms tupIn) V.! 0)

castInt :: Atom -> Int
castInt (IntAtom i) = i
castInt _ = error "attempted to cast non-IntAtom to Int"