QuickSelect Solves the problem of finding the kth lowest element in an unsorted array. QuickPartition is a variation on that problem, where we want the k lowest elements from an unsorted array.
I have written up two different versions in Haskell:
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import Control.Monad.Primitive
import Control.Monad.ST.Strict
import Data.Foldable
import qualified Data.Sequence as S
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as VM
This version uses Data.Sequence
to partition and
accumulate the elements to be returned. If you need a
Data.Vector
you can convert from a
Data.Sequence
. This version has the advantage of being
simple. It splits the array into 3 parts in order to guarantee forward
progress.
The basic algorithm is
Set keep to empty
Choose a pivot
Partition the array based on pivot, look at upper half
quickPartition ::
forall a. (Ord a, V.Unbox a)
=> Int
-> V.Vector a
-> S.Seq a
= part k as S.empty
quickPartition k as where
toVec :: S.Seq a -> V.Vector a
= V.fromListN (S.length s) $ toList s
toVec s split :: a -> V.Vector a -> (S.Seq a, S.Seq a, S.Seq a)
= go as S.empty S.empty S.empty
split pivot as where
go ::
V.Vector a
-> S.Seq a
-> S.Seq a
-> S.Seq a
-> (S.Seq a, S.Seq a, S.Seq a)
last
go as first mid | V.null as = (first, mid, last)
| otherwise =
let !h = V.head as
!t = V.tail as
in if | h > pivot -> go t first mid (last S.|> h)
| h < pivot -> go t (first S.|> h) mid last
| h == pivot -> go t first (mid S.|> h) last
0 as keep = keep
part
part k as keep| V.null as = keep
| otherwise =
let !midIdx = V.length as `div` 2
!pivot = as V.! midIdx
last) = split pivot as
(first, mid, !count = S.length first
!midCount = S.length mid
in if | count <= k && count + midCount >= k ->
<> first <> S.take (k - count) mid
keep | count < k ->
- count - midCount) (toVec last) (keep <> first <> mid)
part (k | count > k -> part k (toVec first) keep
This version doesn’t use Data.Sequence
Instead it uses
mutable arrays and does a more standard quick sort, recursing only into
the necessary parts. The Data.Vector
library has the nice
feature of allowing slices. This makes the partition step easier to
write, because we don’t have to keep track of low and high, or whether
they are inclusive, etc. The slices modify the underlying array, so at
the end we clone the relevant portion and then freeze it.
I will admit that array manipulation and loops are uglier in Haskell.
quickPartition2 ::
forall a. (Ord a, V.Unbox a, Show a)
=> Int
-> V.Vector a
-> V.Vector a
=
quickPartition2 k as if V.length as <= k
then as
else runST $ arrange k as
where
arrange :: (PrimMonad m) => Int -> V.Vector a -> m (V.Vector a)
= do
arrange k as <- V.thaw as
asM
arrangeM k asM=<< (VM.clone $ VM.slice 0 k asM)
V.unsafeFreeze arrangeM :: (PrimMonad m) => Int -> VM.MVector (PrimState m) a -> m ()
0 asM = return ()
arrangeM = do
arrangeM k asM let !sz = VM.length asM
!midIdx = sz `div` 2
<- VM.read asM midIdx
pivot <- partitionStep pivot asM
(minN, maxN) if | minN <= k && k <= maxN -> return ()
| maxN < k -> arrangeM (k - maxN) $ VM.slice maxN (sz - maxN) asM
| minN > k -> arrangeM k $ VM.slice 0 minN asM
partitionStep ::
PrimMonad m) => a -> VM.MVector (PrimState m) a -> m (Int, Int)
(
partitionStep pivot asM| VM.length asM == 1 = return (0, 1)
partitionStep pivot asM| otherwise = meetLoop 0 ((VM.length asM) - 1)
where
meetLoop i j| i > j = return (j + 1, i)
| otherwise = do
<- findLowLoop i
i' <- findHighLoop j
j' if i' <= j'
then do
< j') $ VM.swap asM i' j'
when (i' + 1) (j' - 1)
meetLoop (i' else meetLoop i' j'
findLowLoop i| i == VM.length asM = return i
| otherwise = do
<- VM.read asM i
iElement if iElement >= pivot
then return i
else findLowLoop (i + 1)
= do
findHighLoop j <- VM.read asM j
jElement if jElement <= pivot
then return j
else findHighLoop (j - 1)