--------------------------------------------------------------------------------
-- |
-- Module      :  HGeometry.Matrix.Class
-- Copyright   :  (C) Frank Staals
-- License     :  see the LICENSE file
-- Maintainer  :  Frank Staals
--
-- A class of types representing matrices
--
--------------------------------------------------------------------------------
module HGeometry.Matrix.Class
  ( Matrix_(..)
  , HasElements(..)
  , HasDeterminant(..)
  , Invertible(..)
  ) where

import           Control.Lens hiding (cons,snoc,uncons,unsnoc,elements)
import           Data.Kind
import qualified Data.List as List
import           Data.Maybe (fromMaybe)
import           Data.Proxy
import           GHC.TypeNats
import           HGeometry.Properties
import           HGeometry.Vector
-- import           HGeometry.Vector.List (ListVector(..))
import           Prelude hiding (zipWith)

--------------------------------------------------------------------------------
-- $setup
-- >>> import HGeometry.Point
-- >>> import HGeometry.Matrix
-- >>> let printMatrix = mapMOf_ (rows.traverse) print
-- >>> :{
-- let matrixFromList' :: ( Matrix_ matrix n m r
--                        , KnownNat n, KnownNat m
--                        , Has_ Vector_  m r
--                        , Has_ Vector_  n (Vector m r)
--                        ) => [r] -> matrix
--     matrixFromList' = fromMaybe (error "") . matrixFromList
-- :}


-- | Types that have an 'elements' field lens.
class HasElements matrix matrix' where
  -- | IndexedTraversal over the elements of the matrix, each index is
  -- a (row,column) index pair.
  elements :: IndexedTraversal1 (Int,Int) matrix matrix' (NumType matrix) (NumType matrix')

-- | A matrix of n rows, each of m columns, storing values of type r.
type Matrix_ :: Type -> Nat -> Nat -> Type -> Constraint
class ( r ~ NumType matrix
      , Ixed matrix
      , IxValue matrix ~ r
      , Index matrix ~ (Int,Int) -- ^ row, col
      , HasElements matrix matrix
      ) => Matrix_ matrix n m r | matrix -> n
                                , matrix -> m
                                , matrix -> r where
  {-# MINIMAL generateMatrix, matrixFromRows, rows #-}

  -- | Produces the Identity Matrix.
  --
  -- >>> printMatrix $ identityMatrix @(Matrix 3 3 Int)
  -- Vector3 1 0 0
  -- Vector3 0 1 0
  -- Vector3 0 0 1
  identityMatrix :: Num r => matrix
  identityMatrix = ((Int, Int) -> r) -> matrix
forall matrix (n :: Nat) (m :: Nat) r.
Matrix_ matrix n m r =>
((Int, Int) -> r) -> matrix
generateMatrix (((Int, Int) -> r) -> matrix) -> ((Int, Int) -> r) -> matrix
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then r
1 else r
0

  -- | Given a function that specifies the values, generate the matrix
  generateMatrix :: ((Int,Int) -> r) -> matrix

  -- | Given a list of the elements in the matrix, in row by row
  -- order, constructs the matrix.
  -- requires that there are exactly n*m elements.
  --
  -- >>> matrixFromList @(Matrix 2 3 Int) [1,2,3,4,5,6]
  -- Just (Matrix (Vector2 (Vector3 1 2 3) (Vector3 4 5 6)))
  -- >>> matrixFromList @(Matrix 2 3 Int) [1,2,3,4,5,6,7]
  -- Nothing
  matrixFromList    :: ( KnownNat n, KnownNat m
                       , Has_ Vector_  m r
                       , Has_ Vector_  n (Vector m r)
                       ) => [r] -> Maybe matrix
  matrixFromList [r]
xs = do rs  <- Integer -> [r] -> Maybe [Vector m r]
go Integer
n [r]
xs
                         rs' <- vectorFromList @(Vector n (Vector m r)) rs
                         pure $ matrixFromRows rs'
    where
      m :: Int
m = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ Proxy m -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (Proxy m -> Nat) -> Proxy m -> Nat
forall a b. (a -> b) -> a -> b
$ forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @m
      n :: Integer
n = Nat -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Integer) -> Nat -> Integer
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (Proxy n -> Nat) -> Proxy n -> Nat
forall a b. (a -> b) -> a -> b
$ forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n

      go :: Integer -> [r] -> Maybe [Vector m r]
go Integer
0  [] = [Vector m r] -> Maybe [Vector m r]
forall a. a -> Maybe a
Just []
      go Integer
0  [r]
_  = Maybe [Vector m r]
forall a. Maybe a
Nothing
      go Integer
n' [r]
ys = let ([r]
r,[r]
rest) = Int -> [r] -> ([r], [r])
forall a. Int -> [a] -> ([a], [a])
List.splitAt Int
m [r]
ys
                 in do r'    <- forall vector (d :: Nat) r.
Vector_ vector d r =>
[r] -> Maybe vector
vectorFromList @(Vector m r) [r]
r
                       rest' <- go (n'-1) rest
                       pure (r':rest')

  -- | Given a list of the elements in the matrix, in row by row
  -- order, constructs the matrix.
  -- requires that there are exactly n*m elements.
  --
  -- >>> matrixFromRows @(Matrix 2 3 Int) (Vector2 (Vector3 1 2 3) (Vector3 4 5 6))
  -- Matrix (Vector2 (Vector3 1 2 3) (Vector3 4 5 6))
  matrixFromRows :: ( Vector_ rowVector n (Vector m r)
                    -- , OptVector_ m r
                    ) => rowVector -> matrix

  -- | Matrix multiplication
  --
  -- >>> let m1  = matrixFromRows @(Matrix 2 3 Int) (Vector2 (Vector3 1 2 3) (Vector3 4 5 6))
  -- >>> let r i = Vector4 i (i*10) (i*100) (i*1000)
  -- >>> let m2  = matrixFromRows @(Matrix 3 4 Int) (Vector3 (r 1) (r 2) (r 3))
  -- >>> printMatrix $ (m1 !*! m2 :: Matrix 2 4 Int)
  -- Vector4 14 140 1400 14000
  -- Vector4 32 320 3200 32000
  (!*!)     :: ( Matrix_ matrix'  m m' r
               , Matrix_ matrix'' n m' r
               , Num r
               -- , OptVector_ m r, KnownNat m
               -- , OptVector_ n r, KnownNat n
               , Has_ Additive_ m r -- (Vector m r)
               ) => matrix -> matrix' -> matrix''
  matrix
ma !*! matrix'
mb = ((Int, Int) -> IxValue (Vector m r)) -> matrix''
forall matrix (n :: Nat) (m :: Nat) r.
Matrix_ matrix n m r =>
((Int, Int) -> r) -> matrix
generateMatrix (((Int, Int) -> IxValue (Vector m r)) -> matrix'')
-> ((Int, Int) -> IxValue (Vector m r)) -> matrix''
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> Int -> matrix -> Vector m (NumType matrix)
forall {a} {n :: Nat} {m :: Nat}.
(Matrix_ a n m (NumType a), Has_ Vector_ m (NumType a)) =>
Int -> a -> Vector m (NumType a)
row' Int
i matrix
ma Vector m r -> Vector m r -> IxValue (Vector m r)
forall {a}.
(Additive_ a (Dimension a) (IxValue a), Num (IxValue a)) =>
a -> a -> IxValue a
`dot'` Int -> matrix' -> Vector m (NumType matrix')
forall {a} {n :: Nat} {m :: Nat}.
(Matrix_ a n m (NumType a), Has_ Vector_ n (NumType a)) =>
Int -> a -> Vector n (NumType a)
column' Int
j matrix'
mb
    where
      row' :: Int -> a -> Vector m (NumType a)
row' Int
i    = Vector m (NumType a)
-> Maybe (Vector m (NumType a)) -> Vector m (NumType a)
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> Vector m (NumType a)
forall a. HasCallStack => [Char] -> a
error [Char]
"absurd: row i out of range") (Maybe (Vector m (NumType a)) -> Vector m (NumType a))
-> (a -> Maybe (Vector m (NumType a))) -> a -> Vector m (NumType a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a -> Maybe (Vector m (NumType a))
forall matrix (n :: Nat) (m :: Nat) r.
(Matrix_ matrix n m r, Has_ Vector_ m r) =>
Int -> matrix -> Maybe (Vector m r)
row Int
i
      column' :: Int -> a -> Vector n (NumType a)
column' Int
j = Vector n (NumType a)
-> Maybe (Vector n (NumType a)) -> Vector n (NumType a)
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> Vector n (NumType a)
forall a. HasCallStack => [Char] -> a
error [Char]
"absurd: column j out of range") (Maybe (Vector n (NumType a)) -> Vector n (NumType a))
-> (a -> Maybe (Vector n (NumType a))) -> a -> Vector n (NumType a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a -> Maybe (Vector n (NumType a))
forall matrix (n :: Nat) (m :: Nat) r.
(Matrix_ matrix n m r, Has_ Vector_ n r) =>
Int -> matrix -> Maybe (Vector n r)
column Int
j

      dot' :: a -> a -> IxValue a
dot' a
u a
v = Getting (Endo (Endo (IxValue a))) a (IxValue a) -> a -> IxValue a
forall a s. Num a => Getting (Endo (Endo a)) s a -> s -> a
sumOf Getting (Endo (Endo (IxValue a))) a (IxValue a)
forall vector vector'.
HasComponents vector vector' =>
IndexedTraversal1
  Int vector vector' (IxValue vector) (IxValue vector')
IndexedTraversal1 Int a a (IxValue a) (IxValue a)
components (a -> IxValue a) -> a -> IxValue a
forall a b. (a -> b) -> a -> b
$ (IxValue a -> IxValue a -> IxValue a) -> a -> a -> a
forall vector (d :: Nat) r.
Additive_ vector d r =>
(r -> r -> r) -> vector -> vector -> vector
liftI2 IxValue a -> IxValue a -> IxValue a
forall a. Num a => a -> a -> a
(*) a
u a
v
  {-# INLINE (!*!) #-}

  -- | Multiply a matrix and a vector.
  --
  -- >>> let m = matrixFromRows @(Matrix 2 3 Int) (Vector2 (Vector3 1 2 3) (Vector3 4 5 6))
  -- >>> m !* Vector3 2 3 1
  -- Vector2 11 29
  (!*)  :: ( Has_ Vector_ n (Vector m r)
           , HasComponents (Vector n (Vector m r)) (Vector n r)
           -- , Additive_
           -- , OptVector_ n r, OptVector_ m r
           , Has_ Additive_ m r
           , Num r
           ) => matrix -> Vector m r -> Vector n r
  matrix
m !* Vector m r
v = Vector n (Vector m r)
rows'Vector n (Vector m r)
-> (Vector n (Vector m r) -> Vector n r) -> Vector n r
forall a b. a -> (a -> b) -> b
&(Vector m r -> Identity r)
-> Vector n (Vector m r) -> Identity (Vector n r)
(IxValue (Vector n (Vector m r))
 -> Identity (IxValue (Vector n r)))
-> Vector n (Vector m r) -> Identity (Vector n r)
forall vector vector'.
HasComponents vector vector' =>
IndexedTraversal1
  Int vector vector' (IxValue vector) (IxValue vector')
IndexedTraversal1
  Int
  (Vector n (Vector m r))
  (Vector n r)
  (IxValue (Vector n (Vector m r)))
  (IxValue (Vector n r))
components ((Vector m r -> Identity r)
 -> Vector n (Vector m r) -> Identity (Vector n r))
-> (Vector m r -> r) -> Vector n (Vector m r) -> Vector n r
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ Vector m r -> r
dotWithV
    where
      rows'  :: Vector n (Vector m r)
      rows' :: Vector n (Vector m r)
rows'  = matrix
mmatrix
-> Getting (Vector n (Vector m r)) matrix (Vector n (Vector m r))
-> Vector n (Vector m r)
forall s a. s -> Getting a s a -> a
^.Getting (Vector n (Vector m r)) matrix (Vector n (Vector m r))
forall matrix (n :: Nat) (m :: Nat) r.
Matrix_ matrix n m r =>
Lens' matrix (Vector n (Vector m r))
Lens' matrix (Vector n (Vector m r))
rows
      dotWithV   :: Vector m r -> r
      dotWithV :: Vector m r -> r
dotWithV Vector m r
u = Getting (Endo (Endo r)) (Vector m r) r -> Vector m r -> r
forall a s. Num a => Getting (Endo (Endo a)) s a -> s -> a
sumOf Getting (Endo (Endo r)) (Vector m r) r
(IxValue (Vector m r)
 -> Const (Endo (Endo r)) (IxValue (Vector m r)))
-> Vector m r -> Const (Endo (Endo r)) (Vector m r)
forall vector vector'.
HasComponents vector vector' =>
IndexedTraversal1
  Int vector vector' (IxValue vector) (IxValue vector')
IndexedTraversal1
  Int
  (Vector m r)
  (Vector m r)
  (IxValue (Vector m r))
  (IxValue (Vector m r))
components (Vector m r -> r) -> Vector m r -> r
forall a b. (a -> b) -> a -> b
$ (r -> r -> r) -> Vector m r -> Vector m r -> Vector m r
forall vector (d :: Nat) r.
Additive_ vector d r =>
(r -> r -> r) -> vector -> vector -> vector
liftI2 r -> r -> r
forall a. Num a => a -> a -> a
(*) Vector m r
u (Vector m r
vVector m r
-> Getting (Vector m r) (Vector m r) (Vector m r) -> Vector m r
forall s a. s -> Getting a s a -> a
^.Getting (Vector m r) (Vector m r) (Vector m r)
forall vector vector' (d :: Nat) r s.
AsVector_ vector vector' d r s =>
Iso vector vector' (Vector d r) (Vector d s)
Iso (Vector m r) (Vector m r) (Vector m r) (Vector m r)
_Vector)
  {-# INLINE (!*) #-}

  -- | Multiply a scalar and a matrix
  (*!!)   :: Num r => r -> matrix -> matrix
  r
s *!! matrix
m = matrix
mmatrix -> (matrix -> matrix) -> matrix
forall a b. a -> (a -> b) -> b
&(r -> Identity r) -> matrix -> Identity matrix
(NumType matrix -> Identity (NumType matrix))
-> matrix -> Identity matrix
forall matrix matrix'.
HasElements matrix matrix' =>
IndexedTraversal1
  (Int, Int) matrix matrix' (NumType matrix) (NumType matrix')
IndexedTraversal1
  (Int, Int) matrix matrix (NumType matrix) (NumType matrix)
elements ((r -> Identity r) -> matrix -> Identity matrix)
-> r -> matrix -> matrix
forall a s t. Num a => ASetter s t a a -> a -> s -> t
*~ r
s
  {-# INLINE (*!!) #-}

  -- | Multiply a matrix and a scalar
  (!!*)   :: Num r => matrix -> r -> matrix
  matrix
m !!* r
s = matrix
mmatrix -> (matrix -> matrix) -> matrix
forall a b. a -> (a -> b) -> b
&(r -> Identity r) -> matrix -> Identity matrix
(NumType matrix -> Identity (NumType matrix))
-> matrix -> Identity matrix
forall matrix matrix'.
HasElements matrix matrix' =>
IndexedTraversal1
  (Int, Int) matrix matrix' (NumType matrix) (NumType matrix')
IndexedTraversal1
  (Int, Int) matrix matrix (NumType matrix) (NumType matrix)
elements ((r -> Identity r) -> matrix -> Identity matrix)
-> r -> matrix -> matrix
forall a s t. Num a => ASetter s t a a -> a -> s -> t
*~ r
s
  {-# INLINE (!!*) #-}

  -- | Lens to access all rows
  rows :: Lens' matrix (Vector n (Vector m r))
  -- rows :: IndexedTraversal' Int matrix (Vector m r)

--  -- | Traversal over all columns
--  columns :: IndexedTraversal' Int matrix (Vector n r)



  -- | Access the i^th row in the matrix.
  --
  row     :: Has_ Vector_ m r => Int -> matrix -> Maybe (Vector m r)
  row Int
i matrix
m = (Int -> Maybe r) -> Maybe (Vector m r)
forall vector (d :: Nat) r (f :: * -> *).
(Vector_ vector d r, Applicative f) =>
(Int -> f r) -> f vector
forall (f :: * -> *).
Applicative f =>
(Int -> f r) -> f (Vector m r)
generateA ((Int -> Maybe r) -> Maybe (Vector m r))
-> (Int -> Maybe r) -> Maybe (Vector m r)
forall a b. (a -> b) -> a -> b
$ \Int
j -> matrix
m matrix -> Getting (First r) matrix r -> Maybe r
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index matrix -> Traversal' matrix (IxValue matrix)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix (Int
i,Int
j)
  {-# INLINE row #-}

  -- row'   :: (OptVector_ m r, KnownNat m) => Int -> Traversal' matrix (Vector m r)
  -- row' i = traversal go
  --   where
  --     go focus m =
    -- generate $ \j -> m ^?! ix (i,j)

  -- | Access the j^th column in the matrix.
  column     :: Has_ Vector_ n r => Int -> matrix -> Maybe (Vector n r)
  column Int
j matrix
m = (Int -> Maybe r) -> Maybe (Vector n r)
forall vector (d :: Nat) r (f :: * -> *).
(Vector_ vector d r, Applicative f) =>
(Int -> f r) -> f vector
forall (f :: * -> *).
Applicative f =>
(Int -> f r) -> f (Vector n r)
generateA ((Int -> Maybe r) -> Maybe (Vector n r))
-> (Int -> Maybe r) -> Maybe (Vector n r)
forall a b. (a -> b) -> a -> b
$ \Int
i -> matrix
m matrix -> Getting (First r) matrix r -> Maybe r
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index matrix -> Traversal' matrix (IxValue matrix)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix (Int
i,Int
j)
  {-# INLINE column #-}


infixl 7 !*!
infixl 7 !*
infixl 7 *!!
infixl 7 !!*

-- class Matrix_ matrix n m r => ConstructableMatrix_ matrix n m r where
--   fromList


--------------------------------------------------------------------------------
-- * Determinants

-- | Dimensions for which we can compute the determinant of a matrix
class HasDeterminant d where
  det :: (Num r, Matrix_ matrix d d r) => matrix -> r

instance HasDeterminant 1 where
  det :: forall r matrix. (Num r, Matrix_ matrix 1 1 r) => matrix -> r
det matrix
m = matrix
mmatrix -> Getting (Endo r) matrix r -> r
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?!Index matrix -> Traversal' matrix (IxValue matrix)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix (Int
0,Int
0)
  {-# INLINE det #-}

instance HasDeterminant 2 where
  det :: forall r matrix. (Num r, Matrix_ matrix 2 2 r) => matrix -> r
det matrix
m = case matrix
mmatrix -> Getting (Endo [r]) matrix r -> [r]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^..Getting (Endo [r]) matrix r
(NumType matrix -> Const (Endo [r]) (NumType matrix))
-> matrix -> Const (Endo [r]) matrix
forall matrix matrix'.
HasElements matrix matrix' =>
IndexedTraversal1
  (Int, Int) matrix matrix' (NumType matrix) (NumType matrix')
IndexedTraversal1
  (Int, Int) matrix matrix (NumType matrix) (NumType matrix)
elements of
            [ r
a, r
b,
              r
c, r
d] -> r -> r -> r -> r -> r
forall r. Num r => r -> r -> r -> r -> r
det22 r
a r
b r
c r
d -- a*d - b*c
            [r]
_       -> [Char] -> r
forall a. HasCallStack => [Char] -> a
error [Char]
"det: 2x2, absurd"
  {-# INLINE det #-}

-- | determinant of a 2x2 matrix [[a,b], [c,d]]
det22         :: Num r => r -> r -> r -> r -> r
det22 :: forall r. Num r => r -> r -> r -> r -> r
det22 r
a r
b r
c r
d = r
ar -> r -> r
forall a. Num a => a -> a -> a
*r
d r -> r -> r
forall a. Num a => a -> a -> a
- r
br -> r -> r
forall a. Num a => a -> a -> a
*r
c

instance HasDeterminant 3 where
  det :: forall r matrix. (Num r, Matrix_ matrix 3 3 r) => matrix -> r
det matrix
m = case matrix
mmatrix -> Getting (Endo [r]) matrix r -> [r]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^..Getting (Endo [r]) matrix r
(NumType matrix -> Const (Endo [r]) (NumType matrix))
-> matrix -> Const (Endo [r]) matrix
forall matrix matrix'.
HasElements matrix matrix' =>
IndexedTraversal1
  (Int, Int) matrix matrix' (NumType matrix) (NumType matrix')
IndexedTraversal1
  (Int, Int) matrix matrix (NumType matrix) (NumType matrix)
elements of
            [ r
a, r
b, r
c,
              r
d, r
e, r
f,
              r
g, r
h, r
i] -> r
ar -> r -> r
forall a. Num a => a -> a -> a
*r
er -> r -> r
forall a. Num a => a -> a -> a
*r
i r -> r -> r
forall a. Num a => a -> a -> a
+ r
br -> r -> r
forall a. Num a => a -> a -> a
*r
fr -> r -> r
forall a. Num a => a -> a -> a
*r
g r -> r -> r
forall a. Num a => a -> a -> a
+ r
cr -> r -> r
forall a. Num a => a -> a -> a
*r
dr -> r -> r
forall a. Num a => a -> a -> a
*r
h r -> r -> r
forall a. Num a => a -> a -> a
- r
cr -> r -> r
forall a. Num a => a -> a -> a
*r
er -> r -> r
forall a. Num a => a -> a -> a
*r
g r -> r -> r
forall a. Num a => a -> a -> a
- r
br -> r -> r
forall a. Num a => a -> a -> a
*r
dr -> r -> r
forall a. Num a => a -> a -> a
*r
i r -> r -> r
forall a. Num a => a -> a -> a
- r
ar -> r -> r
forall a. Num a => a -> a -> a
*r
fr -> r -> r
forall a. Num a => a -> a -> a
*r
h
            [r]
_          -> [Char] -> r
forall a. HasCallStack => [Char] -> a
error [Char]
"det: 3x3, absurd"
  {-# INLINE det #-}

instance HasDeterminant 4 where
  det :: forall r matrix. (Num r, Matrix_ matrix 4 4 r) => matrix -> r
det matrix
m = case matrix
mmatrix -> Getting (Endo [r]) matrix r -> [r]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^..Getting (Endo [r]) matrix r
(NumType matrix -> Const (Endo [r]) (NumType matrix))
-> matrix -> Const (Endo [r]) matrix
forall matrix matrix'.
HasElements matrix matrix' =>
IndexedTraversal1
  (Int, Int) matrix matrix' (NumType matrix) (NumType matrix')
IndexedTraversal1
  (Int, Int) matrix matrix (NumType matrix) (NumType matrix)
elements of
    [ r
i00, r
i01, r
i02, r
i03,
      r
i10, r
i11, r
i12, r
i13,
      r
i20, r
i21, r
i22, r
i23,
      r
i30, r
i31, r
i32, r
i33 ] -> let s0 :: r
s0 = r
i00 r -> r -> r
forall a. Num a => a -> a -> a
* r
i11 r -> r -> r
forall a. Num a => a -> a -> a
- r
i10 r -> r -> r
forall a. Num a => a -> a -> a
* r
i01
                                  s1 :: r
s1 = r
i00 r -> r -> r
forall a. Num a => a -> a -> a
* r
i12 r -> r -> r
forall a. Num a => a -> a -> a
- r
i10 r -> r -> r
forall a. Num a => a -> a -> a
* r
i02
                                  s2 :: r
s2 = r
i00 r -> r -> r
forall a. Num a => a -> a -> a
* r
i13 r -> r -> r
forall a. Num a => a -> a -> a
- r
i10 r -> r -> r
forall a. Num a => a -> a -> a
* r
i03
                                  s3 :: r
s3 = r
i01 r -> r -> r
forall a. Num a => a -> a -> a
* r
i12 r -> r -> r
forall a. Num a => a -> a -> a
- r
i11 r -> r -> r
forall a. Num a => a -> a -> a
* r
i02
                                  s4 :: r
s4 = r
i01 r -> r -> r
forall a. Num a => a -> a -> a
* r
i13 r -> r -> r
forall a. Num a => a -> a -> a
- r
i11 r -> r -> r
forall a. Num a => a -> a -> a
* r
i03
                                  s5 :: r
s5 = r
i02 r -> r -> r
forall a. Num a => a -> a -> a
* r
i13 r -> r -> r
forall a. Num a => a -> a -> a
- r
i12 r -> r -> r
forall a. Num a => a -> a -> a
* r
i03

                                  c5 :: r
c5 = r
i22 r -> r -> r
forall a. Num a => a -> a -> a
* r
i33 r -> r -> r
forall a. Num a => a -> a -> a
- r
i32 r -> r -> r
forall a. Num a => a -> a -> a
* r
i23
                                  c4 :: r
c4 = r
i21 r -> r -> r
forall a. Num a => a -> a -> a
* r
i33 r -> r -> r
forall a. Num a => a -> a -> a
- r
i31 r -> r -> r
forall a. Num a => a -> a -> a
* r
i23
                                  c3 :: r
c3 = r
i21 r -> r -> r
forall a. Num a => a -> a -> a
* r
i32 r -> r -> r
forall a. Num a => a -> a -> a
- r
i31 r -> r -> r
forall a. Num a => a -> a -> a
* r
i22
                                  c2 :: r
c2 = r
i20 r -> r -> r
forall a. Num a => a -> a -> a
* r
i33 r -> r -> r
forall a. Num a => a -> a -> a
- r
i30 r -> r -> r
forall a. Num a => a -> a -> a
* r
i23
                                  c1 :: r
c1 = r
i20 r -> r -> r
forall a. Num a => a -> a -> a
* r
i32 r -> r -> r
forall a. Num a => a -> a -> a
- r
i30 r -> r -> r
forall a. Num a => a -> a -> a
* r
i22
                                  c0 :: r
c0 = r
i20 r -> r -> r
forall a. Num a => a -> a -> a
* r
i31 r -> r -> r
forall a. Num a => a -> a -> a
- r
i30 r -> r -> r
forall a. Num a => a -> a -> a
* r
i21
                              in r
s0 r -> r -> r
forall a. Num a => a -> a -> a
* r
c5 r -> r -> r
forall a. Num a => a -> a -> a
- r
s1 r -> r -> r
forall a. Num a => a -> a -> a
* r
c4 r -> r -> r
forall a. Num a => a -> a -> a
+ r
s2 r -> r -> r
forall a. Num a => a -> a -> a
* r
c3 r -> r -> r
forall a. Num a => a -> a -> a
+ r
s3 r -> r -> r
forall a. Num a => a -> a -> a
* r
c2 r -> r -> r
forall a. Num a => a -> a -> a
- r
s4 r -> r -> r
forall a. Num a => a -> a -> a
* r
c1 r -> r -> r
forall a. Num a => a -> a -> a
+ r
s5 r -> r -> r
forall a. Num a => a -> a -> a
* r
c0
     -- adapted from the implementation in the Linear package.
    [r]
_ -> [Char] -> r
forall a. HasCallStack => [Char] -> a
error [Char]
"det: 4x4 absurd"
  {-# INLINE det #-}
  -- TODO: verify that GHC unrolls the list
  -- TODO verify that GHC specializes this for the most relevant types

--------------------------------------------------------------------------------
-- * Invertible matrices


-- | Class of matrices that are invertible.
class Invertible n where
  -- | given an invertable square \(n \times n\) matrix A, computes
  -- the \(n \times n\) matrix B such that A !*! B = identityMatrix
  --
  inverseMatrix :: ( Fractional r
                   , Matrix_ matrix n n r
                   , Has_ Vector_ n r
                   ) => matrix -> matrix

instance Invertible 1 where
  inverseMatrix :: forall r matrix.
(Fractional r, Matrix_ matrix 1 1 r, Has_ Vector_ 1 r) =>
matrix -> matrix
inverseMatrix matrix
m = matrix
mmatrix -> (matrix -> matrix) -> matrix
forall a b. a -> (a -> b) -> b
&(r -> Identity r) -> matrix -> Identity matrix
(NumType matrix -> Identity (NumType matrix))
-> matrix -> Identity matrix
forall matrix matrix'.
HasElements matrix matrix' =>
IndexedTraversal1
  (Int, Int) matrix matrix' (NumType matrix) (NumType matrix')
IndexedTraversal1
  (Int, Int) matrix matrix (NumType matrix) (NumType matrix)
elements ((r -> Identity r) -> matrix -> Identity matrix)
-> (r -> r) -> matrix -> matrix
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (\r
x -> (r
1r -> r -> r
forall a. Fractional a => a -> a -> a
/r
x))
  -- slightly weird way of writing this, since there is only one element, but whatever
  {-# INLINE inverseMatrix #-}

instance Invertible 2 where
  -- >>> printMatrix $ inverseMatrix $ matrixFromList' @(Matrix 2 2 Double) [1,2, 3,4]
  -- (Vector2 (-2.0) 1.0)
  -- (Vector2 1.5 (-0.5))
  inverseMatrix :: forall r matrix.
(Fractional r, Matrix_ matrix 2 2 r, Has_ Vector_ 2 r) =>
matrix -> matrix
inverseMatrix matrix
m = case matrix
mmatrix -> Getting (Endo [r]) matrix r -> [r]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^..Getting (Endo [r]) matrix r
(NumType matrix -> Const (Endo [r]) (NumType matrix))
-> matrix -> Const (Endo [r]) matrix
forall matrix matrix'.
HasElements matrix matrix' =>
IndexedTraversal1
  (Int, Int) matrix matrix' (NumType matrix) (NumType matrix')
IndexedTraversal1
  (Int, Int) matrix matrix (NumType matrix) (NumType matrix)
elements of
                      [r
a,r
b,
                       r
c,r
d] -> let s :: r
s = r
1 r -> r -> r
forall a. Fractional a => a -> a -> a
/ matrix -> r
forall (d :: Nat) r matrix.
(HasDeterminant d, Num r, Matrix_ matrix d d r) =>
matrix -> r
forall r matrix. (Num r, Matrix_ matrix 2 2 r) => matrix -> r
det matrix
m
                               in r
s r -> matrix -> matrix
forall matrix (n :: Nat) (m :: Nat) r.
(Matrix_ matrix n m r, Num r) =>
r -> matrix -> matrix
*!! (Vector 2 (Vector 2 r) -> matrix
forall rowVector.
Vector_ rowVector 2 (Vector 2 r) =>
rowVector -> matrix
forall matrix (n :: Nat) (m :: Nat) r rowVector.
(Matrix_ matrix n m r, Vector_ rowVector n (Vector m r)) =>
rowVector -> matrix
matrixFromRows (Vector 2 (Vector 2 r) -> matrix)
-> Vector 2 (Vector 2 r) -> matrix
forall a b. (a -> b) -> a -> b
$ Vector 2 r -> Vector 2 r -> Vector 2 (Vector 2 r)
forall r. r -> r -> Vector 2 r
Vector2 -- @(ListVector 2 _)
                                  (r -> r -> Vector 2 r
forall r. r -> r -> Vector 2 r
Vector2 r
d          (r -> r
forall a. Num a => a -> a
negate r
b))
                                  (r -> r -> Vector 2 r
forall r. r -> r -> Vector 2 r
Vector2 (r -> r
forall a. Num a => a -> a
negate r
c) r
a))
                      [r]
_     -> [Char] -> matrix
forall a. HasCallStack => [Char] -> a
error [Char]
"inverseMatrix 2x2: absurd"
  {-# INLINE inverseMatrix #-}
  -- it is a bit silly we are using the list vectors here.

instance Invertible 3 where
  -- >>> printMatrix $ inverseMatrix $ matrixFromList' @(Matrix 3 3 Double) [1,2,4,     4,2,2,    1,1,1]
  -- (Vector3 0.0 0.5 (-1.0))
  -- (Vector3 (-0.5) (-0.75) 3.5)
  -- (Vector3 0.5 0.25 (-1.5))
  inverseMatrix :: forall r matrix.
(Fractional r, Matrix_ matrix 3 3 r, Has_ Vector_ 3 r) =>
matrix -> matrix
inverseMatrix matrix
m = case matrix
mmatrix -> Getting (Endo [r]) matrix r -> [r]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^..Getting (Endo [r]) matrix r
(NumType matrix -> Const (Endo [r]) (NumType matrix))
-> matrix -> Const (Endo [r]) matrix
forall matrix matrix'.
HasElements matrix matrix' =>
IndexedTraversal1
  (Int, Int) matrix matrix' (NumType matrix) (NumType matrix')
IndexedTraversal1
  (Int, Int) matrix matrix (NumType matrix) (NumType matrix)
elements of
      [ r
a, r
b, r
c,
        r
d, r
e, r
f,
        r
g, r
h, r
i] -> let lambda :: r
lambda  = r
1 r -> r -> r
forall a. Fractional a => a -> a -> a
/ matrix -> r
forall (d :: Nat) r matrix.
(HasDeterminant d, Num r, Matrix_ matrix d d r) =>
matrix -> r
forall r matrix. (Num r, Matrix_ matrix 3 3 r) => matrix -> r
det matrix
m
                        aa :: r
aa = r -> r -> r -> r -> r
forall r. Num r => r -> r -> r -> r -> r
det22 r
e r
f r
h r
i
                        bb :: r
bb = r -> r -> r -> r -> r
forall r. Num r => r -> r -> r -> r -> r
det22 r
c r
b r
i r
h
                        cc :: r
cc = r -> r -> r -> r -> r
forall r. Num r => r -> r -> r -> r -> r
det22 r
b r
c r
e r
f
                        dd :: r
dd = r -> r -> r -> r -> r
forall r. Num r => r -> r -> r -> r -> r
det22 r
f r
d r
i r
g
                        ee :: r
ee = r -> r -> r -> r -> r
forall r. Num r => r -> r -> r -> r -> r
det22 r
a r
c r
g r
i
                        ff :: r
ff = r -> r -> r -> r -> r
forall r. Num r => r -> r -> r -> r -> r
det22 r
c r
a r
f r
d
                        gg :: r
gg = r -> r -> r -> r -> r
forall r. Num r => r -> r -> r -> r -> r
det22 r
d r
e r
g r
h
                        hh :: r
hh = r -> r -> r -> r -> r
forall r. Num r => r -> r -> r -> r -> r
det22 r
b r
a r
h r
g
                        ii :: r
ii = r -> r -> r -> r -> r
forall r. Num r => r -> r -> r -> r -> r
det22 r
a r
b r
d r
e
                    in r
lambda r -> matrix -> matrix
forall matrix (n :: Nat) (m :: Nat) r.
(Matrix_ matrix n m r, Num r) =>
r -> matrix -> matrix
*!! (Vector 3 (Vector 3 r) -> matrix
forall rowVector.
Vector_ rowVector 3 (Vector 3 r) =>
rowVector -> matrix
forall matrix (n :: Nat) (m :: Nat) r rowVector.
(Matrix_ matrix n m r, Vector_ rowVector n (Vector m r)) =>
rowVector -> matrix
matrixFromRows (Vector 3 (Vector 3 r) -> matrix)
-> Vector 3 (Vector 3 r) -> matrix
forall a b. (a -> b) -> a -> b
$ Vector 3 r -> Vector 3 r -> Vector 3 r -> Vector 3 (Vector 3 r)
forall r. r -> r -> r -> Vector 3 r
Vector3 -- _ @(ListVector 3 _)
                                     (r -> r -> r -> Vector 3 r
forall r. r -> r -> r -> Vector 3 r
Vector3 r
aa r
bb r
cc)
                                     (r -> r -> r -> Vector 3 r
forall r. r -> r -> r -> Vector 3 r
Vector3 r
dd r
ee r
ff)
                                     (r -> r -> r -> Vector 3 r
forall r. r -> r -> r -> Vector 3 r
Vector3 r
gg r
hh r
ii))
      [r]
_          -> [Char] -> matrix
forall a. HasCallStack => [Char] -> a
error [Char]
"inverseMatrix 3x3: absurd"
  {-# INLINE inverseMatrix #-}

--------------------------------------------------------------------------------