{-# LANGUAGE ExtendedDefaultRules #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} module IndexSpec (spec) where import Control.Arrow ((&&&)) import Lens.Family import Test.Hspec import Test.QuickCheck import Torch.DType import Torch.Index import Torch.Lens import Torch.Tensor import Torch.TensorFactories spec :: Spec spec = do describe "slice" $ do it "None" $ do [slice|None|] `shouldBe` None it "Ellipsis" $ do [slice|Ellipsis|] `shouldBe` Ellipsis it "..." $ do [slice|...|] `shouldBe` Ellipsis it "123" $ do [slice|123|] `shouldBe` 123 it "-123" $ do [slice|-123|] `shouldBe` -123 it "True" $ do [slice|True|] `shouldBe` True it "False" $ do [slice|False|] `shouldBe` False it ":" $ do [slice|:|] `shouldBe` Slice () it "::" $ do [slice|::|] `shouldBe` Slice () it "1:" $ do [slice|1:|] `shouldBe` Slice (1, None) it "1::" $ do [slice|1::|] `shouldBe` Slice (1, None) it ":3" $ do [slice|:3|] `shouldBe` Slice (None, 3) it ":3:" $ do [slice|:3:|] `shouldBe` Slice (None, 3) it "::2" $ do [slice|::2|] `shouldBe` Slice (None, None, 2) it "1:3" $ do [slice|1:3|] `shouldBe` Slice (1, 3) it "1::2" $ do [slice|1::2|] `shouldBe` Slice (1, None, 2) it ":3:2" $ do [slice|:3:2|] `shouldBe` Slice (None, 3, 2) it "1:3:2" $ do [slice|1:3:2|] `shouldBe` Slice (1, 3, 2) it "1,2,3" $ do [slice|1,2,3|] `shouldBe` (1, 2, 3) it "1 , 2, 3" $ do [slice|1 , 2, 3|] `shouldBe` (1, 2, 3) it "1 , 2, 3" $ do let i = 1 [slice|{i} , 2, 3|] `shouldBe` (1, 2, 3) describe "indexing" $ do it "pick up a value" $ do let x = asTensor ([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]] :: [[[Int]]]) r = x ! [slice|1,0,2|] (dtype &&& shape &&& asValue) r `shouldBe` (Int64, ([], 8 :: Int)) it "intercalate" $ do let x = zeros' [6] i = [slice|0::2|] (dtype &&& shape &&& asValue) (maskedFill x i (arange' 1 4 1)) `shouldBe` (Float, ([6], [1, 0, 2, 0, 3, 0] :: [Float])) it "negative index" $ do let x = arange' 1 5 1 i = [slice|-1|] (dtype &&& shape &&& asValue) (x ! i) `shouldBe` (Float, ([], 4 :: Float)) describe "indexing with lens" $ do it "pick up a value" $ do let x = asTensor ([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]] :: [[[Int]]]) r = x ^. [lslice|1,0,2|] (dtype &&& shape &&& asValue) r `shouldBe` (Int64, ([], 8 :: Int)) it "intercalate" $ do let x = zeros' [6] i = [lslice|0::2|] :: Lens' Tensor Tensor (dtype &&& shape &&& asValue) (x & i .~ arange' 1 4 1) `shouldBe` (Float, ([6], [1, 0, 2, 0, 3, 0] :: [Float])) it "negative index" $ do let x = arange' 1 5 1 i = [lslice|-1|] (dtype &&& shape &&& asValue) (x ^. i) `shouldBe` (Float, ([], 4 :: Float))