Как работает python numpy.where ()?

94

Я играю numpyи копаюсь в документации, и я наткнулся на некоторую магию. А именно я говорю о numpy.where():

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

Как они внутренне добиваются того, чтобы вы могли передать что-то вроде x > 5метода? Я думаю, это как-то связано, __gt__но я ищу подробное объяснение.

пайтон
источник

Ответы:

75

Как они внутренне добиваются того, чтобы вы могли передать в метод что-то вроде x> 5?

Короткий ответ - нет.

Любая логическая операция с массивом numpy возвращает логический массив. (то есть __gt__, __lt__и т.д. все возвращают логические массивы , где данное условие является истинным).

Например

x = np.arange(9).reshape(3,3)
print x > 5

дает:

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

Это та же самая причина, почему что-то вроде if x > 5:вызывает ValueError, если xэто массив numpy. Это массив значений True / False, а не одно значение.

Кроме того, массивы numpy могут индексироваться логическими массивами. Например , в этом случае x[x>5]уступает [6 7 8].

Честно говоря, это довольно редко, что вам действительно нужно, numpy.whereно он просто возвращает индикаторы, где находится логический массив True. Обычно вы можете делать то, что вам нужно, с помощью простой логической индексации.

Джо Кингтон
источник
10
Просто чтобы указать, что numpy.whereу них есть 2 «рабочих режима», первый возвращает indices, где condition is Trueи если присутствуют необязательные параметры xи y( conditionтакая же форма, как или передаваемая в такую ​​форму!), Он вернет значения, xкогда в condition is Trueпротивном случае из y. Это делает его whereболее универсальным и позволяет использовать его чаще. Спасибо
ешь
1
В некоторых случаях также могут возникать накладные расходы при использовании __getitem__синтаксиса []over или numpy.whereили numpy.take. Поскольку __getitem__он также должен поддерживать нарезку, возникают некоторые накладные расходы. Я видел заметную разницу в скорости при работе со структурами данных Python Pandas и логической индексации очень больших столбцов. В тех случаях, если вам не нужна нарезка, тогда takeи whereлучше.
ely
24

Старый ответ сбивает с толку. Он дает вам МЕСТОПОЛОЖЕНИЯ (все) того, где ваше утверждение истинно.

так:

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

Я использую его как альтернативу list.index (), но он имеет и множество других применений. Я никогда не использовал его с 2D-массивами.

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

Новый ответ Похоже, человек спрашивал о чем-то более фундаментальном.

Вопрос заключался в том, как ВЫ могли реализовать что-то, что позволяет функции (например, где) знать, что было запрошено.

Прежде всего обратите внимание, что вызов любого из операторов сравнения делает интересную вещь.

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

Это делается путем перегрузки метода «__gt__». Например:

>>> class demo(object):
    def __gt__(self, item):
        print item


>>> a = demo()
>>> a > 4
4

Как видите, "a> 4" был допустимым кодом.

Вы можете получить полный список и документацию по всем перегруженным функциям здесь: http://docs.python.org/reference/datamodel.html

Невероятно то, насколько просто это сделать. ВСЕ операции в python выполняются таким образом. Сказать a> b эквивалентно a. gt (b)!

Гарретт Берг
источник
3
Эта перегрузка оператора сравнения, похоже, не очень хорошо работает с более сложными логическими выражениями - например, я не могу этого сделать np.where(a > 30 and a < 50)или np.where(30 < a < 50)потому, что в конечном итоге она пытается вычислить логическое И двух массивов логических значений, что довольно бессмысленно. Есть ли способ записать такое условие np.where?
davidA
@meowsqueaknp.where((a > 30) & (a < 50))
tibalt
Почему np.where () возвращает список в вашем примере?
Андреас Янкополус
0

np.whereвозвращает кортеж длины, равной размеру массива numpy ndarray, на котором он вызывается (другими словами ndim), и каждый элемент кортежа представляет собой массив numpy ndarray индексов всех тех значений в исходном массиве ndarray, для которых условие истинно. (Пожалуйста, не путайте размер с формой)

Например:

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))


y - это кортеж длиной 2, потому что x.ndimравен 2. Первый элемент в кортеже содержит номера строк всех элементов больше 4, а второй элемент содержит номера столбцов всех элементов больше 4. Как вы можете видеть, [1,2,2 , 2] соответствует номерам строк 5,6,7,8, а [2,0,1,2] соответствует номерам столбцов 5,6,7,8. Обратите внимание, что ndarray перемещается по первому измерению (по строкам ).

По аналогии,

x=np.arange(27).reshape(3,3,3)
np.where(x>4)


вернет кортеж длиной 3, потому что x имеет 3 измерения.

Но подождите, это еще не все!

когда добавляются два дополнительных аргумента np.where; он выполнит операцию замены для всех тех парных комбинаций строка-столбец, которые получаются указанным выше кортежем.

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])
Пиюш Сингх
источник