from __future__ import print_function, division
from sympy.core import Mul, Basic, sympify, Add
from sympy.functions import transpose, adjoint
from sympy.matrices.expressions.transpose import transpose
from sympy.strategies import (rm_id, unpack, typed, debug, flatten, exhaust,
do_one, new)
from sympy.matrices.expressions.matexpr import (MatrixExpr, ShapeError,
Identity, ZeroMatrix)
[docs]class MatMul(MatrixExpr):
"""
A product of matrix expressions
Examples
========
>>> from sympy import MatMul, MatrixSymbol
>>> A = MatrixSymbol('A', 5, 4)
>>> B = MatrixSymbol('B', 4, 3)
>>> C = MatrixSymbol('C', 3, 6)
>>> MatMul(A, B, C)
A*B*C
"""
is_MatMul = True
def __new__(cls, *args, **kwargs):
check = kwargs.get('check', True)
args = list(map(sympify, args))
obj = Basic.__new__(cls, *args)
factor, matrices = obj.as_coeff_matrices()
if check:
validate(*matrices)
return obj
@property
def shape(self):
matrices = [arg for arg in self.args if arg.is_Matrix]
return (matrices[0].rows, matrices[-1].cols)
def _entry(self, i, j, expand=True):
coeff, matrices = self.as_coeff_matrices()
if len(matrices) == 1: # situation like 2*X, matmul is just X
return coeff * matrices[0][i, j]
head, tail = matrices[0], matrices[1:]
assert len(tail) != 0
X = head
Y = MatMul(*tail)
from sympy.core.symbol import Dummy
from sympy.concrete.summations import Sum
from sympy.matrices import ImmutableMatrix, MatrixBase
k = Dummy('k', integer=True)
if X.has(ImmutableMatrix) or Y.has(ImmutableMatrix):
return coeff*Add(*[X[i, k]*Y[k, j] for k in range(X.cols)])
result = Sum(coeff*X[i, k]*Y[k, j], (k, 0, X.cols - 1))
return result.doit() if expand else result
def as_coeff_matrices(self):
scalars = [x for x in self.args if not x.is_Matrix]
matrices = [x for x in self.args if x.is_Matrix]
coeff = Mul(*scalars)
return coeff, matrices
def as_coeff_mmul(self):
coeff, matrices = self.as_coeff_matrices()
return coeff, MatMul(*matrices)
def _eval_transpose(self):
return MatMul(*[transpose(arg) for arg in self.args[::-1]]).doit()
def _eval_adjoint(self):
return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit()
def _eval_trace(self):
factor, mmul = self.as_coeff_mmul()
if factor != 1:
from .trace import Trace
return factor * Trace(mmul)
else:
raise NotImplementedError("Can't simplify any further")
def _eval_determinant(self):
from sympy.matrices.expressions.determinant import Determinant
factor, matrices = self.as_coeff_matrices()
square_matrices = only_squares(*matrices)
return factor**self.rows * Mul(*list(map(Determinant, square_matrices)))
def _eval_inverse(self):
try:
return MatMul(*[
arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1
for arg in self.args[::-1]]).doit()
except ShapeError:
from sympy.matrices.expressions.inverse import Inverse
return Inverse(self)
def doit(self, **kwargs):
deep = kwargs.get('deep', False)
if deep:
args = [arg.doit(**kwargs) for arg in self.args]
else:
args = self.args
return canonicalize(MatMul(*args))
def validate(*matrices):
""" Checks for valid shapes for args of MatMul """
for i in range(len(matrices)-1):
A, B = matrices[i:i+2]
if A.cols != B.rows:
raise ShapeError("Matrices %s and %s are not aligned"%(A, B))
# Rules
def newmul(*args):
if args[0] == 1:
args = args[1:]
return new(MatMul, *args)
def any_zeros(mul):
if any([arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix)
for arg in mul.args]):
matrices = [arg for arg in mul.args if arg.is_Matrix]
return ZeroMatrix(matrices[0].rows, matrices[-1].cols)
return mul
def xxinv(mul):
""" Y * X * X.I -> Y """
from sympy.matrices.expressions import Inverse
factor, matrices = mul.as_coeff_matrices()
for i, (X, Y) in enumerate(zip(matrices[:-1], matrices[1:])):
try:
if X.is_square and Y.is_square and X == Y.inverse():
I = Identity(X.rows)
return newmul(factor, *(matrices[:i] + [I] + matrices[i+2:]))
except ValueError: # Y might not be invertible
pass
return mul
def remove_ids(mul):
""" Remove Identities from a MatMul
This is a modified version of sympy.strategies.rm_id.
This is necesssary because MatMul may contain both MatrixExprs and Exprs
as args.
See Also
--------
sympy.strategies.rm_id
"""
# Separate Exprs from MatrixExprs in args
factor, mmul = mul.as_coeff_mmul()
# Apply standard rm_id for MatMuls
result = rm_id(lambda x: x.is_Identity is True)(mmul)
if result != mmul:
return newmul(factor, *result.args) # Recombine and return
else:
return mul
def factor_in_front(mul):
factor, matrices = mul.as_coeff_matrices()
if factor != 1:
return newmul(factor, *matrices)
return mul
rules = (any_zeros, remove_ids, xxinv, unpack, rm_id(lambda x: x == 1),
factor_in_front, flatten)
canonicalize = exhaust(typed({MatMul: do_one(*rules)}))
def only_squares(*matrices):
""" factor matrices only if they are square """
assert matrices[0].rows == matrices[-1].cols
out = []
start = 0
for i, M in enumerate(matrices):
if M.cols == matrices[start].rows:
out.append(MatMul(*matrices[start:i+1]).doit())
start = i+1
return out