-- |
-- Module:      Math.NumberTheory.Roots.Squares
-- Copyright:   (c) 2011 Daniel Fischer, 2016-2020 Andrew Lelechenko
-- Licence:     MIT
-- Maintainer:  Daniel Fischer <daniel.is.fischer@googlemail.com>
--
-- Functions dealing with squares. Efficient calculation of integer square roots
-- and efficient testing for squareness.

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash    #-}

module Math.NumberTheory.Roots.Squares
    ( -- * Square root calculation
      integerSquareRoot
    , integerSquareRoot'
    , integerSquareRootRem
    , integerSquareRootRem'
    , exactSquareRoot
      -- * Tests for squares
    , isSquare
    , isSquare'
    , isPossibleSquare
    ) where

import Data.Bits (finiteBitSize, (.&.))
import GHC.Exts (Ptr(..))
import Numeric.Natural (Natural)

import Math.NumberTheory.Roots.Squares.Internal
import Math.NumberTheory.Utils.BitMask (indexBitSet)

-- | For a non-negative input \( n \)
--   calculate its integer square root \( \lfloor \sqrt{n} \rfloor \).
--   Throw an error on negative input.
--
-- >>> integerSquareRoot 99
-- 9
-- >>> integerSquareRoot 100
-- 10
-- >>> integerSquareRoot 101
-- 10
{-# SPECIALISE integerSquareRoot :: Int -> Int,
                                    Word -> Word,
                                    Integer -> Integer,
                                    Natural -> Natural
  #-}
integerSquareRoot :: Integral a => a -> a
integerSquareRoot :: a -> a
integerSquareRoot a
n
  | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0       = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"integerSquareRoot: negative argument"
  | Bool
otherwise   = a -> a
forall a. Integral a => a -> a
integerSquareRoot' a
n

-- | Calculate the integer square root of a non-negative number @n@,
--   that is, the largest integer @r@ with @r*r <= n@.
--   The precondition @n >= 0@ is not checked.
{-# RULES
"integerSquareRoot'/Int"     integerSquareRoot' = isqrtInt'
"integerSquareRoot'/Word"    integerSquareRoot' = isqrtWord
"integerSquareRoot'/Integer" integerSquareRoot' = isqrtInteger
  #-}
{-# INLINE [1] integerSquareRoot' #-}
integerSquareRoot' :: Integral a => a -> a
integerSquareRoot' :: a -> a
integerSquareRoot' = a -> a
forall a. Integral a => a -> a
isqrtA

-- | For a non-negative input \( n \)
--   calculate its integer square root \( r = \lfloor \sqrt{n} \rfloor \)
--   and remainder \( s = n - r^2 \).
--   Throw an error on negative input.
--
-- >>> integerSquareRootRem 99
-- (9,18)
-- >>> integerSquareRootRem 100
-- (10,0)
-- >>> integerSquareRootRem 101
-- (10,1)
{-# SPECIALISE integerSquareRootRem ::
        Int -> (Int, Int),
        Word -> (Word, Word),
        Integer -> (Integer, Integer),
        Natural -> (Natural, Natural)
  #-}
integerSquareRootRem :: Integral a => a -> (a, a)
integerSquareRootRem :: a -> (a, a)
integerSquareRootRem a
n
  | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0       = [Char] -> (a, a)
forall a. HasCallStack => [Char] -> a
error [Char]
"integerSquareRootRem: negative argument"
  | Bool
otherwise   = a -> (a, a)
forall a. Integral a => a -> (a, a)
integerSquareRootRem' a
n

-- | Calculate the integer square root of a non-negative number as well as
--   the difference of that number with the square of that root, that is if
--   @(s,r) = integerSquareRootRem' n@ then @s^2 <= n == s^2+r < (s+1)^2@.
--   The precondition @n >= 0@ is not checked.
{-# RULES
"integerSquareRootRem'/Integer" integerSquareRootRem' = karatsubaSqrt
  #-}
{-# INLINE [1] integerSquareRootRem' #-}
integerSquareRootRem' :: Integral a => a -> (a, a)
integerSquareRootRem' :: a -> (a, a)
integerSquareRootRem' a
n = (a
s, a
n a -> a -> a
forall a. Num a => a -> a -> a
- a
s a -> a -> a
forall a. Num a => a -> a -> a
* a
s)
  where
    s :: a
s = a -> a
forall a. Integral a => a -> a
integerSquareRoot' a
n

-- | Calculate the exact integer square root if it exists,
-- otherwise return 'Nothing'.
--
-- >>> map exactSquareRoot [-100, 99, 100, 101]
-- [Nothing,Nothing,Just 10,Nothing]
{-# SPECIALISE exactSquareRoot :: Int -> Maybe Int,
                                  Word -> Maybe Word,
                                  Integer -> Maybe Integer,
                                  Natural -> Maybe Natural
  #-}
exactSquareRoot :: Integral a => a -> Maybe a
exactSquareRoot :: a -> Maybe a
exactSquareRoot a
n
  | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0
  , a -> Bool
forall a. Integral a => a -> Bool
isPossibleSquare a
n
  , (a
r, a
0) <- a -> (a, a)
forall a. Integral a => a -> (a, a)
integerSquareRootRem' a
n = a -> Maybe a
forall a. a -> Maybe a
Just a
r
  | Bool
otherwise                         = Maybe a
forall a. Maybe a
Nothing

-- | Test whether the argument is a perfect square.
--
-- >>> map isSquare [-100, 99, 100, 101]
-- [False,False,True,False]
{-# SPECIALISE isSquare :: Int -> Bool,
                           Word -> Bool,
                           Integer -> Bool,
                           Natural -> Bool
  #-}
isSquare :: Integral a => a -> Bool
isSquare :: a -> Bool
isSquare a
n = a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 Bool -> Bool -> Bool
&& a -> Bool
forall a. Integral a => a -> Bool
isSquare' a
n

-- | Test whether the input (a non-negative number) @n@ is a square.
--   The same as 'isSquare', but without the negativity test.
--   Faster if many known positive numbers are tested.
--
--   The precondition @n >= 0@ is not tested, passing negative
--   arguments may cause any kind of havoc.
{-# SPECIALISE isSquare' :: Int -> Bool,
                            Word -> Bool,
                            Integer -> Bool,
                            Natural -> Bool
  #-}
isSquare' :: Integral a => a -> Bool
isSquare' :: a -> Bool
isSquare' a
n
    | a -> Bool
forall a. Integral a => a -> Bool
isPossibleSquare a
n
    , (a
_, a
0) <- a -> (a, a)
forall a. Integral a => a -> (a, a)
integerSquareRootRem' a
n = Bool
True
    | Bool
otherwise                         = Bool
False

-- | Test whether a non-negative number may be a square.
--   Non-negativity is not checked, passing negative arguments may
--   cause any kind of havoc.
--
--   First the remainder modulo 256 is checked (that can be calculated
--   easily without division and eliminates about 82% of all numbers).
--   After that, the remainders modulo 9, 25, 7, 11 and 13 are tested
--   to eliminate altogether about 99.436% of all numbers.
{-# SPECIALISE isPossibleSquare :: Int -> Bool,
                                   Word -> Bool,
                                   Integer -> Bool,
                                   Natural -> Bool
  #-}
isPossibleSquare :: Integral a => a -> Bool
isPossibleSquare :: a -> Bool
isPossibleSquare a
n'
  =  Ptr Word -> Int -> Bool
indexBitSet Ptr Word
mask256 (Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer
n Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
255))
  Bool -> Bool -> Bool
&& Ptr Word -> Int -> Bool
indexBitSet Ptr Word
mask693 (Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
693))
  Bool -> Bool -> Bool
&& Ptr Word -> Int -> Bool
indexBitSet Ptr Word
mask325 (Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
325))
  where
    n :: Integer
n = a -> Integer
forall a. Integral a => a -> Integer
toInteger a
n'

-----------------------------------------------------------------------------
-- Generated by 'Math.NumberTheory.Utils.BitMask.vectorToAddrLiteral'

mask256 :: Ptr Word
mask256 :: Ptr Word
mask256 = Addr# -> Ptr Word
forall a. Addr# -> Ptr a
Ptr Addr#
"\DC3\STX\ETX\STX\DC2\STX\STX\STX\DC3\STX\STX\STX\DC2\STX\STX\STX\DC2\STX\ETX\STX\DC2\STX\STX\STX\DC2\STX\STX\STX\DC2\STX\STX\STX"#

mask693 :: Ptr Word
mask693 :: Ptr Word
mask693 = Addr# -> Ptr Word
forall a. Addr# -> Ptr a
Ptr Addr#
"\DC3\STXA\STX0\NUL\STX\EOTI\NUL\STX\t\CAN\NUL\NULB\164\NUL\DC1\EOT\b\STX\NUL@P\128@\NUL\STX\t\128 \SOH\DLE\NUL\SOH\130$\NUL\128\DC4(\NUL\NUL\SOH\DC2\NUL\f\STX\DC4\SOH\NUL \b\NUL\"\NUL\128\EOT`\144\NUL\b\129\NULE\DC2\DLE@\STX\EOT\NUL\129\NUL\t\b\EOT\SOH\194\128\NUL\DLE\EOT\NUL\DLE\NUL\NUL"#

mask325 :: Ptr Word
mask325 :: Ptr Word
mask325 = Addr# -> Ptr Word
forall a. Addr# -> Ptr a
Ptr Addr#
"\DC3B\SOH&\144\NUL\n!%\140\STXH0\SOH\DC4BJ\b\ENQ\144@\STX(\132\148\DLE\n \131\EOTP\f)!\DC4@\STX\EM\160\DLE\DC2"#

-- -- Make an array indicating whether a remainder is a square remainder.
-- sqRemArray :: Int -> V.Vector Bool
-- sqRemArray md = runST $ do
--   ar <- MV.replicate md False
--   let !stop = (md `quot` 2) + 1
--       fill k
--         | k < stop  = MV.unsafeWrite ar ((k*k) `rem` md) True >> fill (k+1)
--         | otherwise = return ()
--   MV.unsafeWrite ar 0 True
--   MV.unsafeWrite ar 1 True
--   fill 2
--   V.unsafeFreeze ar

-- sr256 :: V.Vector Bool
-- sr256 = sqRemArray 256

-- sr693 :: V.Vector Bool
-- sr693 = sqRemArray 693

-- sr325 :: V.Vector Bool
-- sr325 = sqRemArray 325

-----------------------------------------------------------------------------
-- Specialisations for Int, Word, and Integer

-- For @n <= 2^64@, the result of
--
-- > truncate (sqrt $ fromIntegral n)
--
-- is never too small and never more than one too large.
-- The multiplication doesn't overflow for 32 or 64 bit Ints.
isqrtInt' :: Int -> Int
isqrtInt' :: Int -> Int
isqrtInt' Int
n
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
r   = Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1
    | Bool
otherwise = Int
r
      where
        !r :: Int
r = (Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
truncate :: Double -> Int) (Double -> Int) -> (Double -> Double) -> Double -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Int) -> Double -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
-- With -O2, that should be translated to the below
{-
isqrtInt' n@(I# i#)
    | r# *# r# ># i#            = I# (r# -# 1#)
    | otherwise                 = I# r#
      where
        !r# = double2Int# (sqrtDouble# (int2Double# i#))
-}

-- Same for Word.
isqrtWord :: Word -> Word
isqrtWord :: Word -> Word
isqrtWord Word
n
    | Word
n Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
< (Word
rWord -> Word -> Word
forall a. Num a => a -> a -> a
*Word
r)
      -- Double interprets values near maxBound as 2^64, we don't have that problem for 32 bits
      Bool -> Bool -> Bool
|| Word -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (Word
0 :: Word) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 Bool -> Bool -> Bool
&& Word
r Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
== Word
4294967296
                = Word
rWord -> Word -> Word
forall a. Num a => a -> a -> a
-Word
1
    | Bool
otherwise = Word
r
      where
        !r :: Word
r = (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Int -> Word) (Int -> Word) -> (Double -> Int) -> Double -> Word
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
truncate :: Double -> Int) (Double -> Int) -> (Double -> Double) -> Double -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Word) -> Double -> Word
forall a b. (a -> b) -> a -> b
$ Word -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
n

{-# INLINE isqrtInteger #-}
isqrtInteger :: Integer -> Integer
isqrtInteger :: Integer -> Integer
isqrtInteger = (Integer, Integer) -> Integer
forall a b. (a, b) -> a
fst ((Integer, Integer) -> Integer)
-> (Integer -> (Integer, Integer)) -> Integer -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> (Integer, Integer)
karatsubaSqrt