Понимание Einsum NumPy

192

Я изо всех сил пытаюсь понять, как именно einsumработает. Я посмотрел на документацию и несколько примеров, но это не похоже на прилипание.

Вот пример, который мы рассмотрели в классе:

C = np.einsum("ij,jk->ki", A, B)

для двух массивов AиB

Я думаю, что это заняло бы A^T * B, но я не уверен (это принимает транспонирование одного из них правильно?). Может кто-нибудь рассказать мне, что именно здесь происходит (и вообще при использовании einsum)?

Ланс пролив
источник
7
На самом деле это будет (A * B)^T, или эквивалентно B^T * A^T.
Тигран Салуев
23
Я написал небольшой пост в блоге об основах einsum здесь . (Я с радостью перенесу наиболее важные фрагменты в ответ по переполнению стека, если это будет полезно).
Алекс Райли
1
@ajcr - Красивая ссылка. Спасибо. numpyДокументация неадекватна при объяснении деталей.
Rayryeng
Спасибо за вотум доверия! С опозданием я привел ответ ниже .
Алекс Райли
Обратите внимание, что в Python *это не матричное умножение, а поэлементное умножение. Осторожно!
ComputerScientist

Ответы:

373

(Примечание: этот ответ основан на кратком сообщении в блоге о том, что einsumя написал некоторое время назад.)

Что делает einsum?

Представьте, что у нас есть два многомерных массива, Aи B. Теперь давайте предположим, что мы хотим ...

  • умножить A с Bопределенным образом, чтобы создать новый массив продуктов; а потом возможно
  • суммируйте этот новый массив вдоль определенных осей; а потом возможно
  • транспонировать оси нового массива в определенном порядке.

Там хороший шанс , что einsumпоможет нам сделать это быстрее и больше памяти, эффективно , что комбинации функций NumPy нравится multiply, sumи transposeпозволит.

Как einsumработает?

Вот простой (но не совсем тривиальный) пример. Возьмите следующие два массива:

A = np.array([0, 1, 2])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

Мы умножим Aи Bпоэлементно, а затем суммируем по строкам нового массива. В «нормальном» NumPy мы написали бы:

>>> (A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])

Итак, здесь операция индексации Aвыстраивает в линию первые оси двух массивов, чтобы умножение можно было транслировать. Строки массива продуктов затем суммируются, чтобы вернуть ответ.

Теперь, если мы хотим использовать einsumвместо этого, мы могли бы написать:

>>> np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])

Подписи строка 'i,ij->i'является ключом здесь и нуждается в немного объяснить. Вы можете думать об этом в две половины. С левой стороны (слева от ->) мы пометили два входных массива. Справа от ->мы пометили массив, который мы хотим получить.

Вот что происходит дальше:

  • Aимеет одну ось; мы пометили это i. И Bимеет две оси; мы пометили ось 0 как iи ось 1 как j.

  • При повторив метку iв обеих входных массивах, мы говорим , einsumчто эти две оси должны быть умножены вместе. Другими словами, мы умножаем массив Aна каждый столбец массива B, как это A[:, np.newaxis] * Bделает.

  • Обратите внимание, что jв нашем желаемом выводе нет метки; мы только что использовали i(мы хотим получить массив 1D). По опуская этикетку, мы говорим , einsumчтобы подвести вдоль этой оси. Другими словами, мы суммируем ряды продуктов, точно так же, как это .sum(axis=1)делает.

Это в основном все, что вам нужно знать, чтобы использовать einsum. Это помогает немного поиграть; если мы оставим обе метки в выводе, 'i,ij->ij'мы вернемся к двумерному массиву продуктов (так же, как A[:, np.newaxis] * B). Если мы говорим, что нет выходных меток, 'i,ij->мы возвращаем одно число (то же самое, что и делать (A[:, np.newaxis] * B).sum()).

Однако самое интересное в том einsum, что сначала не создается временный массив продуктов; это просто суммирует продукты, как это идет. Это может привести к большой экономии в использовании памяти.

Немного больший пример

Чтобы объяснить скалярное произведение, вот два новых массива:

A = array([[1, 1, 1],
           [2, 2, 2],
           [5, 5, 5]])

B = array([[0, 1, 0],
           [1, 1, 0],
           [1, 1, 1]])

Мы вычислим скалярное произведение, используя np.einsum('ij,jk->ik', A, B). Вот рисунок, показывающий маркировку Aи Bи выходного массива, который мы получаем из функции:

введите описание изображения здесь

Вы можете видеть, что метка jповторяется - это означает, что мы умножаем строки Aна столбцы B. Кроме того, метка jне включена в выходные данные - мы суммируем эти продукты. Метки iи kсохраняются для вывода, поэтому мы получаем двумерный массив.

Это может быть еще более ясным , чтобы сравнить этот результат с массивом , где метка jнаходится не подводились. Ниже, слева вы можете увидеть трехмерный массив, полученный в результате записи np.einsum('ij,jk->ijk', A, B)(т.е. мы сохранили метку j):

введите описание изображения здесь

Суммирующая ось jдает ожидаемое произведение точек, показанное справа.

Некоторые упражнения

Чтобы получить больше ощущений einsum, может быть полезно реализовать знакомые операции с массивами NumPy, используя нижнюю запись. Все, что включает в себя комбинации умножения и суммирования осей, может быть написано с использованием einsum.

Пусть A и B - два одномерных массива одинаковой длины. Например, A = np.arange(10)и B = np.arange(5, 15).

  • Сумма Aможет быть записана:

    np.einsum('i->', A)
  • Поэлементное умножение A * B, можно записать так:

    np.einsum('i,i->i', A, B)
  • Внутренний продукт или точечный продукт, np.inner(A, B)или np.dot(A, B), может быть записан:

    np.einsum('i,i->', A, B) # or just use 'i,i'
  • На внешнем произведении np.outer(A, B)можно записать:

    np.einsum('i,j->ij', A, B)

Для двумерных массивов Cи Dпри условии, что оси имеют совместимые длины (обе имеют одинаковую длину или одну из них имеет длину 1), вот несколько примеров:

  • След C(сумма главной диагонали) np.trace(C)можно записать так:

    np.einsum('ii', C)
  • Поэлементное умножение Cи транспонированные D, C * D.Tможно записать:

    np.einsum('ij,ji->ij', C, D)
  • Умножая каждый элемент Cна массив D(чтобы сделать массив 4D) C[:, :, None, None] * D, можно записать:

    np.einsum('ij,kl->ijkl', C, D)  
Алекс Райли
источник
1
Очень хорошее объяснение, спасибо. «Обратите внимание, что я не появляюсь в качестве метки в нашем желаемом выводе» - не так ли?
Ян Хинкс
Спасибо @IanHincks! Это похоже на опечатку; Я исправил это сейчас.
Алекс Райли
1
Очень хороший ответ Также стоит отметить, что ij,jkмог бы работать сам (без стрелок) для формирования матрицы умножения. Но, похоже, для наглядности лучше поставить стрелки, а затем и выходные размеры. Это в блоге.
ComputerScientist
1
@Peaceful: это один из тех случаев, когда трудно выбрать правильное слово! Я чувствую, что «столбец» подходит немного лучше, поскольку Aимеет длину 3, такую ​​же, как длина столбцов в B(тогда как строки Bимеют длину 4 и не могут быть умножены на элемент A).
Алекс Райли
1
Обратите внимание, что пропуск ->семантики влияет: «В неявном режиме выбранные индексы важны, так как оси вывода переупорядочены в алфавитном порядке. Это означает, что np.einsum('ij', a)не влияет на двумерный массив, а np.einsum('ji', a)занимает его транспонирование».
BallpointBen
41

Понять идею numpy.einsum()очень легко, если вы понимаете ее интуитивно. В качестве примера давайте начнем с простого описания, включающего умножение матриц .


Чтобы использовать numpy.einsum(), все, что вам нужно сделать, это передать так называемую строку индексов в качестве аргумента, а затем ваши входные массивы .

Скажем , у вас есть два 2D массивов, Aи B, и вы хотите сделать матричное умножение. Итак, вы делаете:

np.einsum("ij, jk -> ik", A, B)

Здесь нижняя строка ij соответствует массиву, Aа нижняя строка jk соответствует массиву B. Кроме того, самое важное, что следует отметить, это то, что количество символов в каждой строке индекса должно соответствовать размерам массива. (т.е. два символа для 2D-массивов, три символа для 3D-массивов и т. д.) И если вы повторяете символы между строками индекса ( jв нашем случае), то это означает, что вы хотите, чтобы einсумма происходила по этим измерениям. Таким образом, они будут уменьшены. (то есть это измерение исчезнет )

Подстрочный строка после этого ->, будет наш результирующий массив. Если вы оставите это поле пустым, все будет суммировано, и в качестве результата будет возвращено скалярное значение. В противном случае результирующий массив будет иметь размеры в соответствии со строкой индекса . В нашем примере это будет ik. Это интуитивно понятно, потому что мы знаем, что для умножения матрицы количество столбцов в массиве Aдолжно соответствовать количеству строк в массиве, Bчто и происходит здесь (т.е. мы кодируем это знание, повторяя символ jв строке индекса )


Вот еще несколько примеров, иллюстрирующих использование / мощь np.einsum()в реализации некоторых общих тензорных или nd-массивных операций, кратко.

входные

# a vector
In [197]: vec
Out[197]: array([0, 1, 2, 3])

# an array
In [198]: A
Out[198]: 
array([[11, 12, 13, 14],
       [21, 22, 23, 24],
       [31, 32, 33, 34],
       [41, 42, 43, 44]])

# another array
In [199]: B
Out[199]: 
array([[1, 1, 1, 1],
       [2, 2, 2, 2],
       [3, 3, 3, 3],
       [4, 4, 4, 4]])

1) Матричное умножение (аналогично np.matmul(arr1, arr2))

In [200]: np.einsum("ij, jk -> ik", A, B)
Out[200]: 
array([[130, 130, 130, 130],
       [230, 230, 230, 230],
       [330, 330, 330, 330],
       [430, 430, 430, 430]])

2) Извлечь элементы по главной диагонали (аналогично np.diag(arr))

In [202]: np.einsum("ii -> i", A)
Out[202]: array([11, 22, 33, 44])

3) произведение Адамара (т.е. поэлементное произведение двух массивов) (аналогично arr1 * arr2)

In [203]: np.einsum("ij, ij -> ij", A, B)
Out[203]: 
array([[ 11,  12,  13,  14],
       [ 42,  44,  46,  48],
       [ 93,  96,  99, 102],
       [164, 168, 172, 176]])

4) Поэлементное возведение в квадрат (аналогично np.square(arr)или arr ** 2)

In [210]: np.einsum("ij, ij -> ij", B, B)
Out[210]: 
array([[ 1,  1,  1,  1],
       [ 4,  4,  4,  4],
       [ 9,  9,  9,  9],
       [16, 16, 16, 16]])

5) Трассировка (т.е. сумма элементов главной диагонали) (аналогично np.trace(arr))

In [217]: np.einsum("ii -> ", A)
Out[217]: 110

6) Матрица транспонировать (аналогично np.transpose(arr))

In [221]: np.einsum("ij -> ji", A)
Out[221]: 
array([[11, 21, 31, 41],
       [12, 22, 32, 42],
       [13, 23, 33, 43],
       [14, 24, 34, 44]])

7) Наружное произведение (векторов) (аналогично np.outer(vec1, vec2))

In [255]: np.einsum("i, j -> ij", vec, vec)
Out[255]: 
array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]])

8) Внутренний продукт (векторов) (аналогично np.inner(vec1, vec2))

In [256]: np.einsum("i, i -> ", vec, vec)
Out[256]: 14

9) Сумма по оси 0 (аналогично np.sum(arr, axis=0))

In [260]: np.einsum("ij -> j", B)
Out[260]: array([10, 10, 10, 10])

10) Сумма по оси 1 (аналогично np.sum(arr, axis=1))

In [261]: np.einsum("ij -> i", B)
Out[261]: array([ 4,  8, 12, 16])

11) Пакетное умножение матриц

In [287]: BM = np.stack((A, B), axis=0)

In [288]: BM
Out[288]: 
array([[[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]],

       [[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3],
        [ 4,  4,  4,  4]]])

In [289]: BM.shape
Out[289]: (2, 4, 4)

# batch matrix multiply using einsum
In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)

In [293]: BMM
Out[293]: 
array([[[1350, 1400, 1450, 1500],
        [2390, 2480, 2570, 2660],
        [3430, 3560, 3690, 3820],
        [4470, 4640, 4810, 4980]],

       [[  10,   10,   10,   10],
        [  20,   20,   20,   20],
        [  30,   30,   30,   30],
        [  40,   40,   40,   40]]])

In [294]: BMM.shape
Out[294]: (2, 4, 4)

12) Сумма по оси 2 (аналогично np.sum(arr, axis=2))

In [330]: np.einsum("ijk -> ij", BM)
Out[330]: 
array([[ 50,  90, 130, 170],
       [  4,   8,  12,  16]])

13) Суммируйте все элементы в массиве (аналогично np.sum(arr))

In [335]: np.einsum("ijk -> ", BM)
Out[335]: 480

14) Сумма по нескольким осям (т.е. маргинализация)
(аналогично np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7)))

# 8D array
In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))

# marginalize out axis 5 (i.e. "n" here)
In [363]: esum = np.einsum("ijklmnop -> n", R)

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))

In [365]: np.allclose(esum, nsum)
Out[365]: True

15) Продукты Double Dot (аналогично np.sum (hadamard-product), см. 3 )

In [772]: A
Out[772]: 
array([[1, 2, 3],
       [4, 2, 2],
       [2, 3, 4]])

In [773]: B
Out[773]: 
array([[1, 4, 7],
       [2, 5, 8],
       [3, 6, 9]])

In [774]: np.einsum("ij, ij -> ", A, B)
Out[774]: 124

16) 2D и 3D умножение массива

Такое умножение может быть очень полезно при решении линейной системы уравнений ( Ax = b ), где вы хотите проверить результат.

# inputs
In [115]: A = np.random.rand(3,3)
In [116]: b = np.random.rand(3, 4, 5)

# solve for x
In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)

# 2D and 3D array multiplication :)
In [118]: Ax = np.einsum('ij, jkl', A, x)

# indeed the same!
In [119]: np.allclose(Ax, b)
Out[119]: True

Напротив, если нужно использовать np.matmul()для этой проверки, мы должны сделать пару reshapeопераций для достижения того же результата, как:

# reshape 3D array `x` to 2D, perform matmul
# then reshape the resultant array to 3D
In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)

# indeed correct!
In [124]: np.allclose(Ax, Ax_matmul)
Out[124]: True

Бонус : Читайте больше математики здесь: Суммирование Эйнштейна и определенно здесь: Тензорная запись

kmario23
источник
7

Давайте сделаем 2 массива с разными, но совместимыми размерами, чтобы подчеркнуть их взаимодействие

In [43]: A=np.arange(6).reshape(2,3)
Out[43]: 
array([[0, 1, 2],
       [3, 4, 5]])


In [44]: B=np.arange(12).reshape(3,4)
Out[44]: 
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

Ваш расчет принимает «точку» (сумму произведений) от (2,3) с (3,4) для получения массива (4,2). i1-й тусклый A, последний из C; kпоследний из B1-го C. j«потребляется» суммированием.

In [45]: C=np.einsum('ij,jk->ki',A,B)
Out[45]: 
array([[20, 56],
       [23, 68],
       [26, 80],
       [29, 92]])

Это так же, как np.dot(A,B).T- это конечный результат, который транспонирован.

Чтобы узнать больше о том, что происходит j, измените Cподписки на ijk:

In [46]: np.einsum('ij,jk->ijk',A,B)
Out[46]: 
array([[[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [16, 18, 20, 22]],

       [[ 0,  3,  6,  9],
        [16, 20, 24, 28],
        [40, 45, 50, 55]]])

Это также может быть произведено с:

A[:,:,None]*B[None,:,:]

То есть добавить kизмерение в конец Aиi перед B, чтобы получить массив (2,3,4).

0 + 4 + 16 = 20и 9 + 28 + 55 = 92т. д .; Суммируйте jи перенесите, чтобы получить более ранний результат:

np.sum(A[:,:,None] * B[None,:,:], axis=1).T

# C[k,i] = sum(j) A[i,j (,k) ] * B[(i,)  j,k]
hpaulj
источник
7

Я нашел NumPy: уловки торговли (Часть II) поучительными

Мы используем ->, чтобы указать порядок выходного массива. Так что думайте о «ij, i-> j» как о левой стороне (LHS) и правой стороне (RHS). Любое повторение меток на LHS вычисляет элемент продукта, а затем суммирует. Изменяя метку на стороне RHS (выходной), мы можем определить ось, по которой мы хотим перейти относительно входного массива, то есть суммирование по оси 0, 1 и так далее.

import numpy as np

>>> a
array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
>>> b
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
>>> d = np.einsum('ij, jk->ki', a, b)

Обратите внимание, что есть три оси, i, j, k, и что j повторяется (с левой стороны). i,jпредставлять строки и столбцы для a. j,kдля b.

Чтобы рассчитать произведение и выровнять jось, нам нужно добавить ось a. ( bбудет транслироваться вдоль (?) первой оси)

a[i, j, k]
   b[j, k]

>>> c = a[:,:,np.newaxis] * b
>>> c
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 0,  2,  4],
        [ 6,  8, 10],
        [12, 14, 16]],

       [[ 0,  3,  6],
        [ 9, 12, 15],
        [18, 21, 24]]])

jотсутствует в правой части, поэтому мы суммируем по jвторой оси массива 3x3x3

>>> c = c.sum(1)
>>> c
array([[ 9, 12, 15],
       [18, 24, 30],
       [27, 36, 45]])

Наконец, индексы (в алфавитном порядке) обращены справа, поэтому мы транспонируем.

>>> c.T
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])

>>> np.einsum('ij, jk->ki', a, b)
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])
>>>
WWII
источник
NumPy: уловки торговли (часть II), кажется, требуют приглашения от владельца сайта, а также учетной записи Wordpress
Tejas Shetty
... обновил ссылку, к счастью я нашел ее с помощью поиска. - Спасибо.
Второй
@TejasShetty Здесь много хороших ответов - может быть, мне стоит удалить этот.
В
2
Пожалуйста, не удаляйте свой ответ.
Техас Шетти
5

Читая уравнения Эйнсума, я обнаружил, что наиболее полезно просто уметь мысленно сводить их к их императивным версиям.

Начнем со следующего (внушительного) утверждения:

C = np.einsum('bhwi,bhwj->bij', A, B)

Прорабатывая пунктуацию сначала, мы видим, что у нас есть два 4-буквенных двоеточия, разделенных запятыми - bhwiи bhwj, перед стрелкой, и один 3-буквенный шарикbij после нее. Следовательно, уравнение дает тензорный результат ранга 3 из двух тензорных входов ранга 4.

Теперь пусть каждая буква в каждом двоичном объекте будет именем переменной диапазона. Позиция, в которой буква появляется в BLOB-объекте, является индексом оси, в которой она находится в этом тензоре. Следовательно, императивное суммирование, которое производит каждый элемент C, должно начинаться с трех вложенных циклов for, по одному для каждого индекса C.

for b in range(...):
    for i in range(...):
        for j in range(...):
            # the variables b, i and j index C in the order of their appearance in the equation
            C[b, i, j] = ...

Итак, по сути, у вас есть forцикл для каждого выходного индекса C. Пока мы оставим диапазоны неопределенными.

Далее мы посмотрим на левую сторону - есть ли какие-то переменные диапазона, которые не появляются с правой стороны? В нашем случае - да, hи w. Добавьте внутренний вложенный forцикл для каждой такой переменной:

for b in range(...):
    for i in range(...):
        for j in range(...):
            C[b, i, j] = 0
            for h in range(...):
                for w in range(...):
                    ...

Внутри самого внутреннего цикла у нас теперь определены все индексы, поэтому мы можем записать фактическое суммирование и перевод завершен:

# three nested for-loops that index the elements of C
for b in range(...):
    for i in range(...):
        for j in range(...):

            # prepare to sum
            C[b, i, j] = 0

            # two nested for-loops for the two indexes that don't appear on the right-hand side
            for h in range(...):
                for w in range(...):
                    # Sum! Compare the statement below with the original einsum formula
                    # 'bhwi,bhwj->bij'

                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]

Если вы уже смогли следовать коду, то поздравляю! Это все, что вам нужно, чтобы уметь читать уравнения Эйнсума. В частности, обратите внимание на то, как исходная формула einsum отображается в окончательный оператор суммирования в приведенном выше фрагменте. Циклы for и границы диапазона - просто пух, и это последнее утверждение - все, что вам действительно нужно, чтобы понять, что происходит.

Для полноты картины давайте посмотрим, как определить диапазоны для каждой переменной диапазона. Ну, диапазон каждой переменной - это просто длина измерения (й), которое она индексирует. Очевидно, что если переменная индексирует более одного измерения в одном или нескольких тензорах, то длины каждого из этих измерений должны быть равны. Вот код выше с полными диапазонами:

# C's shape is determined by the shapes of the inputs
# b indexes both A and B, so its range can come from either A.shape or B.shape
# i indexes only A, so its range can only come from A.shape, the same is true for j and B
assert A.shape[0] == B.shape[0]
assert A.shape[1] == B.shape[1]
assert A.shape[2] == B.shape[2]
C = np.zeros((A.shape[0], A.shape[3], B.shape[3]))
for b in range(A.shape[0]): # b indexes both A and B, or B.shape[0], which must be the same
    for i in range(A.shape[3]):
        for j in range(B.shape[3]):
            # h and w can come from either A or B
            for h in range(A.shape[1]):
                for w in range(A.shape[2]):
                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]
Стефан Драгнев
источник
0

Я думаю, что самый простой пример в документах tenorflow

Есть четыре шага, чтобы преобразовать ваше уравнение в систему обозначений Einsum. Возьмем это уравнение в качестве примера.C[i,k] = sum_j A[i,j] * B[j,k]

  1. Сначала мы отбрасываем имена переменных. Мы получилиik = sum_j ij * jk
  2. Мы отбрасываем sum_jтермин, поскольку он неявный. Мы получилиik = ij * jk
  3. Мы заменяем *на ,. Мы получилиik = ij, jk
  4. Выход находится на RHS и отделен ->знаком. Мы получилиij, jk -> ik

Интерпретатор einsum просто выполняет эти 4 шага в обратном порядке. Все отсутствующие в результате индексы суммируются.

Вот еще несколько примеров из документации

# Matrix multiplication
einsum('ij,jk->ik', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]

# Dot product
einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]

# Outer product
einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]

# Transpose
einsum('ij->ji', m)  # output[j,i] = m[i,j]

# Trace
einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]

# Batch matrix multiplication
einsum('aij,ajk->aik', s, t)  # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
Сурадип Нанда
источник