tf.matmul()函数用法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@tf_export("linalg.matmul", "matmul")
@dispatch.add_dispatch_support
def matmul(a,
b,
transpose_a=False,
transpose_b=False,
adjoint_a=False,
adjoint_b=False,
a_is_sparse=False,
b_is_sparse=False,
name=None):
"""Multiplies matrix `a` by matrix `b`, producing `a` * `b`.

The inputs must, following any transpositions, be tensors of rank >= 2
where the inner 2 dimensions specify valid matrix multiplication arguments,
and any further outer dimensions match.

Both matrices must be of the same type. The supported types are:
`float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`.

Either matrix can be transposed or adjointed (conjugated and transposed) on
the fly by setting one of the corresponding flag to `True`. These are `False`
by default.

If one or both of the matrices contain a lot of zeros, a more efficient
multiplication algorithm can be used by setting the corresponding
`a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
This optimization is only available for plain matrices (rank-2 tensors) with
datatypes `bfloat16` or `float32`.

For example:

```python
# 2-D tensor `a`
# [[1, 2, 3],
# [4, 5, 6]]
a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])

# 2-D tensor `b`
# [[ 7, 8],
# [ 9, 10],
# [11, 12]]
b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2])

# `a` * `b`
# [[ 58, 64],
# [139, 154]]
c = tf.matmul(a, b)


# 3-D tensor `a`
# [[[ 1, 2, 3],
# [ 4, 5, 6]],
# [[ 7, 8, 9],
# [10, 11, 12]]]
a = tf.constant(np.arange(1, 13, dtype=np.int32),
shape=[2, 2, 3])

# 3-D tensor `b`
# [[[13, 14],
# [15, 16],
# [17, 18]],
# [[19, 20],
# [21, 22],
# [23, 24]]]
b = tf.constant(np.arange(13, 25, dtype=np.int32),
shape=[2, 3, 2])

# `a` * `b`
# [[[ 94, 100],
# [229, 244]],
# [[508, 532],
# [697, 730]]]
c = tf.matmul(a, b)

# Since python >= 3.5 the @ operator is supported (see PEP 465).
# In TensorFlow, it simply calls the `tf.matmul()` function, so the
# following lines are equivalent:
d = a @ b @ [[10.], [11.]]
d = tf.matmul(tf.matmul(a, b), [[10.], [11.]])

Args:
a: Tensor of type float16, float32, float64, int32, complex64, complex128 and rank > 1.
b: Tensor with same type and rank as a. transpose_a: If True, a is transposed before multiplication.
transpose_b: If True, b is transposed before multiplication.
adjoint_a: If True, a is conjugated and transposed before
multiplication.
adjoint_b: If True, b is conjugated and transposed before
multiplication.
a_is_sparse: If True, a is treated as a sparse matrix.
b_is_sparse: If True, b is treated as a sparse matrix.
name: Name for the operation (optional).

Returns:
A Tensor of the same type as a and b where each inner-most matrix is
the product of the corresponding matrices in a and b, e.g. if all
transpose or adjoint attributes are False:

`output`[..., i, j] = sum_k (`a`[..., i, k] * `b`[..., k, j]),
for all indices i, j.

Note: This is matrix product, not element-wise product.

Raises:
ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b
are both set to True.
“””
with ops.name_scope(name, “MatMul”, [a, b]) as name:
if transpose_a and adjoint_a:
raise ValueError(“Only one of transpose_a and adjoint_a can be True.”)
if transpose_b and adjoint_b:
raise ValueError(“Only one of transpose_b and adjoint_b can be True.”)

if context.executing_eagerly():
  if not isinstance(a, (ops.EagerTensor, _resource_variable_type)):
    a = ops.convert_to_tensor(a, name="a")
  if not isinstance(b, (ops.EagerTensor, _resource_variable_type)):
    b = ops.convert_to_tensor(b, name="b")
else:
  a = ops.convert_to_tensor(a, name="a")
  b = ops.convert_to_tensor(b, name="b")

# TODO(apassos) remove _shape_tuple here when it is not needed.
a_shape = a._shape_tuple()  # pylint: disable=protected-access
b_shape = b._shape_tuple()  # pylint: disable=protected-access

if fwd_compat.forward_compatible(2019, 4, 25):
  output_may_have_non_empty_batch_shape = (
      (a_shape is None or len(a_shape) > 2) or
      (b_shape is None or len(b_shape) > 2))
  batch_mat_mul_fn = gen_math_ops.batch_mat_mul_v2
else:
  output_may_have_non_empty_batch_shape = (
      (a_shape is None or len(a_shape) > 2) and
      (b_shape is None or len(b_shape) > 2))
  batch_mat_mul_fn = gen_math_ops.batch_mat_mul

if (not a_is_sparse and
    not b_is_sparse) and output_may_have_non_empty_batch_shape:
  # BatchMatmul does not support transpose, so we conjugate the matrix and
  # use adjoint instead. Conj() is a noop for real matrices.
  if transpose_a:
    a = conj(a)
    adjoint_a = True
  if transpose_b:
    b = conj(b)
    adjoint_b = True
  return batch_mat_mul_fn(a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)

# Neither matmul nor sparse_matmul support adjoint, so we conjugate
# the matrix and use transpose instead. Conj() is a noop for real
# matrices.
if adjoint_a:
  a = conj(a)
  transpose_a = True
if adjoint_b:
  b = conj(b)
  transpose_b = True

use_sparse_matmul = False
if a_is_sparse or b_is_sparse:
  sparse_matmul_types = [dtypes.bfloat16, dtypes.float32]
  use_sparse_matmul = (
      a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types)
if ((a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16) and
    a.dtype != b.dtype):
  # matmul currently doesn't handle mixed-precision inputs.
  use_sparse_matmul = True
if use_sparse_matmul:
  ret = sparse_matmul(
      a,
      b,
      transpose_a=transpose_a,
      transpose_b=transpose_b,
      a_is_sparse=a_is_sparse,
      b_is_sparse=b_is_sparse,
      name=name)
  # sparse_matmul always returns float32, even with
  # bfloat16 inputs. This prevents us from configuring bfloat16 training.
  # casting to bfloat16 also matches non-sparse matmul behavior better.
  if a.dtype == dtypes.bfloat16 and b.dtype == dtypes.bfloat16:
    ret = cast(ret, dtypes.bfloat16)
  return ret
else:
  return gen_math_ops.mat_mul(
      a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)