Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions vec/src/Data/Vec/DataFamily/SpineStrict.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
Expand Down Expand Up @@ -91,6 +92,11 @@ module Data.Vec.DataFamily.SpineStrict (
ifoldMap1,
foldr,
ifoldr,
-- * Scans
scanr,
scanl,
scanr1,
scanl1,
-- * Special folds
length,
null,
Expand Down Expand Up @@ -582,10 +588,10 @@ last :: forall n a. N.SNatI n => Vec ('S n) a -> a
last xs = getLast (N.induction1 start step) xs where
start :: Last 'Z a
start = Last $ \(x:::VNil) -> x

step :: Last m a -> Last ('S m) a
step (Last rec) = Last $ \(_ ::: ys) -> rec ys


newtype Last n a = Last { getLast :: Vec ('S n) a -> a }

Expand All @@ -596,7 +602,7 @@ init :: forall n a. N.SNatI n => Vec ('S n) a -> Vec n a
init xs = getInit (N.induction1 start step) xs where
start :: Init 'Z a
start = Init (const VNil)

step :: Init m a -> Init ('S m) a
step (Init rec) = Init $ \(y ::: ys) -> y ::: rec ys

Expand Down Expand Up @@ -845,6 +851,43 @@ ifoldr = getIFoldr $ N.induction1 start step where

newtype IFoldr a n b = IFoldr { getIFoldr :: (Fin n -> a -> b -> b) -> b -> Vec n a -> b }

scanr :: forall a b n. N.SNatI n => (a -> b -> b) -> b -> Vec n a -> Vec ('S n) b
scanr f z = getScanr $ N.induction1 start step where
start :: Scanr a 'Z b
start = Scanr $ \_ -> singleton z

step :: Scanr a m b -> Scanr a ('S m) b
step (Scanr go) = Scanr $ \(x ::: xs) -> let ys@(y ::: _) = go xs in f x y ::: ys

newtype Scanr a n b = Scanr { getScanr :: Vec n a -> Vec ('S n) b }

scanl :: forall a b n. N.SNatI n => (b -> a -> b) -> b -> Vec n a -> Vec ('S n) b
scanl f = getScanl $ N.induction1 start step where
start :: Scanl a 'Z b
start = Scanl $ \z VNil -> singleton z

step :: Scanl a m b -> Scanl a ('S m) b
step (Scanl go) = Scanl $ \(!acc) (x ::: xs) -> acc ::: go (f acc x) xs

newtype Scanl a n b = Scanl { getScanl :: b -> Vec n a -> Vec ('S n) b }

scanr1 :: forall a n. N.SNatI n => (a -> a -> a) -> Vec n a -> Vec n a
scanr1 f = getScanr1 $ N.induction1 start step where
start :: Scanr1 'Z a
start = Scanr1 $ \_ -> VNil

step :: forall m. N.SNatI m => Scanr1 m a -> Scanr1 ('S m) a
step (Scanr1 go) = Scanr1 $ \(x ::: xs) -> case N.snat :: N.SNat m of
N.SZ -> x ::: VNil
N.SS -> let ys@(y ::: _) = go xs in f x y ::: ys

newtype Scanr1 n a = Scanr1 { getScanr1 :: Vec n a -> Vec n a }

scanl1 :: forall a n. N.SNatI n => (a -> a -> a) -> Vec n a -> Vec n a
scanl1 f xs = case N.snat :: N.SNat n of
N.SZ -> VNil
N.SS -> let (y ::: ys) = xs in scanl f y ys

-- | Yield the length of a 'Vec'. /O(n)/
length :: forall n a. N.SNatI n => Vec n a -> Int
length _ = getLength l where
Expand Down
28 changes: 28 additions & 0 deletions vec/src/Data/Vec/Lazy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ module Data.Vec.Lazy (
foldr,
ifoldr,
foldl',
-- * Scans
scanr,
scanl,
scanr1,
scanl1,
-- * Special folds
length,
null,
Expand Down Expand Up @@ -691,6 +696,29 @@ foldl' f z = go z where
go !acc VNil = acc
go !acc (x ::: xs) = go (f acc x) xs

scanr :: forall a b n. (a -> b -> b) -> b -> Vec n a -> Vec ('S n) b
scanr f z = go where
go :: Vec m a -> Vec ('S m) b
go VNil = singleton z
go (x ::: xs) = case go xs of ys@(y ::: _) -> f x y ::: ys

scanl :: forall a b n. (b -> a -> b) -> b -> Vec n a -> Vec ('S n) b
scanl f = go where
go :: b -> Vec m a -> Vec ('S m) b
go !acc VNil = acc ::: VNil
go !acc (x ::: xs) = acc ::: go (f acc x) xs

scanr1 :: forall a n. (a -> a -> a) -> Vec n a -> Vec n a
scanr1 f = go where
go :: Vec m a -> Vec m a
go VNil = VNil
go (x ::: VNil) = x ::: VNil
go (x ::: xs@(_ ::: _)) = case go xs of ys@(y ::: _) -> f x y ::: ys

scanl1 :: forall a n. (a -> a -> a) -> Vec n a -> Vec n a
scanl1 _ VNil = VNil
scanl1 f (x ::: xs) = scanl f x xs

-- | Yield the length of a 'Vec'. /O(n)/
length :: Vec n a -> Int
length = go 0 where
Expand Down
50 changes: 47 additions & 3 deletions vec/src/Data/Vec/Lazy/Inline.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
Expand Down Expand Up @@ -51,6 +52,11 @@ module Data.Vec.Lazy.Inline (
ifoldMap1,
foldr,
ifoldr,
-- * Scans
scanr,
scanl,
scanr1,
scanl1,
-- * Special folds
length,
null,
Expand Down Expand Up @@ -260,10 +266,10 @@ last :: forall n a. N.SNatI n => Vec ('S n) a -> a
last xs = getLast (N.induction1 start step) xs where
start :: Last 'Z a
start = Last $ \(x:::VNil) -> x

step :: Last m a -> Last ('S m) a
step (Last rec) = Last $ \(_ ::: ys) -> rec ys


newtype Last n a = Last { getLast :: Vec ('S n) a -> a }

Expand All @@ -274,7 +280,7 @@ init :: forall n a. N.SNatI n => Vec ('S n) a -> Vec n a
init xs = getInit (N.induction1 start step) xs where
start :: Init 'Z a
start = Init (const VNil)

step :: Init m a -> Init ('S m) a
step (Init rec) = Init $ \(y ::: ys) -> y ::: rec ys

Expand Down Expand Up @@ -520,6 +526,44 @@ ifoldr = getIFoldr $ N.induction1 start step where

newtype IFoldr a n b = IFoldr { getIFoldr :: (Fin n -> a -> b -> b) -> b -> Vec n a -> b }

scanr :: forall a b n. N.SNatI n => (a -> b -> b) -> b -> Vec n a -> Vec ('S n) b
scanr f z = getScanr $ N.induction1 start step where
start :: Scanr a 'Z b
start = Scanr $ \_ -> singleton z

step :: Scanr a m b -> Scanr a ('S m) b
step (Scanr go) = Scanr $ \(x ::: xs) -> case go xs of
ys@(y ::: _) -> f x y ::: ys

newtype Scanr a n b = Scanr { getScanr :: Vec n a -> Vec ('S n) b }

scanl :: forall a b n. N.SNatI n => (b -> a -> b) -> b -> Vec n a -> Vec ('S n) b
scanl f = getScanl $ N.induction1 start step where
start :: Scanl a 'Z b
start = Scanl $ \z VNil -> singleton z

step :: Scanl a m b -> Scanl a ('S m) b
step (Scanl go) = Scanl $ \(!acc) (x ::: xs) -> acc ::: go (f acc x) xs

newtype Scanl a n b = Scanl { getScanl :: b -> Vec n a -> Vec ('S n) b }

scanr1 :: forall a n. N.SNatI n => (a -> a -> a) -> Vec n a -> Vec n a
scanr1 f = getScanr1 $ N.induction1 start step where
start :: Scanr1 'Z a
start = Scanr1 $ \_ -> VNil

step :: forall m. N.SNatI m => Scanr1 m a -> Scanr1 ('S m) a
step (Scanr1 go) = Scanr1 $ \(x ::: xs) -> case N.snat :: N.SNat m of
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't feel right. You shouldn't need to check length in the step case. I can take a look myself if you cannot find a way to avoid it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In scanr1, the last element is special (it is the zero element), and so the 0 -> 1 step and the m -> m + 1 step (where m > 0) are different. But yeah please let me know if there's a better way to write this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phadej - let me know if you have any updates on this!

N.SZ -> x ::: VNil
N.SS -> case go xs of ys@(y ::: _) -> f x y ::: ys

newtype Scanr1 n a = Scanr1 { getScanr1 :: Vec n a -> Vec n a }

scanl1 :: forall a n. N.SNatI n => (a -> a -> a) -> Vec n a -> Vec n a
scanl1 f xs = case N.snat :: N.SNat n of
N.SZ -> VNil
N.SS -> case xs of y ::: ys -> scanl f y ys

-- | Yield the length of a 'Vec'. /O(n)/
length :: forall n a. N.SNatI n => Vec n a -> Int
length _ = getLength l where
Expand Down
2 changes: 1 addition & 1 deletion vec/src/Data/Vec/Pull.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
--
-- The module tries to have same API as "Data.Vec.Lazy", missing bits:
-- @withDict@, @toPull@, @fromPull@, @traverse@ (and variants),
-- @(++)@, @concat@ and @split@.
-- @scanr@ (and variants), @(++)@, @concat@ and @split@.
module Data.Vec.Pull (
Vec (..),
-- * Construction
Expand Down
68 changes: 66 additions & 2 deletions vec/test/Inspection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ lhsLast = I.last $ 'a' ::: 'b' ::: 'c' ::: VNil
lhsLast' :: Char
lhsLast' = L.last $ 'a' ::: 'b' ::: 'c' :::VNil

rhsLast :: Char
rhsLast :: Char
rhsLast = 'c'

inspect $ 'lhsLast === 'rhsLast
Expand Down Expand Up @@ -167,4 +167,68 @@ rhsToNonEmpty :: NonEmpty Char
rhsToNonEmpty = 'a' :| ['b', 'c']

inspect $ 'lhsToNonEmpty === 'rhsToNonEmpty
inspect $ 'lhsToNonEmpty' =/= 'rhsToNonEmpty
inspect $ 'lhsToNonEmpty' =/= 'rhsToNonEmpty

-------------------------------------------------------------------------------
-- scanr
-------------------------------------------------------------------------------

lhsScanr :: Vec N.Nat5 Int
lhsScanr = I.scanr (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

lhsScanr' :: Vec N.Nat5 Int
lhsScanr' = L.scanr (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanr :: Vec N.Nat5 Int
rhsScanr = (-2) ::: 3 ::: (-1) ::: 4 ::: 0 ::: VNil

inspect $ 'lhsScanr === 'rhsScanr
inspect $ 'lhsScanr' =/= 'rhsScanr

-------------------------------------------------------------------------------
-- scanl
-------------------------------------------------------------------------------

lhsScanl :: Vec N.Nat5 Int
lhsScanl = I.scanl (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

lhsScanl' :: Vec N.Nat5 Int
lhsScanl' = L.scanl (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanl :: Vec N.Nat5 Int
rhsScanl = 0 ::: (-1) ::: (-3) ::: (-6) ::: (-10) ::: VNil

inspect $ 'lhsScanl === 'rhsScanl
inspect $ 'lhsScanl' =/= 'rhsScanl

-------------------------------------------------------------------------------
-- scanr1
-------------------------------------------------------------------------------

lhsScanr1 :: Vec N.Nat4 Int
lhsScanr1 = I.scanr1 (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

lhsScanr1' :: Vec N.Nat4 Int
lhsScanr1' = L.scanr1 (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanr1 :: Vec N.Nat4 Int
rhsScanr1 = (-2) ::: 3 ::: (-1) ::: 4 ::: VNil

inspect $ 'lhsScanr1 === 'rhsScanr1
inspect $ 'lhsScanr1' =/= 'rhsScanr1

-------------------------------------------------------------------------------
-- scanl1
-------------------------------------------------------------------------------

lhsScanl1 :: Vec N.Nat4 Int
lhsScanl1 = I.scanl1 (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

lhsScanl1' :: Vec N.Nat4 Int
lhsScanl1' = L.scanl1 (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanl1 :: Vec N.Nat4 Int
rhsScanl1 = 1 ::: (-1) ::: (-4) ::: (-8) ::: VNil

inspect $ 'lhsScanl1 === 'rhsScanl1
inspect $ 'lhsScanl1' =/= 'rhsScanl1
48 changes: 48 additions & 0 deletions vec/test/Inspection/DataFamily/SpineStrict.hs
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,51 @@ rhsReverse :: Vec N.Nat3 Char
rhsReverse = 'a' ::: 'b' ::: 'c' ::: VNil

inspect $ 'lhsReverse === 'rhsReverse

-------------------------------------------------------------------------------
-- scanr
-------------------------------------------------------------------------------

lhsScanr :: Vec N.Nat5 Int
lhsScanr = I.scanr (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanr :: Vec N.Nat5 Int
rhsScanr = (-2) ::: 3 ::: (-1) ::: 4 ::: 0 ::: VNil

inspect $ 'lhsScanr === 'rhsScanr

-------------------------------------------------------------------------------
-- scanl
-------------------------------------------------------------------------------

lhsScanl :: Vec N.Nat5 Int
lhsScanl = I.scanl (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanl :: Vec N.Nat5 Int
rhsScanl = 0 ::: (-1) ::: (-3) ::: (-6) ::: (-10) ::: VNil

inspect $ 'lhsScanl === 'rhsScanl

-------------------------------------------------------------------------------
-- scanr1
-------------------------------------------------------------------------------

lhsScanr1 :: Vec N.Nat4 Int
lhsScanr1 = I.scanr1 (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanr1 :: Vec N.Nat4 Int
rhsScanr1 = (-2) ::: 3 ::: (-1) ::: 4 ::: VNil

inspect $ 'lhsScanr1 === 'rhsScanr1

-------------------------------------------------------------------------------
-- scanl1
-------------------------------------------------------------------------------

lhsScanl1 :: Vec N.Nat4 Int
lhsScanl1 = I.scanl1 (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanl1 :: Vec N.Nat4 Int
rhsScanl1 = 1 ::: (-1) ::: (-4) ::: (-8) ::: VNil

inspect $ 'lhsScanl1 === 'rhsScanl1