--------------------------------------------------------------------------------
-- |
-- Module      :  HGeometry.StringSearch.KMP
-- Copyright   :  (C) Frank Staals
-- License     :  see the LICENSE file
-- Maintainer  :  Frank Staals
--
-- Implementation of Knuth-Morris-Pratt String-searching
-- algorithm. The exposition is based on that of Goodrich and
-- Tamassia in "Data Structures and Algorithms in Java 2nd Edition".
--
--------------------------------------------------------------------------------
module HGeometry.StringSearch.KMP
  ( isSubStringOf
  , kmpMatch
  , buildFailureFunction
  ) where

import           Control.Monad.ST
import qualified Data.Vector as V
import           Data.Vector.Generic ((!))
import qualified Data.Vector.Unboxed as UV
import qualified Data.Vector.Unboxed.Mutable as UMV
import qualified VectorBuilder.Builder as Builder
import qualified VectorBuilder.Vector as Builder


--------------------------------------------------------------------------------

-- | Constructs the failure function.
--
-- running time: \(O(m)\).
buildFailureFunction   :: forall a. Eq a => V.Vector a -> UV.Vector Int
buildFailureFunction :: forall a. Eq a => Vector a -> Vector Int
buildFailureFunction Vector a
p = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
UV.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
                           f <- Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UMV.new Int
m
                           go f 1 0
   where
     m :: Int
m = Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
p
     go                        :: UMV.MVector s Int -> Int -> Int -> ST s (UMV.MVector s Int)
     go :: forall s. MVector s Int -> Int -> Int -> ST s (MVector s Int)
go MVector s Int
f Int
i Int
j | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m         = MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s Int
f
              | Vector a
p Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
j a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== Vector a
p Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
i = MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UMV.write MVector s Int
MVector (PrimState (ST s)) Int
f Int
i (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) ST s () -> ST s (MVector s Int) -> ST s (MVector s Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>>  MVector s Int -> Int -> Int -> ST s (MVector s Int)
forall s. MVector s Int -> Int -> Int -> ST s (MVector s Int)
go MVector s Int
f (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
              | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0          = MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UMV.read  MVector s Int
MVector (PrimState (ST s)) Int
f (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)   ST s Int -> (Int -> ST s (MVector s Int)) -> ST s (MVector s Int)
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVector s Int -> Int -> Int -> ST s (MVector s Int)
forall s. MVector s Int -> Int -> Int -> ST s (MVector s Int)
go MVector s Int
f Int
i
              | Bool
otherwise      = MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UMV.write MVector s Int
MVector (PrimState (ST s)) Int
f Int
i Int
0     ST s () -> ST s (MVector s Int) -> ST s (MVector s Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>>  MVector s Int -> Int -> Int -> ST s (MVector s Int)
forall s. MVector s Int -> Int -> Int -> ST s (MVector s Int)
go MVector s Int
f (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
0

-- | Test if the first argument, the pattern p, occurs as a consecutive subsequence in t.
--
-- running time: \(O(n+m)\), where p has length \(m\) and t has length \(n\).
isSubStringOf       :: (Eq a, Foldable p, Foldable t) => p a -> t a -> Maybe Int
p a
p isSubStringOf :: forall a (p :: * -> *) (t :: * -> *).
(Eq a, Foldable p, Foldable t) =>
p a -> t a -> Maybe Int
`isSubStringOf` t a
t = Vector a -> Vector a -> Maybe Int
forall a. Eq a => Vector a -> Vector a -> Maybe Int
kmpMatch (Builder a -> Vector a
forall (vector :: * -> *) element.
Vector vector element =>
Builder element -> vector element
Builder.build (Builder a -> Vector a) -> (p a -> Builder a) -> p a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. p a -> Builder a
forall (foldable :: * -> *) element.
Foldable foldable =>
foldable element -> Builder element
Builder.foldable (p a -> Vector a) -> p a -> Vector a
forall a b. (a -> b) -> a -> b
$ p a
p)
                               (Builder a -> Vector a
forall (vector :: * -> *) element.
Vector vector element =>
Builder element -> vector element
Builder.build (Builder a -> Vector a) -> (t a -> Builder a) -> t a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t a -> Builder a
forall (foldable :: * -> *) element.
Foldable foldable =>
foldable element -> Builder element
Builder.foldable (t a -> Vector a) -> t a -> Vector a
forall a b. (a -> b) -> a -> b
$ t a
t)


-- | Test if the first argument, the pattern p, occurs as a consecutive subsequence in t.
--
-- running time: \(O(n+m)\), where p has length \(m\) and t has length \(n\).
kmpMatch                 :: Eq a => V.Vector a -> V.Vector a -> Maybe Int
kmpMatch :: forall a. Eq a => Vector a -> Vector a -> Maybe Int
kmpMatch Vector a
p Vector a
t | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0    = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0
             | Bool
otherwise = Int -> Int -> Maybe Int
kmp Int
0 Int
0
  where
    m :: Int
m = Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
p
    n :: Int
n = Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
t
    f :: Vector Int
f = Vector a -> Vector Int
forall a. Eq a => Vector a -> Vector Int
buildFailureFunction Vector a
p

    kmp :: Int -> Int -> Maybe Int
kmp Int
i Int
j | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n         = Maybe Int
forall a. Maybe a
Nothing
            | Vector a
p Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
j a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== Vector a
t Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
i = if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 then Int -> Maybe Int
forall a. a -> Maybe a
Just (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                                             else Int -> Int -> Maybe Int
kmp (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
            | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0          = Int -> Int -> Maybe Int
kmp Int
i     (Vector Int
f Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
            | Bool
otherwise      = Int -> Int -> Maybe Int
kmp (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
0           -- j == 0