module Geometry.Icosphere
  ( generateIndexed

  , icofaces
  , icopoints
  ) where

import RIO

import Geomancy.Vec3 (Vec3, vec3)
import Geomancy.Vec3 qualified as Vec3
import Data.Vector qualified as Vector
import Data.Vector.Mutable qualified as Mutable
import RIO.Vector.Partial ((!))
import RIO.Map qualified as Map
import RIO.Vector.Storable qualified as Storable
import Control.Monad.State.Strict (get, put, runState)
import Vulkan.NamedType ((:::))

import Geometry.Face (Face(..))

generateIndexed
  :: ( Fractional scale
     , Storable pos
     , Storable vertexAttr
     )
  => "subdivisions"  ::: Natural
  -> "initial"       ::: (Vec3 -> pointAttr)
  -> "midpoint"      ::: (scale -> Vec3 -> pointAttr -> pointAttr -> pointAttr)
  -> "vertex"        ::: (Vector (Vec3, pointAttr) -> [Face Int] -> Vector (pos, vertexAttr))
  -> "model vectors" ::: (Storable.Vector pos, Storable.Vector vertexAttr, Storable.Vector Word32)
generateIndexed :: forall scale pos vertexAttr pointAttr.
(Fractional scale, Storable pos, Storable vertexAttr) =>
("subdivisions" ::: Natural)
-> ("initial" ::: (Vec3 -> pointAttr))
-> ("midpoint"
    ::: (scale -> Vec3 -> pointAttr -> pointAttr -> pointAttr))
-> ("vertex"
    ::: (Vector (Vec3, pointAttr)
         -> [Face Int] -> Vector (pos, vertexAttr)))
-> "model vectors"
   ::: (Vector pos, Vector vertexAttr, Vector Word32)
generateIndexed "subdivisions" ::: Natural
details "initial" ::: (Vec3 -> pointAttr)
mkInitialAttrs "midpoint"
::: (scale -> Vec3 -> pointAttr -> pointAttr -> pointAttr)
mkMidpointAttrs "vertex"
::: (Vector (Vec3, pointAttr)
     -> [Face Int] -> Vector (pos, vertexAttr))
mkVertices =
  ( forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
Storable.convert Vector pos
pv
  , forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
Storable.convert Vector vertexAttr
av
  , forall a. Storable a => [a] -> Vector a
Storable.fromList [Word32]
iv
  )
  where
    (Vector pos
pv, Vector vertexAttr
av) = forall a b. Vector (a, b) -> (Vector a, Vector b)
Vector.unzip forall a b. (a -> b) -> a -> b
$ "vertex"
::: (Vector (Vec3, pointAttr)
     -> [Face Int] -> Vector (pos, vertexAttr))
mkVertices Vector (Vec3, pointAttr)
finalPoints [Face Int]
faces

    iv :: [Word32]
iv = do
      Face Int
face <- [Face Int]
faces
      Int
vert <- forall (t :: * -> *) a. Foldable t => t a -> [a]
toList Face Int
face
      pure $ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
vert

    ([Face Int]
faces, (Map (Int, Int) Int
_midpoints, Vector (Vec3, pointAttr)
finalPoints, Int
_finalPointsCount)) =
      forall s a. State s a -> s -> (a, s)
runState
        ([Face Int]
-> ("subdivisions" ::: Natural)
-> ("subdivisions" ::: Natural)
-> StateT
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
     Identity
     [Face Int]
go [Face Int]
icofaces "subdivisions" ::: Natural
details "subdivisions" ::: Natural
details)
        ( forall a. Monoid a => a
mempty
        , Vector (Vec3, pointAttr)
initialPoints
        , forall a. Vector a -> Int
Vector.length Vector Vec3
icopoints
        )

    maxPoints :: Int
maxPoints = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Face Int]
icofaces forall a. Num a => a -> a -> a
* (Int
4 forall a b. (Num a, Integral b) => a -> b -> a
^ "subdivisions" ::: Natural
details) forall a. Num a => a -> a -> a
- Int
8

    initialPoints :: Vector (Vec3, pointAttr)
initialPoints = forall a. (forall s. ST s (MVector s a)) -> Vector a
Vector.create do
      MVector s (Vec3, pointAttr)
v <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
Mutable.new Int
maxPoints
      forall (m :: * -> *) a b.
Monad m =>
(Int -> a -> m b) -> Vector a -> m ()
Vector.imapM_
        (forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
Mutable.unsafeWrite MVector s (Vec3, pointAttr)
v)
        (forall a b. (a -> b) -> Vector a -> Vector b
Vector.map (forall a. a -> a
id forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& "initial" ::: (Vec3 -> pointAttr)
mkInitialAttrs) Vector Vec3
icopoints)
      pure MVector s (Vec3, pointAttr)
v

    go :: [Face Int]
-> ("subdivisions" ::: Natural)
-> ("subdivisions" ::: Natural)
-> StateT
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
     Identity
     [Face Int]
go [Face Int]
curFaces "subdivisions" ::: Natural
maxLevel "subdivisions" ::: Natural
curLevel = do
      -- traceShowM $ "Inflating level " <> textShow (maxLevel - curLevel)
      case "subdivisions" ::: Natural
curLevel of
        "subdivisions" ::: Natural
0 ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure [Face Int]
curFaces
        "subdivisions" ::: Natural
_ -> do
          let scale :: scale
scale = forall a b. (Integral a, Num b) => a -> b
fromIntegral "subdivisions" ::: Natural
curLevel forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral "subdivisions" ::: Natural
maxLevel
          [[Face Int]]
next <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (scale
-> Face Int
-> StateT
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
     Identity
     [Face Int]
subdivideFace scale
scale) [Face Int]
curFaces
          [Face Int]
-> ("subdivisions" ::: Natural)
-> ("subdivisions" ::: Natural)
-> StateT
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
     Identity
     [Face Int]
go (forall a. Monoid a => [a] -> a
mconcat [[Face Int]]
next) "subdivisions" ::: Natural
maxLevel ("subdivisions" ::: Natural
curLevel forall a. Num a => a -> a -> a
- "subdivisions" ::: Natural
1)

    subdivideFace :: scale
-> Face Int
-> StateT
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
     Identity
     [Face Int]
subdivideFace scale
scale (Face Int
a Int
b Int
c) = do
      (Map (Int, Int) Int
mids, Vector (Vec3, pointAttr)
points, Int
numPoints) <- forall s (m :: * -> *). MonadState s m => m s
get

      let
        extras :: Vector (Vec3, pointAttr)
extras = forall a. Monoid a => a
mempty
        (Map (Int, Int) Int
midsAB, Vector (Vec3, pointAttr)
extrasAB, Int
ab) = scale
-> Map (Int, Int) Int
-> Vector (Vec3, pointAttr)
-> Vector (Vec3, pointAttr)
-> Int
-> (Int, Int)
-> (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
midpoint scale
scale Map (Int, Int) Int
mids   Vector (Vec3, pointAttr)
extras   Vector (Vec3, pointAttr)
points Int
numPoints (Int
a, Int
b)
        (Map (Int, Int) Int
midsBC, Vector (Vec3, pointAttr)
extrasBC, Int
bc) = scale
-> Map (Int, Int) Int
-> Vector (Vec3, pointAttr)
-> Vector (Vec3, pointAttr)
-> Int
-> (Int, Int)
-> (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
midpoint scale
scale Map (Int, Int) Int
midsAB Vector (Vec3, pointAttr)
extrasAB Vector (Vec3, pointAttr)
points Int
numPoints (Int
b, Int
c)
        (Map (Int, Int) Int
midsCA, Vector (Vec3, pointAttr)
extrasCA, Int
ca) = scale
-> Map (Int, Int) Int
-> Vector (Vec3, pointAttr)
-> Vector (Vec3, pointAttr)
-> Int
-> (Int, Int)
-> (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
midpoint scale
scale Map (Int, Int) Int
midsBC Vector (Vec3, pointAttr)
extrasBC Vector (Vec3, pointAttr)
points Int
numPoints (Int
c, Int
a)

      forall s (m :: * -> *). MonadState s m => s -> m ()
put
        ( Map (Int, Int) Int
midsCA
        , forall a. (forall s. ST s a) -> a
runST do
            MVector s (Vec3, pointAttr)
old <- forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
Vector.unsafeThaw Vector (Vec3, pointAttr)
points

            forall (m :: * -> *) a b.
Monad m =>
(Int -> a -> m b) -> Vector a -> m ()
Vector.imapM_
              ( \Int
i (Vec3, pointAttr)
point ->
                  forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
Mutable.unsafeWrite MVector s (Vec3, pointAttr)
old (Int
numPoints forall a. Num a => a -> a -> a
+ Int
i) (Vec3, pointAttr)
point
              )
              Vector (Vec3, pointAttr)
extrasCA

            forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
Vector.unsafeFreeze MVector s (Vec3, pointAttr)
old
        , Int
numPoints forall a. Num a => a -> a -> a
+ forall a. Vector a -> Int
Vector.length Vector (Vec3, pointAttr)
extrasCA
        )

      pure
        [ forall a. a -> a -> a -> Face a
Face Int
ab Int
bc Int
ca
        , forall a. a -> a -> a -> Face a
Face Int
ca Int
a Int
ab
        , forall a. a -> a -> a -> Face a
Face Int
ab Int
b Int
bc
        , forall a. a -> a -> a -> Face a
Face Int
bc Int
c Int
ca
        ]

    midpoint :: scale
-> Map (Int, Int) Int
-> Vector (Vec3, pointAttr)
-> Vector (Vec3, pointAttr)
-> Int
-> (Int, Int)
-> (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
midpoint scale
scale Map (Int, Int) Int
mids Vector (Vec3, pointAttr)
extras Vector (Vec3, pointAttr)
points Int
numPoints (Int, Int)
parents =
      case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Int, Int)
parents Map (Int, Int) Int
mids of
        Just Int
knownIx ->
          ( Map (Int, Int) Int
mids
          , Vector (Vec3, pointAttr)
extras
          , Int
knownIx
          )
        Maybe Int
Nothing ->
          let
            (Vec3
pos1, pointAttr
attr1) = Vector (Vec3, pointAttr)
points forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! forall a b. (a, b) -> a
fst (Int, Int)
parents
            (Vec3
pos2, pointAttr
attr2) = Vector (Vec3, pointAttr)
points forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! forall a b. (a, b) -> b
snd (Int, Int)
parents
            midPos :: Vec3
midPos = Float -> Vec3 -> Vec3 -> Vec3
Vec3.lerp Float
0.5 Vec3
pos1 Vec3
pos2

            newIx :: Int
newIx =
              Int
numPoints forall a. Num a => a -> a -> a
+
              forall a. Vector a -> Int
Vector.length Vector (Vec3, pointAttr)
extras

            point :: (Vec3, pointAttr)
point =
              ( Vec3
midPos
              , "midpoint"
::: (scale -> Vec3 -> pointAttr -> pointAttr -> pointAttr)
mkMidpointAttrs scale
scale Vec3
midPos pointAttr
attr1 pointAttr
attr2
              )
          in
            ( forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Int, Int)
parents Int
newIx Map (Int, Int) Int
mids
            , forall a. Vector a -> a -> Vector a
Vector.snoc Vector (Vec3, pointAttr)
extras (Vec3, pointAttr)
point
            , Int
newIx
            )

icofaces :: [Face Int]
icofaces :: [Face Int]
icofaces =
  [ -- faces around point 0
    forall a. a -> a -> a -> Face a
Face  Int
5 Int
11 Int
0
  , forall a. a -> a -> a -> Face a
Face  Int
1  Int
5 Int
0
  , forall a. a -> a -> a -> Face a
Face  Int
7  Int
1 Int
0
  , forall a. a -> a -> a -> Face a
Face Int
10  Int
7 Int
0
  , forall a. a -> a -> a -> Face a
Face Int
11 Int
10 Int
0

    -- 5 adjacent faces
  , forall a. a -> a -> a -> Face a
Face Int
9  Int
5  Int
1
  , forall a. a -> a -> a -> Face a
Face Int
4 Int
11  Int
5
  , forall a. a -> a -> a -> Face a
Face Int
2 Int
10 Int
11
  , forall a. a -> a -> a -> Face a
Face Int
6  Int
7 Int
10
  , forall a. a -> a -> a -> Face a
Face Int
8  Int
1  Int
7

    -- 5 adjacent faces around point 3
  , forall a. a -> a -> a -> Face a
Face Int
4 Int
9 Int
3
  , forall a. a -> a -> a -> Face a
Face Int
2 Int
4 Int
3
  , forall a. a -> a -> a -> Face a
Face Int
6 Int
2 Int
3
  , forall a. a -> a -> a -> Face a
Face Int
8 Int
6 Int
3
  , forall a. a -> a -> a -> Face a
Face Int
9 Int
8 Int
3

    -- 5 adjacent faces
  , forall a. a -> a -> a -> Face a
Face  Int
5 Int
9 Int
4
  , forall a. a -> a -> a -> Face a
Face Int
11 Int
4 Int
2
  , forall a. a -> a -> a -> Face a
Face Int
10 Int
2 Int
6
  , forall a. a -> a -> a -> Face a
Face  Int
7 Int
6 Int
8
  , forall a. a -> a -> a -> Face a
Face  Int
1 Int
8 Int
9
  ]

icopoints :: Vector Vec3
icopoints :: Vector Vec3
icopoints = forall a. [a] -> Vector a
Vector.fromList
  [ Float -> Float -> Float -> Vec3
vec3 (-Float
1) Float
0   Float
t
  , Float -> Float -> Float -> Vec3
vec3   Float
1  Float
0   Float
t
  , Float -> Float -> Float -> Vec3
vec3 (-Float
1) Float
0 (-Float
t)
  , Float -> Float -> Float -> Vec3
vec3   Float
1  Float
0 (-Float
t)

  , Float -> Float -> Float -> Vec3
vec3   Float
0 (-Float
t) (-Float
1)
  , Float -> Float -> Float -> Vec3
vec3   Float
0 (-Float
t)   Float
1
  , Float -> Float -> Float -> Vec3
vec3   Float
0   Float
t  (-Float
1)
  , Float -> Float -> Float -> Vec3
vec3   Float
0   Float
t    Float
1

  , Float -> Float -> Float -> Vec3
vec3   Float
t     Float
1  Float
0
  , Float -> Float -> Float -> Vec3
vec3   Float
t   (-Float
1) Float
0
  , Float -> Float -> Float -> Vec3
vec3 (-Float
t)    Float
1  Float
0
  , Float -> Float -> Float -> Vec3
vec3 (-Float
t)  (-Float
1) Float
0
  ]
  where
    t :: Float
t = (Float
1.0 forall a. Num a => a -> a -> a
+ forall a. Floating a => a -> a
sqrt Float
5.0) forall a. Fractional a => a -> a -> a
/ Float
2.0