QuickPartition, a Variation on QuickSelect

QuickSelect Solves the problem of finding the kth lowest element in an unsorted array. QuickPartition is a variation on that problem, where we want the k lowest elements from an unsorted array.

I have written up two different versions in Haskell:

Preamble

{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE MultiWayIf          #-}
{-# LANGUAGE OverloadedLists     #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Main where

import           Control.Monad.Primitive
import           Control.Monad.ST.Strict
import           Data.Foldable
import qualified Data.Sequence               as S
import qualified Data.Vector.Unboxed         as V
import qualified Data.Vector.Unboxed.Mutable as VM

First Version

This version uses Data.Sequence to partition and accumulate the elements to be returned. If you need a Data.Vector you can convert from a Data.Sequence. This version has the advantage of being simple. It splits the array into 3 parts in order to guarantee forward progress.

The basic algorithm is

  1. Set keep to empty

  2. Choose a pivot

  3. Partition the array based on pivot, look at upper half

    1. Too many -> re run just on upper half
    2. Too few -> add upper half to keep, run on lower half
    3. Just right -> add upper half to keep, return
quickPartition ::
     forall a. (Ord a, V.Unbox a)
  => Int
  -> V.Vector a
  -> S.Seq a
quickPartition k as = part k as S.empty
  where
    toVec :: S.Seq a -> V.Vector a
    toVec s = V.fromListN (S.length s) $ toList s
    split :: a -> V.Vector a -> (S.Seq a, S.Seq a, S.Seq a)
    split pivot as = go as S.empty S.empty S.empty
      where
        go ::
             V.Vector a
          -> S.Seq a
          -> S.Seq a
          -> S.Seq a
          -> (S.Seq a, S.Seq a, S.Seq a)
        go as first mid last
          | V.null as = (first, mid, last)
          | otherwise =
            let !h = V.head as
                !t = V.tail as
             in if | h > pivot -> go t first mid (last S.|> h)
                   | h < pivot -> go t (first S.|> h) mid last
                   | h == pivot -> go t first (mid S.|> h) last
    part 0 as keep = keep
    part k as keep
      | V.null as = keep
      | otherwise =
        let !midIdx = V.length as `div` 2
            !pivot = as V.! midIdx
            (first, mid, last) = split pivot as
            !count = S.length first
            !midCount = S.length mid
         in if | count <= k && count + midCount >= k ->
                 keep <> first <> S.take (k - count) mid
               | count < k ->
                 part (k - count - midCount) (toVec last) (keep <> first <> mid)
               | count > k -> part k (toVec first) keep

Second Version

This version doesn’t use Data.Sequence Instead it uses mutable arrays and does a more standard quick sort, recursing only into the necessary parts. The Data.Vector library has the nice feature of allowing slices. This makes the partition step easier to write, because we don’t have to keep track of low and high, or whether they are inclusive, etc. The slices modify the underlying array, so at the end we clone the relevant portion and then freeze it.

I will admit that array manipulation and loops are uglier in Haskell.

quickPartition2 ::
     forall a. (Ord a, V.Unbox a, Show a)
  => Int
  -> V.Vector a
  -> V.Vector a
quickPartition2 k as =
  if V.length as <= k
    then as
    else runST $ arrange k as
  where
    arrange :: (PrimMonad m) => Int -> V.Vector a -> m (V.Vector a)
    arrange k as = do
      asM <- V.thaw as
      arrangeM k asM
      V.unsafeFreeze =<< (VM.clone $ VM.slice 0 k asM)
    arrangeM :: (PrimMonad m) => Int -> VM.MVector (PrimState m) a -> m ()
    arrangeM 0 asM = return ()
    arrangeM k asM = do
      let !sz = VM.length asM
          !midIdx = sz `div` 2
      pivot <- VM.read asM midIdx
      (minN, maxN) <- partitionStep pivot asM
      if | minN <= k && k <= maxN -> return ()
         | maxN < k -> arrangeM (k - maxN) $ VM.slice maxN (sz - maxN) asM
         | minN > k -> arrangeM k $ VM.slice 0 minN asM
    partitionStep ::
         (PrimMonad m) => a -> VM.MVector (PrimState m) a -> m (Int, Int)
    partitionStep pivot asM
      | VM.length asM == 1 = return (0, 1)
    partitionStep pivot asM
      | otherwise = meetLoop 0 ((VM.length asM) - 1)
      where
        meetLoop i j
          | i > j = return (j + 1, i)
          | otherwise = do
            i' <- findLowLoop i
            j' <- findHighLoop j
            if i' <= j'
              then do
                when (i' < j') $ VM.swap asM i' j'
                meetLoop (i' + 1) (j' - 1)
              else meetLoop i' j'
        findLowLoop i
          | i == VM.length asM = return i
          | otherwise = do
            iElement <- VM.read asM i
            if iElement >= pivot
              then return i
              else findLowLoop (i + 1)
        findHighLoop j = do
          jElement <- VM.read asM j
          if jElement <= pivot
            then return j
            else findHighLoop (j - 1)