{-|
Copyright  :  (C) 2016, University of Twente
                  2022-2024, QBayLogic B.V.
License    :  BSD2 (see the file LICENSE)
Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

{-# LANGUAGE Trustworthy #-}

{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver #-}

module Clash.Sized.RTree
  ( -- * 'RTree' data type
    RTree (LR, BR, RLeaf, RBranch)
    -- * Construction
  , treplicate
  , trepeat
    -- * Accessors
  , thead
  , tlast
    -- ** Indexing
  , indexTree
  , tindices
    -- * Modifying trees
  , replaceTree
    -- * Element-wise operations
    -- ** Mapping
  , tmap
  , tzipWith
    -- ** Zipping
  , tzip
    -- ** Unzipping
  , tunzip
    -- * Folding
  , tfold
    -- ** Specialised folds
  , tdfold
    -- ** Prefix sums (scans)
    -- $scans
  , scanlPar
  , tscanl
  , scanrPar
  , tscanr
    -- * Conversions
  , v2t
  , t2v
    -- * Misc
  , lazyT
  )
where

#if !MIN_VERSION_base(4,18,0)
import Control.Applicative         (liftA2)
#endif
import Control.DeepSeq             (NFData(..))
import qualified Control.Lens      as Lens
import Data.Default.Class          (Default (..))
import Data.Either                 (isLeft)
import Data.Foldable               (toList)
import Data.Kind                   (Type)
import Data.Singletons             (Apply, TyFun, type (@@))
import Data.Proxy                  (Proxy (..))
import GHC.TypeLits                (KnownNat, Nat, type (+), type (^), type (*))
import Language.Haskell.TH.Syntax  (Lift(..))
#if MIN_VERSION_template_haskell(2,16,0)
import Language.Haskell.TH.Compat
#endif
import Prelude                     hiding ((++), (!!), map)
import Test.QuickCheck             (Arbitrary (..), CoArbitrary (..))

import Clash.Annotations.Primitive (hasBlackBox)
import Clash.Class.BitPack         (BitPack (..), packXWith)
import Clash.Promoted.Nat          (SNat (..), UNat (..),
                                    pow2SNat, snatToNum, subSNat, toUNat)
import Clash.Promoted.Nat.Literals (d1)
import Clash.Sized.Index           (Index)
import Clash.Sized.Vector          (Vec (..), (!!), (++), dtfold, replace)
import Clash.XException
  (ShowX (..), NFDataX (..), isX, showsX, showsPrecXWith)

{- $setup
>>> :set -XDataKinds
>>> :set -XTypeFamilies
>>> :set -XTypeOperators
>>> :set -XTemplateHaskell
>>> :set -XFlexibleContexts
>>> :set -XTypeApplications
>>> :set -fplugin GHC.TypeLits.Normalise
>>> :set -XUndecidableInstances
>>> import Clash.Prelude
>>> import Data.Kind
>>> import Data.Singletons (Apply, TyFun)
>>> import Data.Proxy
>>> data IIndex (f :: TyFun Nat Type) :: Type
>>> type instance Apply IIndex l = Index ((2^l)+1)
>>> :{
let populationCount' :: (KnownNat k, KnownNat (2^k)) => BitVector (2^k) -> Index ((2^k)+1)
    populationCount' bv = tdfold (Proxy @IIndex)
                                 fromIntegral
                                 (\_ x y -> add x y)
                                 (v2t (bv2v bv))
:}
-}

-- | Perfect depth binary tree.
--
-- * Only has elements at the leaf of the tree
-- * A tree of depth /d/ has /2^d/ elements.
data RTree :: Nat -> Type -> Type where
  RLeaf :: a -> RTree 0 a
  RBranch :: RTree d a -> RTree d a -> RTree (d+1) a

instance NFData a => NFData (RTree d a) where
    rnf :: RTree d a -> ()
rnf (RLeaf a
x) = a -> ()
forall a. NFData a => a -> ()
rnf a
x
    rnf (RBranch RTree d a
l RTree d a
r ) = RTree d a -> ()
forall a. NFData a => a -> ()
rnf RTree d a
l () -> () -> ()
`seq` RTree d a -> ()
forall a. NFData a => a -> ()
rnf RTree d a
r

textract :: RTree 0 a -> a
textract :: RTree 0 a -> a
textract (RLeaf a
x)   = a
x
#if __GLASGOW_HASKELL__ != 902
textract (RBranch RTree d a
_ RTree d a
_) = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Char]
"textract: nodes hold no values"
#endif
-- See: https://github.com/clash-lang/clash-compiler/pull/2511
{-# CLASH_OPAQUE textract #-}
{-# ANN textract hasBlackBox #-}

tsplit :: RTree (d+1) a -> (RTree d a,RTree d a)
tsplit :: RTree (d + 1) a -> (RTree d a, RTree d a)
tsplit (RBranch RTree d a
l RTree d a
r) = (RTree d a
RTree d a
l,RTree d a
RTree d a
r)
#if __GLASGOW_HASKELL__ != 902
tsplit (RLeaf a
_)   = [Char] -> (RTree d a, RTree d a)
forall a. HasCallStack => [Char] -> a
error ([Char] -> (RTree d a, RTree d a))
-> [Char] -> (RTree d a, RTree d a)
forall a b. (a -> b) -> a -> b
$ [Char]
"tsplit: leaf is atomic"
#endif
-- See: https://github.com/clash-lang/clash-compiler/pull/2511
{-# CLASH_OPAQUE tsplit #-}
{-# ANN tsplit hasBlackBox #-}

-- | RLeaf of a perfect depth tree
--
-- >>> LR 1
-- 1
-- >>> let x = LR 1
-- >>> :t x
-- x :: Num a => RTree 0 a
--
-- Can be used as a pattern:
--
-- >>> let f (LR a) (LR b) = a + b
-- >>> :t f
-- f :: Num a => RTree 0 a -> RTree 0 a -> a
-- >>> f (LR 1) (LR 2)
-- 3
pattern LR :: a -> RTree 0 a
pattern $bLR :: a -> RTree 0 a
$mLR :: forall r a. RTree 0 a -> (a -> r) -> (Void# -> r) -> r
LR x <- (textract -> x)
  where
    LR a
x = a -> RTree 0 a
forall a. a -> RTree 0 a
RLeaf a
x

-- | RBranch of a perfect depth tree
--
-- >>> BR (LR 1) (LR 2)
-- <1,2>
-- >>> let x = BR (LR 1) (LR 2)
-- >>> :t x
-- x :: Num a => RTree 1 a
--
-- Case be used a pattern:
--
-- >>> let f (BR (LR a) (LR b)) = LR (a + b)
-- >>> :t f
-- f :: Num a => RTree 1 a -> RTree 0 a
-- >>> f (BR (LR 1) (LR 2))
-- 3
pattern BR :: RTree d a -> RTree d a -> RTree (d+1) a
pattern $bBR :: RTree d a -> RTree d a -> RTree (d + 1) a
$mBR :: forall r (d :: Nat) a.
RTree (d + 1) a
-> (RTree d a -> RTree d a -> r) -> (Void# -> r) -> r
BR l r <- ((\t -> (tsplit t)) -> (l,r))
  where
    BR RTree d a
l RTree d a
r = RTree d a -> RTree d a -> RTree (d + 1) a
forall (d :: Nat) a. RTree d a -> RTree d a -> RTree (d + 1) a
RBranch RTree d a
l RTree d a
r

instance (KnownNat d, Eq a) => Eq (RTree d a) where
  == :: RTree d a -> RTree d a -> Bool
(==) RTree d a
t1 RTree d a
t2 = Vec (2 ^ d) a -> Vec (2 ^ d) a -> Bool
forall a. Eq a => a -> a -> Bool
(==) (RTree d a -> Vec (2 ^ d) a
forall (d :: Nat) a. KnownNat d => RTree d a -> Vec (2 ^ d) a
t2v RTree d a
t1) (RTree d a -> Vec (2 ^ d) a
forall (d :: Nat) a. KnownNat d => RTree d a -> Vec (2 ^ d) a
t2v RTree d a
t2)

instance (KnownNat d, Ord a) => Ord (RTree d a) where
  compare :: RTree d a -> RTree d a -> Ordering
compare RTree d a
t1 RTree d a
t2 = Vec (2 ^ d) a -> Vec (2 ^ d) a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (RTree d a -> Vec (2 ^ d) a
forall (d :: Nat) a. KnownNat d => RTree d a -> Vec (2 ^ d) a
t2v RTree d a
t1) (RTree d a -> Vec (2 ^ d) a
forall (d :: Nat) a. KnownNat d => RTree d a -> Vec (2 ^ d) a
t2v RTree d a
t2)

instance Show a => Show (RTree n a) where
  showsPrec :: Int -> RTree n a -> ShowS
showsPrec Int
_ (RLeaf a
a)   = a -> ShowS
forall a. Show a => a -> ShowS
shows a
a
  showsPrec Int
_ (RBranch RTree d a
l RTree d a
r) = \[Char]
s -> Char
'<'Char -> ShowS
forall a. a -> [a] -> [a]
:RTree d a -> ShowS
forall a. Show a => a -> ShowS
shows RTree d a
l (Char
','Char -> ShowS
forall a. a -> [a] -> [a]
:RTree d a -> ShowS
forall a. Show a => a -> ShowS
shows RTree d a
r (Char
'>'Char -> ShowS
forall a. a -> [a] -> [a]
:[Char]
s))

instance ShowX a => ShowX (RTree n a) where
  showsPrecX :: Int -> RTree n a -> ShowS
showsPrecX = (Int -> RTree n a -> ShowS) -> Int -> RTree n a -> ShowS
forall a. (Int -> a -> ShowS) -> Int -> a -> ShowS
showsPrecXWith Int -> RTree n a -> ShowS
forall (d :: Nat). Int -> RTree d a -> ShowS
go
    where
      go :: Int -> RTree d a -> ShowS
      go :: Int -> RTree d a -> ShowS
go Int
_ (RLeaf a
a)   = a -> ShowS
forall a. ShowX a => a -> ShowS
showsX a
a
      go Int
_ (RBranch RTree d a
l RTree d a
r) = \[Char]
s -> Char
'<'Char -> ShowS
forall a. a -> [a] -> [a]
:RTree d a -> ShowS
forall a. ShowX a => a -> ShowS
showsX RTree d a
l (Char
','Char -> ShowS
forall a. a -> [a] -> [a]
:RTree d a -> ShowS
forall a. ShowX a => a -> ShowS
showsX RTree d a
r (Char
'>'Char -> ShowS
forall a. a -> [a] -> [a]
:[Char]
s))

instance KnownNat d => Functor (RTree d) where
  fmap :: (a -> b) -> RTree d a -> RTree d b
fmap = (a -> b) -> RTree d a -> RTree d b
forall (d :: Nat) a b.
KnownNat d =>
(a -> b) -> RTree d a -> RTree d b
tmap

instance KnownNat d => Applicative (RTree d) where
  pure :: a -> RTree d a
pure  = a -> RTree d a
forall (d :: Nat) a. KnownNat d => a -> RTree d a
trepeat
  <*> :: RTree d (a -> b) -> RTree d a -> RTree d b
(<*>) = ((a -> b) -> a -> b) -> RTree d (a -> b) -> RTree d a -> RTree d b
forall a b c (d :: Nat).
KnownNat d =>
(a -> b -> c) -> RTree d a -> RTree d b -> RTree d c
tzipWith (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
($)

instance KnownNat d => Foldable (RTree d) where
  foldMap :: (a -> m) -> RTree d a -> m
foldMap a -> m
f = (a -> m) -> (m -> m -> m) -> RTree d a -> m
forall (d :: Nat) a b.
KnownNat d =>
(a -> b) -> (b -> b -> b) -> RTree d a -> b
tfold a -> m
f m -> m -> m
forall a. Monoid a => a -> a -> a
mappend

data TraversableTree (g :: Type -> Type) (a :: Type) (f :: TyFun Nat Type) :: Type
type instance Apply (TraversableTree f a) d = f (RTree d a)

instance KnownNat d => Traversable (RTree d) where
  traverse :: forall f a b . Applicative f => (a -> f b) -> RTree d a -> f (RTree d b)
  traverse :: (a -> f b) -> RTree d a -> f (RTree d b)
traverse a -> f b
f = Proxy (TraversableTree f b)
-> (a -> TraversableTree f b @@ 0)
-> (forall (l :: Nat).
    SNat l
    -> (TraversableTree f b @@ l)
    -> (TraversableTree f b @@ l)
    -> TraversableTree f b @@ (l + 1))
-> RTree d a
-> TraversableTree f b @@ d
forall (p :: TyFun Nat Type -> Type) (k :: Nat) a.
KnownNat k =>
Proxy p
-> (a -> p @@ 0)
-> (forall (l :: Nat).
    SNat l -> (p @@ l) -> (p @@ l) -> p @@ (l + 1))
-> RTree k a
-> p @@ k
tdfold (Proxy (TraversableTree f b)
forall k (t :: k). Proxy t
Proxy @(TraversableTree f b))
                      ((b -> RTree 0 b) -> f b -> f (RTree 0 b)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap b -> RTree 0 b
forall a. a -> RTree 0 a
LR (f b -> f (RTree 0 b)) -> (a -> f b) -> a -> f (RTree 0 b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> f b
f)
                      ((f (RTree l b) -> f (RTree l b) -> f (RTree (l + 1) b))
-> SNat l -> f (RTree l b) -> f (RTree l b) -> f (RTree (l + 1) b)
forall a b. a -> b -> a
const ((RTree l b -> RTree l b -> RTree (l + 1) b)
-> f (RTree l b) -> f (RTree l b) -> f (RTree (l + 1) b)
forall (f :: Type -> Type) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 RTree l b -> RTree l b -> RTree (l + 1) b
forall (d :: Nat) a. RTree d a -> RTree d a -> RTree (d + 1) a
BR))

instance (KnownNat d, BitPack a) =>
  BitPack (RTree d a) where
  type BitSize (RTree d a) = (2^d) * (BitSize a)
  pack :: RTree d a -> BitVector (BitSize (RTree d a))
pack   = (RTree d a -> BitVector ((2 ^ d) * BitSize a))
-> RTree d a -> BitVector ((2 ^ d) * BitSize a)
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith (Vec (2 ^ d) a -> BitVector ((2 ^ d) * BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack (Vec (2 ^ d) a -> BitVector ((2 ^ d) * BitSize a))
-> (RTree d a -> Vec (2 ^ d) a)
-> RTree d a
-> BitVector ((2 ^ d) * BitSize a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RTree d a -> Vec (2 ^ d) a
forall (d :: Nat) a. KnownNat d => RTree d a -> Vec (2 ^ d) a
t2v)
  unpack :: BitVector (BitSize (RTree d a)) -> RTree d a
unpack = Vec (2 ^ d) a -> RTree d a
forall (d :: Nat) a. KnownNat d => Vec (2 ^ d) a -> RTree d a
v2t (Vec (2 ^ d) a -> RTree d a)
-> (BitVector ((2 ^ d) * BitSize a) -> Vec (2 ^ d) a)
-> BitVector ((2 ^ d) * BitSize a)
-> RTree d a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BitVector ((2 ^ d) * BitSize a) -> Vec (2 ^ d) a
forall a. BitPack a => BitVector (BitSize a) -> a
unpack

type instance Lens.Index   (RTree d a) = Int
type instance Lens.IxValue (RTree d a) = a
instance KnownNat d => Lens.Ixed (RTree d a) where
  ix :: Index (RTree d a) -> Traversal' (RTree d a) (IxValue (RTree d a))
ix Index (RTree d a)
i IxValue (RTree d a) -> f (IxValue (RTree d a))
f RTree d a
t = Int -> a -> RTree d a -> RTree d a
forall (d :: Nat) i a.
(KnownNat d, Enum i) =>
i -> a -> RTree d a -> RTree d a
replaceTree Int
Index (RTree d a)
i (a -> RTree d a -> RTree d a) -> f a -> f (RTree d a -> RTree d a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IxValue (RTree d a) -> f (IxValue (RTree d a))
f (RTree d a -> Int -> a
forall (d :: Nat) i a. (KnownNat d, Enum i) => RTree d a -> i -> a
indexTree RTree d a
t Int
Index (RTree d a)
i) f (RTree d a -> RTree d a) -> f (RTree d a) -> f (RTree d a)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> RTree d a -> f (RTree d a)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure RTree d a
t

instance (KnownNat d, Default a) => Default (RTree d a) where
  def :: RTree d a
def = a -> RTree d a
forall (d :: Nat) a. KnownNat d => a -> RTree d a
trepeat a
forall a. Default a => a
def

instance Lift a => Lift (RTree d a) where
  lift :: RTree d a -> Q Exp
lift (RLeaf a
a)     = [| RLeaf a |]
  lift (RBranch RTree d a
t1 RTree d a
t2) = [| RBranch $(lift t1) $(lift t2) |]
#if MIN_VERSION_template_haskell(2,16,0)
  liftTyped :: RTree d a -> Q (TExp (RTree d a))
liftTyped = RTree d a -> Q (TExp (RTree d a))
forall a. Lift a => a -> Q (TExp a)
liftTypedFromUntyped
#endif

instance (KnownNat d, Arbitrary a) => Arbitrary (RTree d a) where
  arbitrary :: Gen (RTree d a)
arbitrary = RTree d (Gen a) -> Gen (RTree d a)
forall (t :: Type -> Type) (f :: Type -> Type) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA (Gen a -> RTree d (Gen a)
forall (d :: Nat) a. KnownNat d => a -> RTree d a
trepeat Gen a
forall a. Arbitrary a => Gen a
arbitrary)
  shrink :: RTree d a -> [RTree d a]
shrink    = RTree d [a] -> [RTree d a]
forall (t :: Type -> Type) (f :: Type -> Type) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA (RTree d [a] -> [RTree d a])
-> (RTree d a -> RTree d [a]) -> RTree d a -> [RTree d a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> [a]) -> RTree d a -> RTree d [a]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> [a]
forall a. Arbitrary a => a -> [a]
shrink

instance (KnownNat d, CoArbitrary a) => CoArbitrary (RTree d a) where
  coarbitrary :: RTree d a -> Gen b -> Gen b
coarbitrary = [a] -> Gen b -> Gen b
forall a b. CoArbitrary a => a -> Gen b -> Gen b
coarbitrary ([a] -> Gen b -> Gen b)
-> (RTree d a -> [a]) -> RTree d a -> Gen b -> Gen b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RTree d a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
toList

instance (KnownNat d, NFDataX a) => NFDataX (RTree d a) where
  deepErrorX :: [Char] -> RTree d a
deepErrorX [Char]
x = a -> RTree d a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([Char] -> a
forall a. (NFDataX a, HasCallStack) => [Char] -> a
deepErrorX [Char]
x)

  rnfX :: RTree d a -> ()
rnfX RTree d a
t = if Either [Char] (RTree d a) -> Bool
forall a b. Either a b -> Bool
isLeft (RTree d a -> Either [Char] (RTree d a)
forall a. a -> Either [Char] a
isX RTree d a
t) then () else RTree d a -> ()
go RTree d a
t
   where
    go :: RTree d a -> ()
    go :: RTree d a -> ()
go (RLeaf a
x)   = a -> ()
forall a. NFDataX a => a -> ()
rnfX a
x
    go (RBranch RTree d a
l RTree d a
r) = RTree d a -> ()
forall a. NFDataX a => a -> ()
rnfX RTree d a
l () -> () -> ()
`seq` RTree d a -> ()
forall a. NFDataX a => a -> ()
rnfX RTree d a
r

  hasUndefined :: RTree d a -> Bool
hasUndefined RTree d a
t = if Either [Char] (RTree d a) -> Bool
forall a b. Either a b -> Bool
isLeft (RTree d a -> Either [Char] (RTree d a)
forall a. a -> Either [Char] a
isX RTree d a
t) then Bool
True else RTree d a -> Bool
go RTree d a
t
   where
    go :: RTree d a -> Bool
    go :: RTree d a -> Bool
go (RLeaf a
x)   = a -> Bool
forall a. NFDataX a => a -> Bool
hasUndefined a
x
    go (RBranch RTree d a
l RTree d a
r) = RTree d a -> Bool
forall a. NFDataX a => a -> Bool
hasUndefined RTree d a
l Bool -> Bool -> Bool
|| RTree d a -> Bool
forall a. NFDataX a => a -> Bool
hasUndefined RTree d a
r

  ensureSpine :: RTree d a -> RTree d a
ensureSpine = (a -> a) -> RTree d a -> RTree d a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. NFDataX a => a -> a
ensureSpine (RTree d a -> RTree d a)
-> (RTree d a -> RTree d a) -> RTree d a -> RTree d a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RTree d a -> RTree d a
forall (d :: Nat) a. KnownNat d => RTree d a -> RTree d a
lazyT


{- | A /dependently/ typed fold over trees.

As an example of when you might want to use 'dtfold' we will build a
population counter: a circuit that counts the number of bits set to '1' in
a 'Clash.Sized.BitVector.BitVector'. Given a vector of /n/ bits, we only need we
need a data type that can represent the number /n/: 'Index' @(n+1)@. 'Index' @k@
has a range of @[0 .. k-1]@ (using @ceil(log2(k))@ bits), hence we need 'Index' @n+1@.
As an initial attempt we will use 'tfold', because it gives a nice (@log2(n)@)
tree-structure of adders:

@
populationCount :: (KnownNat (2^d), KnownNat d, KnownNat (2^d+1))
                => BitVector (2^d) -> Index (2^d+1)
populationCount = tfold (resize . bv2i . pack) (+) . v2t . bv2v
@

The \"problem\" with this description is that all adders have the same
bit-width, i.e. all adders are of the type:

@
(+) :: 'Index' (2^d+1) -> 'Index' (2^d+1) -> 'Index' (2^d+1).
@

This is a \"problem\" because we could have a more efficient structure:
one where each layer of adders is /precisely/ wide enough to count the number
of bits at that layer. That is, at height /d/ we want the adder to be of
type:

@
'Index' ((2^d)+1) -> 'Index' ((2^d)+1) -> 'Index' ((2^(d+1))+1)
@

We have such an adder in the form of the 'Clash.Class.Num.add' function, as
defined in the instance 'Clash.Class.Num.ExtendingNum' instance of 'Index'.
However, we cannot simply use 'Clash.Sized.Vector.fold' to create a tree-structure of
'Clash.Class.Num.add's:

#if __GLASGOW_HASKELL__ >= 910
>>> :{
let populationCount' :: (KnownNat (2^d), KnownNat d, KnownNat (2^d+1))
                     => BitVector (2^d) -> Index (2^d+1)
    populationCount' = tfold (resize . bv2i . pack) add . v2t . bv2v
:}
<interactive>:...
    • Couldn't match type: (((2 ^ d) + 1) + ((2 ^ d) + 1)) - 1
                     with: (2 ^ d) + 1
      Expected: Index ((2 ^ d) + 1)
                -> Index ((2 ^ d) + 1) -> Index ((2 ^ d) + 1)
        Actual: Index ((2 ^ d) + 1)
                -> Index ((2 ^ d) + 1)
                -> AResult (Index ((2 ^ d) + 1)) (Index ((2 ^ d) + 1))
    • In the second argument of ‘tfold’, namely ‘add’
      In the first argument of ‘(.)’, namely
        ‘tfold (resize . bv2i . pack) add’
      In the expression: tfold (resize . bv2i . pack) add . v2t . bv2v
    • Relevant bindings include
        populationCount' :: BitVector (2 ^ d) -> Index ((2 ^ d) + 1)
          (bound at ...)
<BLANKLINE>

#elif __GLASGOW_HASKELL__ >= 900
>>> :{
let populationCount' :: (KnownNat (2^d), KnownNat d, KnownNat (2^d+1))
                     => BitVector (2^d) -> Index (2^d+1)
    populationCount' = tfold (resize . bv2i . pack) add . v2t . bv2v
:}
<BLANKLINE>
<interactive>:...
    • Couldn't match type: (((2 ^ d) + 1) + ((2 ^ d) + 1)) - 1
                     with: (2 ^ d) + 1
      Expected: Index ((2 ^ d) + 1)
                -> Index ((2 ^ d) + 1) -> Index ((2 ^ d) + 1)
        Actual: Index ((2 ^ d) + 1)
                -> Index ((2 ^ d) + 1)
                -> AResult (Index ((2 ^ d) + 1)) (Index ((2 ^ d) + 1))
    • In the second argument of ‘tfold’, namely ‘add’
      In the first argument of ‘(.)’, namely
        ‘tfold (resize . bv2i . pack) add’
      In the expression: tfold (resize . bv2i . pack) add . v2t . bv2v
    • Relevant bindings include
        populationCount' :: BitVector (2 ^ d) -> Index ((2 ^ d) + 1)
          (bound at ...)

#else
>>> :{
let populationCount' :: (KnownNat (2^d), KnownNat d, KnownNat (2^d+1))
                     => BitVector (2^d) -> Index (2^d+1)
    populationCount' = tfold (resize . bv2i . pack) add . v2t . bv2v
:}
<BLANKLINE>
<interactive>:...
    • Couldn't match type ‘(((2 ^ d) + 1) + ((2 ^ d) + 1)) - 1’
                     with ‘(2 ^ d) + 1’
      Expected type: Index ((2 ^ d) + 1)
                     -> Index ((2 ^ d) + 1) -> Index ((2 ^ d) + 1)
        Actual type: Index ((2 ^ d) + 1)
                     -> Index ((2 ^ d) + 1)
                     -> AResult (Index ((2 ^ d) + 1)) (Index ((2 ^ d) + 1))
    • In the second argument of ‘tfold’, namely ‘add’
      In the first argument of ‘(.)’, namely
        ‘tfold (resize . bv2i . pack) add’
      In the expression: tfold (resize . bv2i . pack) add . v2t . bv2v
    • Relevant bindings include
        populationCount' :: BitVector (2 ^ d) -> Index ((2 ^ d) + 1)
          (bound at ...)

#endif

because 'tfold' expects a function of type \"@b -> b -> b@\", i.e. a function
where the arguments and result all have exactly the same type.

In order to accommodate the type of our 'Clash.Class.Num.add', where the
result is larger than the arguments, we must use a dependently typed fold in
the form of 'dtfold':

@
{\-\# LANGUAGE UndecidableInstances \#-\}
import Data.Singletons

data IIndex (f :: 'TyFun' Nat Type) :: Type
type instance 'Apply' IIndex l = 'Index' ((2^l)+1)

populationCount' :: (KnownNat k, KnownNat (2^k))
                 => BitVector (2^k) -> Index ((2^k)+1)
populationCount' bv = 'tdfold' (Proxy @IIndex)
                             (resize . bv2i . pack)
                             (\\_ x y -> 'Clash.Class.Num.add' x y)
                             ('v2t' ('Clash.Sized.Vector.bv2v' bv))
@

And we can test that it works:

>>> :t populationCount' (7 :: BitVector 16)
populationCount' (7 :: BitVector 16) :: Index 17
>>> populationCount' (7 :: BitVector 16)
3
-}
tdfold :: forall p k a . KnownNat k
       => Proxy (p :: TyFun Nat Type -> Type) -- ^ The /motive/
       -> (a -> (p @@ 0)) -- ^ Function to apply to the elements on the leafs
       -> (forall l . SNat l -> (p @@ l) -> (p @@ l) -> (p @@ (l+1)))
       -- ^ Function to fold the branches with.
       --
       -- __NB__: @SNat l@ is the depth of the two sub-branches.
       -> RTree k a -- ^ Tree to fold over.
       -> (p @@ k)
tdfold :: Proxy p
-> (a -> p @@ 0)
-> (forall (l :: Nat).
    SNat l -> (p @@ l) -> (p @@ l) -> p @@ (l + 1))
-> RTree k a
-> p @@ k
tdfold Proxy p
_ a -> p @@ 0
f forall (l :: Nat). SNat l -> (p @@ l) -> (p @@ l) -> p @@ (l + 1)
g = SNat k -> RTree k a -> p @@ k
forall (m :: Nat). SNat m -> RTree m a -> p @@ m
go SNat k
forall (n :: Nat). KnownNat n => SNat n
SNat
  where
    go :: SNat m -> RTree m a -> (p @@ m)
    go :: SNat m -> RTree m a -> p @@ m
go SNat m
_  (RLeaf a
a)   = a -> p @@ 0
f a
a
    go SNat m
sn (RBranch RTree d a
l RTree d a
r) = let sn' :: SNat d
sn' = SNat m
SNat (d + 1)
sn SNat (d + 1) -> SNat 1 -> SNat d
forall (a :: Nat) (b :: Nat). SNat (a + b) -> SNat b -> SNat a
`subSNat` SNat 1
d1
                      in  SNat d -> (p @@ d) -> (p @@ d) -> p @@ (d + 1)
forall (l :: Nat). SNat l -> (p @@ l) -> (p @@ l) -> p @@ (l + 1)
g SNat d
sn' (SNat d -> RTree d a -> p @@ d
forall (m :: Nat). SNat m -> RTree m a -> p @@ m
go SNat d
sn' RTree d a
l) (SNat d -> RTree d a -> p @@ d
forall (m :: Nat). SNat m -> RTree m a -> p @@ m
go SNat d
sn' RTree d a
r)
-- See: https://github.com/clash-lang/clash-compiler/pull/2511
{-# CLASH_OPAQUE tdfold #-}
{-# ANN tdfold hasBlackBox #-}

data TfoldTree (a :: Type) (f :: TyFun Nat Type) :: Type
type instance Apply (TfoldTree a) d = a

-- | Reduce a tree to a single element
tfold :: forall d a b .
         KnownNat d
      => (a -> b) -- ^ Function to apply to the leaves
      -> (b -> b -> b) -- ^ Function to combine the results of the reduction
                       -- of two branches
      -> RTree d a -- ^ Tree to fold reduce
      -> b
tfold :: (a -> b) -> (b -> b -> b) -> RTree d a -> b
tfold a -> b
f b -> b -> b
g = Proxy (TfoldTree b)
-> (a -> TfoldTree b @@ 0)
-> (forall (l :: Nat).
    SNat l
    -> (TfoldTree b @@ l)
    -> (TfoldTree b @@ l)
    -> TfoldTree b @@ (l + 1))
-> RTree d a
-> TfoldTree b @@ d
forall (p :: TyFun Nat Type -> Type) (k :: Nat) a.
KnownNat k =>
Proxy p
-> (a -> p @@ 0)
-> (forall (l :: Nat).
    SNat l -> (p @@ l) -> (p @@ l) -> p @@ (l + 1))
-> RTree k a
-> p @@ k
tdfold (Proxy (TfoldTree b)
forall k (t :: k). Proxy t
Proxy @(TfoldTree b)) a -> b
a -> TfoldTree b @@ 0
f ((b -> b -> b) -> SNat l -> b -> b -> b
forall a b. a -> b -> a
const b -> b -> b
g)

-- | \"'treplicate' @d a@\" returns a tree of depth /d/, and has /2^d/ copies
-- of /a/.
--
-- >>> treplicate (SNat :: SNat 3) 6
-- <<<6,6>,<6,6>>,<<6,6>,<6,6>>>
-- >>> treplicate d3 6
-- <<<6,6>,<6,6>>,<<6,6>,<6,6>>>
treplicate :: forall d a . SNat d -> a -> RTree d a
treplicate :: SNat d -> a -> RTree d a
treplicate SNat d
sn a
a = UNat d -> RTree d a
forall (n :: Nat). UNat n -> RTree n a
go (SNat d -> UNat d
forall (n :: Nat). SNat n -> UNat n
toUNat SNat d
sn)
  where
    go :: UNat n -> RTree n a
    go :: UNat n -> RTree n a
go UNat n
UZero      = a -> RTree 0 a
forall a. a -> RTree 0 a
LR a
a
    go (USucc UNat n
un) = RTree n a -> RTree n a -> RTree (n + 1) a
forall (d :: Nat) a. RTree d a -> RTree d a -> RTree (d + 1) a
BR (UNat n -> RTree n a
forall (n :: Nat). UNat n -> RTree n a
go UNat n
un) (UNat n -> RTree n a
forall (n :: Nat). UNat n -> RTree n a
go UNat n
un)
-- See: https://github.com/clash-lang/clash-compiler/pull/2511
{-# CLASH_OPAQUE treplicate #-}
{-# ANN treplicate hasBlackBox #-}

-- | \"'trepeat' @a@\" creates a tree with as many copies of /a/ as demanded by
-- the context.
--
-- >>> trepeat 6 :: RTree 2 Int
-- <<6,6>,<6,6>>
trepeat :: KnownNat d => a -> RTree d a
trepeat :: a -> RTree d a
trepeat = SNat d -> a -> RTree d a
forall (d :: Nat) a. SNat d -> a -> RTree d a
treplicate SNat d
forall (n :: Nat). KnownNat n => SNat n
SNat

data MapTree (a :: Type) (f :: TyFun Nat Type) :: Type
type instance Apply (MapTree a) d = RTree d a

-- | \"'tmap' @f t@\" is the tree obtained by apply /f/ to each element of /t/,
-- i.e.,
--
-- > tmap f (BR (LR a) (LR b)) == BR (LR (f a)) (LR (f b))
tmap :: forall d a b . KnownNat d => (a -> b) -> RTree d a -> RTree d b
tmap :: (a -> b) -> RTree d a -> RTree d b
tmap a -> b
f = Proxy (MapTree b)
-> (a -> MapTree b @@ 0)
-> (forall (l :: Nat).
    SNat l
    -> (MapTree b @@ l) -> (MapTree b @@ l) -> MapTree b @@ (l + 1))
-> RTree d a
-> MapTree b @@ d
forall (p :: TyFun Nat Type -> Type) (k :: Nat) a.
KnownNat k =>
Proxy p
-> (a -> p @@ 0)
-> (forall (l :: Nat).
    SNat l -> (p @@ l) -> (p @@ l) -> p @@ (l + 1))
-> RTree k a
-> p @@ k
tdfold (Proxy (MapTree b)
forall k (t :: k). Proxy t
Proxy @(MapTree b)) (b -> RTree 0 b
forall a. a -> RTree 0 a
LR (b -> RTree 0 b) -> (a -> b) -> a -> RTree 0 b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f) (\SNat l
_ MapTree b @@ l
l MapTree b @@ l
r -> RTree l b -> RTree l b -> RTree (l + 1) b
forall (d :: Nat) a. RTree d a -> RTree d a -> RTree (d + 1) a
BR MapTree b @@ l
RTree l b
l MapTree b @@ l
RTree l b
r)

-- | Generate a tree of indices, where the depth of the tree is determined by
-- the context.
--
-- >>> tindices :: RTree 3 (Index 8)
-- <<<0,1>,<2,3>>,<<4,5>,<6,7>>>
tindices :: forall d . KnownNat d => RTree d (Index (2^d))
tindices :: RTree d (Index (2 ^ d))
tindices =
  Proxy (MapTree (Index (2 ^ d)))
-> (Index (2 ^ d) -> MapTree (Index (2 ^ d)) @@ 0)
-> (forall (l :: Nat).
    SNat l
    -> (MapTree (Index (2 ^ d)) @@ l)
    -> (MapTree (Index (2 ^ d)) @@ l)
    -> MapTree (Index (2 ^ d)) @@ (l + 1))
-> RTree d (Index (2 ^ d))
-> MapTree (Index (2 ^ d)) @@ d
forall (p :: TyFun Nat Type -> Type) (k :: Nat) a.
KnownNat k =>
Proxy p
-> (a -> p @@ 0)
-> (forall (l :: Nat).
    SNat l -> (p @@ l) -> (p @@ l) -> p @@ (l + 1))
-> RTree k a
-> p @@ k
tdfold (Proxy (MapTree (Index (2 ^ d)))
forall k (t :: k). Proxy t
Proxy @(MapTree (Index (2^d)))) Index (2 ^ d) -> MapTree (Index (2 ^ d)) @@ 0
forall a. a -> RTree 0 a
LR
         (\s :: SNat l
s@SNat l
SNat MapTree (Index (2 ^ d)) @@ l
l MapTree (Index (2 ^ d)) @@ l
r -> RTree l (Index (2 ^ d))
-> RTree l (Index (2 ^ d)) -> RTree (l + 1) (Index (2 ^ d))
forall (d :: Nat) a. RTree d a -> RTree d a -> RTree (d + 1) a
BR MapTree (Index (2 ^ d)) @@ l
RTree l (Index (2 ^ d))
l ((Index (2 ^ d) -> Index (2 ^ d))
-> RTree l (Index (2 ^ d)) -> RTree l (Index (2 ^ d))
forall (d :: Nat) a b.
KnownNat d =>
(a -> b) -> RTree d a -> RTree d b
tmap (Index (2 ^ d) -> Index (2 ^ d) -> Index (2 ^ d)
forall a. Num a => a -> a -> a
+(SNat (2 ^ l) -> Index (2 ^ d)
forall a (n :: Nat). Num a => SNat n -> a
snatToNum (SNat l -> SNat (2 ^ l)
forall (a :: Nat). SNat a -> SNat (2 ^ a)
pow2SNat SNat l
s))) MapTree (Index (2 ^ d)) @@ l
RTree l (Index (2 ^ d))
r))
         (SNat d -> Index (2 ^ d) -> RTree d (Index (2 ^ d))
forall (d :: Nat) a. SNat d -> a -> RTree d a
treplicate SNat d
forall (n :: Nat). KnownNat n => SNat n
SNat Index (2 ^ d)
0)

data V2TTree (a :: Type) (f :: TyFun Nat Type) :: Type
type instance Apply (V2TTree a) d = RTree d a

-- | Convert a vector with /2^d/ elements to a tree of depth /d/.
--
-- >>> v2t (1 :> 2 :> 3 :> 4:> Nil)
-- <<1,2>,<3,4>>
v2t :: forall d a . KnownNat d => Vec (2^d) a -> RTree d a
v2t :: Vec (2 ^ d) a -> RTree d a
v2t = Proxy (V2TTree a)
-> (a -> V2TTree a @@ 0)
-> (forall (l :: Nat).
    SNat l
    -> (V2TTree a @@ l) -> (V2TTree a @@ l) -> V2TTree a @@ (l + 1))
-> Vec (2 ^ d) a
-> V2TTree a @@ d
forall (p :: TyFun Nat Type -> Type) (k :: Nat) a.
KnownNat k =>
Proxy p
-> (a -> p @@ 0)
-> (forall (l :: Nat).
    SNat l -> (p @@ l) -> (p @@ l) -> p @@ (l + 1))
-> Vec (2 ^ k) a
-> p @@ k
dtfold (Proxy (V2TTree a)
forall k (t :: k). Proxy t
Proxy @(V2TTree a)) a -> V2TTree a @@ 0
forall a. a -> RTree 0 a
LR ((RTree l a -> RTree l a -> RTree (l + 1) a)
-> SNat l -> RTree l a -> RTree l a -> RTree (l + 1) a
forall a b. a -> b -> a
const RTree l a -> RTree l a -> RTree (l + 1) a
forall (d :: Nat) a. RTree d a -> RTree d a -> RTree (d + 1) a
BR)

data T2VTree (a :: Type) (f :: TyFun Nat Type) :: Type
type instance Apply (T2VTree a) d = Vec (2^d) a

-- | Convert a tree of depth /d/ to a vector of /2^d/ elements
--
-- >>> (BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4)))
-- <<1,2>,<3,4>>
-- >>> t2v (BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4)))
-- 1 :> 2 :> 3 :> 4 :> Nil
t2v :: forall d a . KnownNat d => RTree d a -> Vec (2^d) a
t2v :: RTree d a -> Vec (2 ^ d) a
t2v = Proxy (T2VTree a)
-> (a -> T2VTree a @@ 0)
-> (forall (l :: Nat).
    SNat l
    -> (T2VTree a @@ l) -> (T2VTree a @@ l) -> T2VTree a @@ (l + 1))
-> RTree d a
-> T2VTree a @@ d
forall (p :: TyFun Nat Type -> Type) (k :: Nat) a.
KnownNat k =>
Proxy p
-> (a -> p @@ 0)
-> (forall (l :: Nat).
    SNat l -> (p @@ l) -> (p @@ l) -> p @@ (l + 1))
-> RTree k a
-> p @@ k
tdfold (Proxy (T2VTree a)
forall k (t :: k). Proxy t
Proxy @(T2VTree a)) (a -> Vec 0 a -> Vec (0 + 1) a
forall a (n :: Nat). a -> Vec n a -> Vec (n + 1) a
:> Vec 0 a
forall a. Vec 0 a
Nil) (\SNat l
_ T2VTree a @@ l
l T2VTree a @@ l
r -> T2VTree a @@ l
Vec (2 ^ l) a
l Vec (2 ^ l) a -> Vec (2 ^ l) a -> Vec ((2 ^ l) + (2 ^ l)) a
forall (n :: Nat) a (m :: Nat). Vec n a -> Vec m a -> Vec (n + m) a
++ T2VTree a @@ l
Vec (2 ^ l) a
r)

-- | \"'indexTree' @t n@\" returns the /n/'th element of /t/.
--
-- The bottom-left leaf had index /0/, and the bottom-right leaf has index
-- /2^d-1/, where /d/ is the depth of the tree
--
-- >>> indexTree (BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4))) 0
-- 1
-- >>> indexTree (BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4))) 2
-- 3
-- >>> indexTree (BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4))) 14
-- *** Exception: Clash.Sized.Vector.(!!): index 14 is larger than maximum index 3
-- ...
indexTree :: (KnownNat d, Enum i) => RTree d a -> i -> a
indexTree :: RTree d a -> i -> a
indexTree RTree d a
t i
i = (RTree d a -> Vec (2 ^ d) a
forall (d :: Nat) a. KnownNat d => RTree d a -> Vec (2 ^ d) a
t2v RTree d a
t) Vec (2 ^ d) a -> i -> a
forall (n :: Nat) i a. (KnownNat n, Enum i) => Vec n a -> i -> a
!! i
i

-- | \"'replaceTree' @n a t@\" returns the tree /t/ where the /n/'th element is
-- replaced by /a/.
--
-- The bottom-left leaf had index /0/, and the bottom-right leaf has index
-- /2^d-1/, where /d/ is the depth of the tree
--
-- >>> replaceTree 0 5 (BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4)))
-- <<5,2>,<3,4>>
-- >>> replaceTree 2 7 (BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4)))
-- <<1,2>,<7,4>>
-- >>> replaceTree 9 6 (BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4)))
-- <<1,2>,<3,*** Exception: Clash.Sized.Vector.replace: index 9 is larger than maximum index 3
-- ...
replaceTree :: (KnownNat d, Enum i) => i -> a -> RTree d a -> RTree d a
replaceTree :: i -> a -> RTree d a -> RTree d a
replaceTree i
i a
a = Vec (2 ^ d) a -> RTree d a
forall (d :: Nat) a. KnownNat d => Vec (2 ^ d) a -> RTree d a
v2t (Vec (2 ^ d) a -> RTree d a)
-> (RTree d a -> Vec (2 ^ d) a) -> RTree d a -> RTree d a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. i -> a -> Vec (2 ^ d) a -> Vec (2 ^ d) a
forall (n :: Nat) i a.
(KnownNat n, Enum i) =>
i -> a -> Vec n a -> Vec n a
replace i
i a
a (Vec (2 ^ d) a -> Vec (2 ^ d) a)
-> (RTree d a -> Vec (2 ^ d) a) -> RTree d a -> Vec (2 ^ d) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RTree d a -> Vec (2 ^ d) a
forall (d :: Nat) a. KnownNat d => RTree d a -> Vec (2 ^ d) a
t2v

data ZipWithTree (b :: Type) (c :: Type) (f :: TyFun Nat Type) :: Type
type instance Apply (ZipWithTree b c) d = RTree d b -> RTree d c

-- | 'tzipWith' generalizes 'tzip' by zipping with the function given as the
-- first argument, instead of a tupling function. For example, "tzipWith (+)"
-- applied to two trees produces the tree of corresponding sums.
--
-- > tzipWith f (BR (LR a1) (LR b1)) (BR (LR a2) (LR b2)) == BR (LR (f a1 a2)) (LR (f b1 b2))
tzipWith :: forall a b c d . KnownNat d => (a -> b -> c) -> RTree d a -> RTree d b -> RTree d c
tzipWith :: (a -> b -> c) -> RTree d a -> RTree d b -> RTree d c
tzipWith a -> b -> c
f = Proxy (ZipWithTree b c)
-> (a -> ZipWithTree b c @@ 0)
-> (forall (l :: Nat).
    SNat l
    -> (ZipWithTree b c @@ l)
    -> (ZipWithTree b c @@ l)
    -> ZipWithTree b c @@ (l + 1))
-> RTree d a
-> ZipWithTree b c @@ d
forall (p :: TyFun Nat Type -> Type) (k :: Nat) a.
KnownNat k =>
Proxy p
-> (a -> p @@ 0)
-> (forall (l :: Nat).
    SNat l -> (p @@ l) -> (p @@ l) -> p @@ (l + 1))
-> RTree k a
-> p @@ k
tdfold (Proxy (ZipWithTree b c)
forall k (t :: k). Proxy t
Proxy @(ZipWithTree b c)) a -> ZipWithTree b c @@ 0
a -> RTree 0 b -> RTree 0 c
lr forall (l :: Nat).
SNat l
-> (ZipWithTree b c @@ l)
-> (ZipWithTree b c @@ l)
-> ZipWithTree b c @@ (l + 1)
forall (l :: Nat).
SNat l
-> (RTree l b -> RTree l c)
-> (RTree l b -> RTree l c)
-> RTree (l + 1) b
-> RTree (l + 1) c
br
  where
    lr :: a -> RTree 0 b -> RTree 0 c
    lr :: a -> RTree 0 b -> RTree 0 c
lr a
a RTree 0 b
t = c -> RTree 0 c
forall a. a -> RTree 0 a
LR (a -> b -> c
f a
a (RTree 0 b -> b
forall a. RTree 0 a -> a
textract RTree 0 b
t))

    br :: SNat l
       -> (RTree l b -> RTree l c)
       -> (RTree l b -> RTree l c)
       -> RTree (l+1) b
       -> RTree (l+1) c
    br :: SNat l
-> (RTree l b -> RTree l c)
-> (RTree l b -> RTree l c)
-> RTree (l + 1) b
-> RTree (l + 1) c
br SNat l
_ RTree l b -> RTree l c
fl RTree l b -> RTree l c
fr RTree (l + 1) b
t = RTree l c -> RTree l c -> RTree (l + 1) c
forall (d :: Nat) a. RTree d a -> RTree d a -> RTree (d + 1) a
BR (RTree l b -> RTree l c
fl RTree l b
l) (RTree l b -> RTree l c
fr RTree l b
r)
      where
        (RTree l b
l,RTree l b
r) = RTree (l + 1) b -> (RTree l b, RTree l b)
forall (d :: Nat) a. RTree (d + 1) a -> (RTree d a, RTree d a)
tsplit RTree (l + 1) b
t


-- | 'tzip' takes two trees and returns a tree of corresponding pairs.
tzip :: KnownNat d => RTree d a -> RTree d b -> RTree d (a,b)
tzip :: RTree d a -> RTree d b -> RTree d (a, b)
tzip = (a -> b -> (a, b)) -> RTree d a -> RTree d b -> RTree d (a, b)
forall a b c (d :: Nat).
KnownNat d =>
(a -> b -> c) -> RTree d a -> RTree d b -> RTree d c
tzipWith (,)

data UnzipTree (a :: Type) (b :: Type) (f :: TyFun Nat Type) :: Type
type instance Apply (UnzipTree a b) d = (RTree d a, RTree d b)

-- | 'tunzip' transforms a tree of pairs into a tree of first components and a
-- tree of second components.
tunzip :: forall d a b . KnownNat d => RTree d (a,b) -> (RTree d a,RTree d b)
tunzip :: RTree d (a, b) -> (RTree d a, RTree d b)
tunzip = Proxy (UnzipTree a b)
-> ((a, b) -> UnzipTree a b @@ 0)
-> (forall (l :: Nat).
    SNat l
    -> (UnzipTree a b @@ l)
    -> (UnzipTree a b @@ l)
    -> UnzipTree a b @@ (l + 1))
-> RTree d (a, b)
-> UnzipTree a b @@ d
forall (p :: TyFun Nat Type -> Type) (k :: Nat) a.
KnownNat k =>
Proxy p
-> (a -> p @@ 0)
-> (forall (l :: Nat).
    SNat l -> (p @@ l) -> (p @@ l) -> p @@ (l + 1))
-> RTree k a
-> p @@ k
tdfold (Proxy (UnzipTree a b)
forall k (t :: k). Proxy t
Proxy @(UnzipTree a b)) (a, b) -> UnzipTree a b @@ 0
forall a a. (a, a) -> (RTree 0 a, RTree 0 a)
lr forall p (d :: Nat) a (d :: Nat) a.
p
-> (RTree d a, RTree d a)
-> (RTree d a, RTree d a)
-> (RTree (d + 1) a, RTree (d + 1) a)
forall (l :: Nat).
SNat l
-> (UnzipTree a b @@ l)
-> (UnzipTree a b @@ l)
-> UnzipTree a b @@ (l + 1)
br
  where
    lr :: (a, a) -> (RTree 0 a, RTree 0 a)
lr   (a
a,a
b) = (a -> RTree 0 a
forall a. a -> RTree 0 a
LR a
a,a -> RTree 0 a
forall a. a -> RTree 0 a
LR a
b)

    br :: p
-> (RTree d a, RTree d a)
-> (RTree d a, RTree d a)
-> (RTree (d + 1) a, RTree (d + 1) a)
br p
_ (RTree d a
l1,RTree d a
r1) (RTree d a
l2,RTree d a
r2) = (RTree d a -> RTree d a -> RTree (d + 1) a
forall (d :: Nat) a. RTree d a -> RTree d a -> RTree (d + 1) a
BR RTree d a
l1 RTree d a
l2, RTree d a -> RTree d a -> RTree (d + 1) a
forall (d :: Nat) a. RTree d a -> RTree d a -> RTree (d + 1) a
BR RTree d a
r1 RTree d a
r2)

-- | Given a function @f@ that is strict in its /n/th 'RTree' argument, make it
-- lazy by applying 'lazyT' to this argument:
--
-- > f x0 x1 .. (lazyT xn) .. xn_plus_k
lazyT :: KnownNat d
      => RTree d a
      -> RTree d a
lazyT :: RTree d a -> RTree d a
lazyT = (() -> a -> a) -> RTree d () -> RTree d a -> RTree d a
forall a b c (d :: Nat).
KnownNat d =>
(a -> b -> c) -> RTree d a -> RTree d b -> RTree d c
tzipWith ((a -> () -> a) -> () -> a -> a
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> () -> a
forall a b. a -> b -> a
const) (() -> RTree d ()
forall (d :: Nat) a. KnownNat d => a -> RTree d a
trepeat ())

-- | Extract the first element of a tree
--
-- The first element is defined to be the bottom-left leaf.
--
-- >>> thead $ BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4))
-- 1
thead :: RTree n a -> a
thead :: RTree n a -> a
thead (RLeaf a
x) = a
x
thead (RBranch RTree d a
x RTree d a
_) = RTree d a -> a
forall (n :: Nat) a. RTree n a -> a
thead RTree d a
x

-- | Extract the last element of a tree
--
-- The last element is defined to be the bottom-right leaf.
--
-- >>> tlast $ BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4))
-- 4
tlast :: RTree n a -> a
tlast :: RTree n a -> a
tlast (RLeaf a
x) = a
x
tlast (RBranch RTree d a
_ RTree d a
y) = RTree d a -> a
forall (n :: Nat) a. RTree n a -> a
tlast RTree d a
y

{- $scans #scans#

Scans (`Clash.Sized.Vector.scanl`, `Clash.Sized.Vector.scanr`) are similar to
folds (`Clash.Sized.Vector.foldl`, `Clash.Sized.Vector.foldr`) but return a list
of successive reduced values. When the binary reduction operator @f@ is
associative, the scan functions in this module can be characterized as follows:

> tscanl f [x1, x2, x3, ...] == [x1, x1 `f` x2, x1 `f` x2 `f` x3, ...]

> tscanr f [..., xn2, xn1, xn] == [..., xn2 `f` xn1 `f` xn, xn1 `f` xn, xn]

The scan functions in this module provide a different trade-off between circuit
size and logic depth than the default `Clash.Sized.Vector.scanl` and
`Clash.Sized.Vector.scanr` functions. When \(n\) is the number of elements,
circuit size is \(\mathcal{O}(n \cdot \log n)\), but logic depth is \(\mathcal{O}(\log n)\).
This means the resource usage will likely increase, but the maximum clock
frequency also increases due to the reduced logic depth. The exact amount of
instantiations of @f@ given a tree of depth /d/ is:

> work 0 = 0
> work d = 2 ^ (d - 1) + 2 * work (d - 1)

-}

-- | `tscanl` applied to `Vec`
--
-- >>> scanlPar (+) (1 :> 2 :> 3 :> 4 :> Nil)
-- 1 :> 3 :> 6 :> 10 :> Nil
scanlPar ::
  KnownNat n =>
  -- | Must be associative
  (a -> a -> a) ->
  Vec (2^n) a ->
  Vec (2^n) a
scanlPar :: (a -> a -> a) -> Vec (2 ^ n) a -> Vec (2 ^ n) a
scanlPar a -> a -> a
op = RTree n a -> Vec (2 ^ n) a
forall (d :: Nat) a. KnownNat d => RTree d a -> Vec (2 ^ d) a
t2v (RTree n a -> Vec (2 ^ n) a)
-> (Vec (2 ^ n) a -> RTree n a) -> Vec (2 ^ n) a -> Vec (2 ^ n) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a) -> RTree n a -> RTree n a
forall a (n :: Nat).
KnownNat n =>
(a -> a -> a) -> RTree n a -> RTree n a
tscanl a -> a -> a
op (RTree n a -> RTree n a)
-> (Vec (2 ^ n) a -> RTree n a) -> Vec (2 ^ n) a -> RTree n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vec (2 ^ n) a -> RTree n a
forall (d :: Nat) a. KnownNat d => Vec (2 ^ d) a -> RTree d a
v2t
{-# INLINE scanlPar #-}

-- | `tscanr` applied to `Vec`
--
-- >>> scanrPar (+) (1 :> 2 :> 3 :> 4 :> Nil)
-- 10 :> 9 :> 7 :> 4 :> Nil
scanrPar ::
  KnownNat n =>
   -- | Must be associative
  (a -> a -> a) ->
  Vec (2^n) a ->
  Vec (2^n) a
scanrPar :: (a -> a -> a) -> Vec (2 ^ n) a -> Vec (2 ^ n) a
scanrPar a -> a -> a
op = RTree n a -> Vec (2 ^ n) a
forall (d :: Nat) a. KnownNat d => RTree d a -> Vec (2 ^ d) a
t2v (RTree n a -> Vec (2 ^ n) a)
-> (Vec (2 ^ n) a -> RTree n a) -> Vec (2 ^ n) a -> Vec (2 ^ n) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a) -> RTree n a -> RTree n a
forall a (n :: Nat).
KnownNat n =>
(a -> a -> a) -> RTree n a -> RTree n a
tscanr a -> a -> a
op (RTree n a -> RTree n a)
-> (Vec (2 ^ n) a -> RTree n a) -> Vec (2 ^ n) a -> RTree n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vec (2 ^ n) a -> RTree n a
forall (d :: Nat) a. KnownNat d => Vec (2 ^ d) a -> RTree d a
v2t
{-# INLINE scanrPar #-}

-- | Low-depth left scan
--
-- `tscanl` is similar to `Clash.Sized.Vector.foldl`, but returns a tree of
-- successive reduced values from the left:
--
-- > tscanl f [x1, x2, x3, ...] == [x1, x1 `f` x2, x1 `f` x2 `f` x3, ...]
--
-- >>> tscanl (+) (v2t (1 :> 2 :> 3 :> 4 :> Nil))
-- <<1,3>,<6,10>>
--
-- <<doc/scanlPar.svg>>
tscanl ::
  forall a n.
  KnownNat n =>
  -- | Must be associative
  (a -> a -> a) ->
  RTree n a ->
  RTree n a
tscanl :: (a -> a -> a) -> RTree n a -> RTree n a
tscanl a -> a -> a
op RTree n a
tr =
  case RTree n a
tr of
    RLeaf a
x -> a -> RTree 0 a
forall a. a -> RTree 0 a
LR a
x
    RBranch RTree d a
x RTree d a
y ->
      let
        x' :: RTree d a
x' = (a -> a -> a) -> RTree d a -> RTree d a
forall a (n :: Nat).
KnownNat n =>
(a -> a -> a) -> RTree n a -> RTree n a
tscanl a -> a -> a
op RTree d a
x
        y' :: RTree d a
y' = (a -> a -> a) -> RTree d a -> RTree d a
forall a (n :: Nat).
KnownNat n =>
(a -> a -> a) -> RTree n a -> RTree n a
tscanl a -> a -> a
op RTree d a
y
        l :: a
l = RTree d a -> a
forall (n :: Nat) a. RTree n a -> a
tlast RTree d a
x'
      in RTree d a -> RTree d a -> RTree (d + 1) a
forall (d :: Nat) a. RTree d a -> RTree d a -> RTree (d + 1) a
BR RTree d a
x' ((a -> a) -> RTree d a -> RTree d a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
l a -> a -> a
`op`) RTree d a
y')

-- | Low-depth right scan
--
-- `tscanr` is similar to `Clash.Sized.Vector.foldr`, but returns a tree of
-- successive reduced values from the left:
--
-- > tscanr f [..., xn2, xn1, xn] == [..., xn2 `f` xn1 `f` xn, xn1 `f` xn, xn]
--
-- >>> tscanr (+) (v2t (1 :> 2 :> 3 :> 4 :> Nil))
-- <<10,9>,<7,4>>
tscanr ::
  forall a n.
  KnownNat n =>
  (a -> a -> a) ->
  RTree n a ->
  RTree n a
tscanr :: (a -> a -> a) -> RTree n a -> RTree n a
tscanr a -> a -> a
op RTree n a
tr =
  case RTree n a
tr of
    RLeaf a
x -> a -> RTree 0 a
forall a. a -> RTree 0 a
LR a
x
    RBranch RTree d a
x RTree d a
y ->
        let
          x' :: RTree d a
x' = (a -> a -> a) -> RTree d a -> RTree d a
forall a (n :: Nat).
KnownNat n =>
(a -> a -> a) -> RTree n a -> RTree n a
tscanr a -> a -> a
op RTree d a
x
          y' :: RTree d a
y' = (a -> a -> a) -> RTree d a -> RTree d a
forall a (n :: Nat).
KnownNat n =>
(a -> a -> a) -> RTree n a -> RTree n a
tscanr a -> a -> a
op RTree d a
y
          l :: a
l = RTree d a -> a
forall (n :: Nat) a. RTree n a -> a
thead RTree d a
y'
        in RTree d a -> RTree d a -> RTree (d + 1) a
forall (d :: Nat) a. RTree d a -> RTree d a -> RTree (d + 1) a
BR ((a -> a) -> RTree d a -> RTree d a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
l a -> a -> a
`op`) RTree d a
x') RTree d a
y'