1 | @tf_export("linalg.matmul", "matmul") |
Args:
a: Tensor
of type float16
, float32
, float64
, int32
, complex64
,
complex128
and rank > 1.
b: Tensor
with same type and rank as a
.
transpose_a: If True
, a
is transposed before multiplication.
transpose_b: If True
, b
is transposed before multiplication.
adjoint_a: If True
, a
is conjugated and transposed before
multiplication.
adjoint_b: If True
, b
is conjugated and transposed before
multiplication.
a_is_sparse: If True
, a
is treated as a sparse matrix.
b_is_sparse: If True
, b
is treated as a sparse matrix.
name: Name for the operation (optional).
Returns:
A Tensor
of the same type as a
and b
where each inner-most matrix is
the product of the corresponding matrices in a
and b
, e.g. if all
transpose or adjoint attributes are False
:
`output`[..., i, j] = sum_k (`a`[..., i, k] * `b`[..., k, j]),
for all indices i, j.
Note: This is matrix product, not element-wise product.
Raises:
ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b
are both set to True.
“””
with ops.name_scope(name, “MatMul”, [a, b]) as name:
if transpose_a and adjoint_a:
raise ValueError(“Only one of transpose_a and adjoint_a can be True.”)
if transpose_b and adjoint_b:
raise ValueError(“Only one of transpose_b and adjoint_b can be True.”)
if context.executing_eagerly():
if not isinstance(a, (ops.EagerTensor, _resource_variable_type)):
a = ops.convert_to_tensor(a, name="a")
if not isinstance(b, (ops.EagerTensor, _resource_variable_type)):
b = ops.convert_to_tensor(b, name="b")
else:
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
# TODO(apassos) remove _shape_tuple here when it is not needed.
a_shape = a._shape_tuple() # pylint: disable=protected-access
b_shape = b._shape_tuple() # pylint: disable=protected-access
if fwd_compat.forward_compatible(2019, 4, 25):
output_may_have_non_empty_batch_shape = (
(a_shape is None or len(a_shape) > 2) or
(b_shape is None or len(b_shape) > 2))
batch_mat_mul_fn = gen_math_ops.batch_mat_mul_v2
else:
output_may_have_non_empty_batch_shape = (
(a_shape is None or len(a_shape) > 2) and
(b_shape is None or len(b_shape) > 2))
batch_mat_mul_fn = gen_math_ops.batch_mat_mul
if (not a_is_sparse and
not b_is_sparse) and output_may_have_non_empty_batch_shape:
# BatchMatmul does not support transpose, so we conjugate the matrix and
# use adjoint instead. Conj() is a noop for real matrices.
if transpose_a:
a = conj(a)
adjoint_a = True
if transpose_b:
b = conj(b)
adjoint_b = True
return batch_mat_mul_fn(a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
# Neither matmul nor sparse_matmul support adjoint, so we conjugate
# the matrix and use transpose instead. Conj() is a noop for real
# matrices.
if adjoint_a:
a = conj(a)
transpose_a = True
if adjoint_b:
b = conj(b)
transpose_b = True
use_sparse_matmul = False
if a_is_sparse or b_is_sparse:
sparse_matmul_types = [dtypes.bfloat16, dtypes.float32]
use_sparse_matmul = (
a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types)
if ((a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16) and
a.dtype != b.dtype):
# matmul currently doesn't handle mixed-precision inputs.
use_sparse_matmul = True
if use_sparse_matmul:
ret = sparse_matmul(
a,
b,
transpose_a=transpose_a,
transpose_b=transpose_b,
a_is_sparse=a_is_sparse,
b_is_sparse=b_is_sparse,
name=name)
# sparse_matmul always returns float32, even with
# bfloat16 inputs. This prevents us from configuring bfloat16 training.
# casting to bfloat16 also matches non-sparse matmul behavior better.
if a.dtype == dtypes.bfloat16 and b.dtype == dtypes.bfloat16:
ret = cast(ret, dtypes.bfloat16)
return ret
else:
return gen_math_ops.mat_mul(
a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)