{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-orphans #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  HGeometry.Matrix.ByRow
-- Copyright   :  (C) Frank Staals
-- License     :  see the LICENSE file
-- Maintainer  :  Frank Staals
--
-- type-indexed matrices.
--
--------------------------------------------------------------------------------
module HGeometry.Matrix.ByRows(
    Matrix(Matrix)
  , OptMatrix_
  ) where

import           Control.Lens
-- import           GHC.TypeNats
import           HGeometry.Matrix.Class
import           HGeometry.Vector
import           HGeometry.Properties
-- import           Linear.Matrix (M22, M33, M44)
-- import qualified Linear.Matrix as Lin

--------------------------------------------------------------------------------
-- * Matrices

-- | A matrix of n rows, each of m columns, storing values of type r.
newtype Matrix n m r = Matrix (Vector n (Vector m r))


-- transpose :: Matrix n m r -> Matrix m n r
-- transpose = undefined


type instance NumType (Matrix n m r) = r
type instance Index   (Matrix n m r) = (Int,Int)
type instance IxValue (Matrix n m r) = r


_MatrixVector :: Iso (Matrix n m r)          (Matrix n m s)
                     (Vector n (Vector m r)) (Vector n (Vector m s))
_MatrixVector :: forall (n :: Nat) (m :: Nat) r s (p :: * -> * -> *) (f :: * -> *).
(Profunctor p, Functor f) =>
p (Vector n (Vector m r)) (f (Vector n (Vector m s)))
-> p (Matrix n m r) (f (Matrix n m s))
_MatrixVector = (Matrix n m r -> Vector n (Vector m r))
-> (Vector n (Vector m s) -> Matrix n m s)
-> Iso
     (Matrix n m r)
     (Matrix n m s)
     (Vector n (Vector m r))
     (Vector n (Vector m s))
forall s a b t. (s -> a) -> (b -> t) -> Iso s t a b
iso (\(Matrix Vector n (Vector m r)
v) -> Vector n (Vector m r)
v) Vector n (Vector m s) -> Matrix n m s
forall (n :: Nat) (m :: Nat) r.
Vector n (Vector m r) -> Matrix n m r
Matrix

deriving stock instance ( Show r
                        , Has_ Vector_ m r
                        , Has_ Vector_ n (Vector m r)
                        ) => Show (Matrix n m r)
deriving newtype instance ( Eq (Vector n (Vector m r)))  => Eq  (Matrix n m r)
deriving newtype instance ( Ord (Vector n (Vector m r))) => Ord (Matrix n m r)

-- | shorthand for square matrixecs
type OptMatrix_ d r = ( Has_ Additive_ d r
                      , Has_ Vector_ d (Vector d r)
                      , Ixed (Vector d r)
                      , Ixed (Vector d (Vector d r))
                      )

instance ( Has_ Vector_ n (Vector m r)
         , Has_ Vector_ m r
         , Ixed (Vector n (Vector m r))
         , Ixed (Vector m r)
         ) => Ixed (Matrix n m r) where
  ix :: Index (Matrix n m r)
-> Traversal' (Matrix n m r) (IxValue (Matrix n m r))
ix (Int
i,Int
j) = (Vector n (Vector m r) -> f (Vector n (Vector m r)))
-> Matrix n m r -> f (Matrix n m r)
forall (n :: Nat) (m :: Nat) r s (p :: * -> * -> *) (f :: * -> *).
(Profunctor p, Functor f) =>
p (Vector n (Vector m r)) (f (Vector n (Vector m s)))
-> p (Matrix n m r) (f (Matrix n m s))
_MatrixVector((Vector n (Vector m r) -> f (Vector n (Vector m r)))
 -> Matrix n m r -> f (Matrix n m r))
-> ((r -> f r)
    -> Vector n (Vector m r) -> f (Vector n (Vector m r)))
-> (r -> f r)
-> Matrix n m r
-> f (Matrix n m r)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Index (Vector n (Vector m r))
-> Traversal'
     (Vector n (Vector m r)) (IxValue (Vector n (Vector m r)))
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index (Vector n (Vector m r))
i((IxValue (Vector n (Vector m r))
  -> f (IxValue (Vector n (Vector m r))))
 -> Vector n (Vector m r) -> f (Vector n (Vector m r)))
-> ((r -> f r)
    -> IxValue (Vector n (Vector m r))
    -> f (IxValue (Vector n (Vector m r))))
-> (r -> f r)
-> Vector n (Vector m r)
-> f (Vector n (Vector m r))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Index (IxValue (Vector n (Vector m r)))
-> Traversal'
     (IxValue (Vector n (Vector m r)))
     (IxValue (IxValue (Vector n (Vector m r))))
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index (IxValue (Vector n (Vector m r)))
j


instance ( Has_ Vector_ n (Vector m r)
         , Has_ Vector_ m r
         , Has_ Vector_ n (Vector m s)
         , Has_ Vector_ m s
         , HasComponents (Vector m r) (Vector m s)
         , HasComponents (Vector n (Vector m r)) (Vector n (Vector m s))
         ) => HasElements (Matrix n m r) (Matrix n m s) where
  elements :: IndexedTraversal1
  (Int, Int)
  (Matrix n m r)
  (Matrix n m s)
  (NumType (Matrix n m r))
  (NumType (Matrix n m s))
elements = (Vector n (Vector m r) -> f (Vector n (Vector m s)))
-> Matrix n m r -> f (Matrix n m s)
forall (n :: Nat) (m :: Nat) r s (p :: * -> * -> *) (f :: * -> *).
(Profunctor p, Functor f) =>
p (Vector n (Vector m r)) (f (Vector n (Vector m s)))
-> p (Matrix n m r) (f (Matrix n m s))
_MatrixVector ((Vector n (Vector m r) -> f (Vector n (Vector m s)))
 -> Matrix n m r -> f (Matrix n m s))
-> (p r (f s)
    -> Vector n (Vector m r) -> f (Vector n (Vector m s)))
-> p r (f s)
-> Matrix n m r
-> f (Matrix n m s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.> Indexed
  Int
  (IxValue (Vector n (Vector m r)))
  (f (IxValue (Vector n (Vector m s))))
-> Vector n (Vector m r) -> f (Vector n (Vector m s))
forall vector vector'.
HasComponents vector vector' =>
IndexedTraversal1
  Int vector vector' (IxValue vector) (IxValue vector')
IndexedTraversal1
  Int
  (Vector n (Vector m r))
  (Vector n (Vector m s))
  (IxValue (Vector n (Vector m r)))
  (IxValue (Vector n (Vector m s)))
components (Indexed
   Int
   (IxValue (Vector n (Vector m r)))
   (f (IxValue (Vector n (Vector m s))))
 -> Vector n (Vector m r) -> f (Vector n (Vector m s)))
-> (Indexed Int r (f s)
    -> IxValue (Vector n (Vector m r))
    -> f (IxValue (Vector n (Vector m s))))
-> p r (f s)
-> Vector n (Vector m r)
-> f (Vector n (Vector m s))
forall i j (p :: * -> * -> *) s t r a b.
Indexable (i, j) p =>
(Indexed i s t -> r) -> (Indexed j a b -> s -> t) -> p a b -> r
<.> Indexed Int r (f s)
-> IxValue (Vector n (Vector m r))
-> f (IxValue (Vector n (Vector m s)))
Indexed
  Int
  (IxValue (IxValue (Vector n (Vector m r))))
  (f (IxValue (IxValue (Vector n (Vector m s)))))
-> IxValue (Vector n (Vector m r))
-> f (IxValue (Vector n (Vector m s)))
forall vector vector'.
HasComponents vector vector' =>
IndexedTraversal1
  Int vector vector' (IxValue vector) (IxValue vector')
IndexedTraversal1
  Int
  (IxValue (Vector n (Vector m r)))
  (IxValue (Vector n (Vector m s)))
  (IxValue (IxValue (Vector n (Vector m r))))
  (IxValue (IxValue (Vector n (Vector m s))))
components


instance ( Has_ Vector_ n (Vector m r)
         , Has_ Additive_ m r
         , Ixed (Vector n (Vector m r))
         , Ixed (Vector m r)
         ) => Matrix_ (Matrix n m r) n m r where

  matrixFromRows :: forall rowVector.
Vector_ rowVector n (Vector m r) =>
rowVector -> Matrix n m r
matrixFromRows = Vector n (Vector m r) -> Matrix n m r
forall (n :: Nat) (m :: Nat) r.
Vector n (Vector m r) -> Matrix n m r
Matrix (Vector n (Vector m r) -> Matrix n m r)
-> (rowVector -> Vector n (Vector m r))
-> rowVector
-> Matrix n m r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Getting (Vector n (Vector m r)) rowVector (Vector n (Vector m r))
-> rowVector -> Vector n (Vector m r)
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting (Vector n (Vector m r)) rowVector (Vector n (Vector m r))
forall vector vector' (d :: Nat) r s.
AsVector_ vector vector' d r s =>
Iso vector vector' (Vector d r) (Vector d s)
Iso
  rowVector rowVector (Vector n (Vector m r)) (Vector n (Vector m r))
_Vector

  generateMatrix :: ((Int, Int) -> r) -> Matrix n m r
generateMatrix (Int, Int) -> r
f = Vector n (Vector m r) -> Matrix n m r
forall (n :: Nat) (m :: Nat) r.
Vector n (Vector m r) -> Matrix n m r
Matrix (Vector n (Vector m r) -> Matrix n m r)
-> Vector n (Vector m r) -> Matrix n m r
forall a b. (a -> b) -> a -> b
$ (Int -> Vector m r) -> Vector n (Vector m r)
forall vector (d :: Nat) r.
Vector_ vector d r =>
(Int -> r) -> vector
generate Int -> Vector m r
mkRow
    where
      mkRow :: Int -> Vector m r
mkRow Int
i = (Int -> r) -> Vector m r
forall vector (d :: Nat) r.
Vector_ vector d r =>
(Int -> r) -> vector
generate (\Int
j -> (Int, Int) -> r
f (Int
i,Int
j))

  rows :: Lens' (Matrix n m r) (Vector n (Vector m r))
rows = (Vector n (Vector m r) -> f (Vector n (Vector m r)))
-> Matrix n m r -> f (Matrix n m r)
forall (n :: Nat) (m :: Nat) r s (p :: * -> * -> *) (f :: * -> *).
(Profunctor p, Functor f) =>
p (Vector n (Vector m r)) (f (Vector n (Vector m s)))
-> p (Matrix n m r) (f (Matrix n m s))
_MatrixVector

-- test :: Matrix 2 2 Int
-- test = identityMatrix

--  columns = undefined

-- instance Fractional r => Invertible 2 r where
--   -- >>> inverse' $ Matrix $ Vector2 (Vector2 1 2) (Vector2 3 4.0)
--   -- Matrix Vector2 [Vector2 [-2.0,1.0],Vector2 [1.5,-0.5]]
--   inverse' = withM22 Lin.inv22

-- instance Fractional r => Invertible 3 r where
--   -- >>> inverse' $ Matrix $ Vector3 (Vector3 1 2 4) (Vector3 4 2 2) (Vector3 1 1 1.0)
--   -- Matrix Vector3 [Vector3 [0.0,0.5,-1.0],Vector3 [-0.5,-0.75,3.5],Vector3 [0.5,0.25,-1.5]]
--   inverse' = withM33 Lin.inv33

-- instance Fractional r => Invertible 4 r where
--   inverse' = withM44 Lin.inv44

-- -- instance HasDeterminant 1 where
-- --   det (Matrix (Vector1 (Vector1 x))) = x
-- instance HasDeterminant 2 where
--   det = Lin.det22 . coerce
-- -- instance HasDeterminant 3 where
-- --   det = Lin.det33 . coerce
-- -- instance HasDeterminant 4 where
-- --   det = Lin.det44 . coerce

-- --------------------------------------------------------------------------------
-- -- Boilerplate code for converting between Matrix and M22/M33/M44.

-- withM22   :: (M22 a -> M22 b) -> Matrix 2 2 a -> Matrix 2 2 b
-- withM22 f = coerce . f . coerce

-- withM33 :: (M33 a -> M33 b) -> Matrix 3 3 a -> Matrix 3 3 b
-- withM33 f = coerce . f . coerce

-- withM44 :: (M44 a -> M44 b) -> Matrix 4 4 a -> Matrix 4 4 b
-- withM44 f = coerce . f . coerce