Como usar a função NumPy argmax() em Python
Publicados: 2022-09-14Neste tutorial, você aprenderá como usar a função NumPy argmax() para encontrar o índice do elemento máximo em arrays.
NumPy é uma biblioteca poderosa para computação científica em Python; ele fornece arrays N-dimensionais com melhor desempenho do que as listas do Python. Uma das operações comuns que você realizará ao trabalhar com matrizes NumPy é encontrar o valor máximo na matriz. No entanto, às vezes você pode querer localizar o índice no qual ocorre o valor máximo.
A função argmax() ajuda a encontrar o índice do máximo em arrays unidimensionais e multidimensionais. Vamos continuar a aprender como funciona.
Como encontrar o índice do elemento máximo em uma matriz NumPy
Para acompanhar este tutorial, você precisa ter o Python e o NumPy instalados. Você pode codificar iniciando um Python REPL ou iniciando um notebook Jupyter.
Primeiro, vamos importar NumPy sob o alias usual np .
import numpy as np Você pode usar a função NumPy max() para obter o valor máximo em uma matriz (opcionalmente ao longo de um eixo específico).
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10 Nesse caso, np.max(array_1) retorna 10, o que está correto.
Suponha que você queira encontrar o índice no qual o valor máximo ocorre na matriz. Você pode adotar a seguinte abordagem em duas etapas:
- Encontre o elemento máximo.
- Encontre o índice do elemento máximo.
Em array_1 , o valor máximo de 10 ocorre no índice 4, seguindo a indexação zero. O primeiro elemento está no índice 0; o segundo elemento está no índice 1 e assim por diante.

Para encontrar o índice no qual ocorre o máximo, você pode usar a função NumPy where(). np.where(condition) retorna um array de todos os índices onde a condition é True .
Você terá que tocar no array e acessar o item no primeiro índice. Para descobrir onde ocorre o valor máximo, definimos a condition para array_1==10 ; lembre-se de que 10 é o valor máximo em array_1 .
print(int(np.where(array_1==10)[0])) # Output 4 Usamos np.where() apenas com a condição, mas este não é o método recomendado para usar esta função.
Nota: NumPy where() Função :
np.where(condition,x,y)retorna:– Elementos de
xquando a condição forTrue, e
– Elementos deyquando a condição forFalse.
Portanto, encadeando as np.max() e np.where() , podemos encontrar o elemento máximo, seguido do índice em que ele ocorre.
Em vez do processo de duas etapas acima, você pode usar a função NumPy argmax() para obter o índice do elemento máximo na matriz.
Sintaxe da função NumPy argmax()
A sintaxe geral para usar a função NumPy argmax() é a seguinte:
np.argmax(array,axis,out) # we've imported numpy under the alias npNa sintaxe acima:
- array é qualquer array NumPy válido.
- axis é um parâmetro opcional. Ao trabalhar com matrizes multidimensionais, você pode usar o parâmetro axis para encontrar o índice de máximo ao longo de um eixo específico.
- out é outro parâmetro opcional. Você pode definir o parâmetro
outpara um array NumPy para armazenar a saída da funçãoargmax().
Nota : A partir da versão 1.22.0 do NumPy, há um parâmetro
keepdimsadicional. Quando especificamos o parâmetroaxisna chamada da funçãoargmax(), a matriz é reduzida ao longo desse eixo. Mas definir o parâmetrokeepdimscomoTruegarante que a saída retornada tenha a mesma forma que a matriz de entrada.
Usando NumPy argmax() para encontrar o índice do elemento máximo
#1 . Vamos usar a função NumPy argmax() para encontrar o índice do elemento máximo em array_1 .
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4 A função argmax() retorna 4, o que está correto!
#2 . Se redefinirmos array_1 de forma que 10 ocorra duas vezes, a função argmax() retornará apenas o índice da primeira ocorrência.
array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4 Para o restante dos exemplos, usaremos os elementos de array_1 que definimos no exemplo #1.
Usando NumPy argmax() para encontrar o índice do elemento máximo em uma matriz 2D
Vamos remodelar o array NumPy array_1 em um array bidimensional com duas linhas e quatro colunas.
array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]] Para uma matriz bidimensional, o eixo 0 denota as linhas e o eixo 1 denota as colunas. As matrizes NumPy seguem a indexação zero . Portanto, os índices das linhas e colunas da matriz NumPy array_2 são os seguintes:

Agora, vamos chamar a função argmax() no array bidimensional, array_2 .
print(np.argmax(array_2)) # Output 4 Embora tenhamos chamado argmax() no array bidimensional, ele ainda retorna 4. Isso é idêntico à saída para o array unidimensional, array_1 da seção anterior.
Por que isso acontece?
Isso ocorre porque não especificamos nenhum valor para o parâmetro axis. Quando este parâmetro de eixo não está definido, por padrão, a função argmax() retorna o índice do elemento máximo ao longo do array nivelado.
O que é uma matriz achatada? Se houver uma matriz N-dimensional de forma d1 x d2 x … x dN , onde d1, d2, até dN são os tamanhos da matriz ao longo das N dimensões, então a matriz achatada é uma longa matriz unidimensional de tamanho d1 * d2 * … * dN.
Para verificar a aparência do array achatado para array_2 , você pode chamar o método flatten() , conforme mostrado abaixo:
array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])Índice do elemento máximo ao longo das linhas (eixo = 0)
Vamos prosseguir para encontrar o índice do elemento máximo ao longo das linhas (eixo = 0).
np.argmax(array_2,axis=0) # Output array([1, 1, 1, 1])Essa saída pode ser um pouco difícil de compreender, mas vamos entender como ela funciona.
Definimos o parâmetro axis como zero ( axis = 0 ), pois gostaríamos de encontrar o índice do elemento máximo ao longo das linhas. Portanto, a função argmax() retorna o número da linha em que o elemento máximo ocorre — para cada uma das três colunas.
Vamos visualizar isso para melhor compreensão.

Do diagrama acima e da saída argmax() , temos o seguinte:
- Para a primeira coluna no índice 0, o valor máximo 10 ocorre na segunda linha, no índice = 1.
- Para a segunda coluna no índice 1, o valor máximo 9 ocorre na segunda linha, no índice = 1.
- Para a terceira e quarta colunas no índice 2 e 3, os valores máximos 8 e 4 ocorrem na segunda linha, no índice = 1.
É exatamente por isso que temos o array([1, 1, 1, 1]) porque o elemento máximo ao longo das linhas ocorre na segunda linha (para todas as colunas).
Índice do Elemento Máximo ao Longo das Colunas (eixo = 1)
Em seguida, vamos usar a função argmax() para encontrar o índice do elemento máximo ao longo das colunas.
Execute o trecho de código a seguir e observe a saída.
np.argmax(array_2,axis=1) array([2, 0])Você pode analisar a saída?
Definimos axis = 1 para calcular o índice do elemento máximo ao longo das colunas.
A função argmax() retorna, para cada linha, o número da coluna em que ocorre o valor máximo.
Aqui está uma explicação visual:

Do diagrama acima e da saída argmax() , temos o seguinte:
- Para a primeira linha no índice 0, o valor máximo 7 ocorre na terceira coluna, no índice = 2.
- Para a segunda linha no índice 1, o valor máximo 10 ocorre na primeira coluna, no índice = 0.
Espero que agora você entenda o que a saída, array([2, 0]) significa.
Usando o parâmetro opcional de saída no NumPy argmax ()
Você pode usar o parâmetro opcional out the na função NumPy argmax() para armazenar a saída em uma matriz NumPy.
Vamos inicializar um array de zeros para armazenar a saída da chamada de função argmax() anterior – para encontrar o índice do máximo ao longo das colunas ( axis= 1 ).
out_arr = np.zeros((2,)) print(out_arr) [0. 0.] Agora, vamos revisitar o exemplo de encontrar o índice do elemento máximo ao longo das colunas ( axis = 1 ) e definir out como out_arr que definimos acima.
np.argmax(array_2,axis=1,out=out_arr) Vemos que o interpretador Python lança um TypeError , pois o out_arr foi inicializado para um array de floats por padrão.
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe' Portanto, ao definir o parâmetro out para a matriz de saída, é importante garantir que a matriz de saída tenha a forma e o tipo de dados corretos. Como os índices de array são sempre inteiros, devemos definir o parâmetro dtype como int ao definir o array de saída.
out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0] Agora podemos ir em frente e chamar a função argmax() com os parâmetros axis e out e, desta vez, ela é executada sem erros.
np.argmax(array_2,axis=1,out=out_arr) A saída da função argmax() agora pode ser acessada no array out_arr .
print(out_arr) # Output [2 0]Conclusão
Espero que este tutorial tenha ajudado você a entender como usar a função NumPy argmax(). Você pode executar os exemplos de código em um notebook Jupyter.
Vamos rever o que aprendemos.
- A função NumPy argmax() retorna o índice do elemento máximo em uma matriz. Se o elemento máximo ocorrer mais de uma vez em uma matriz a , então np.argmax(a) retornará o índice da primeira ocorrência do elemento.
- Ao trabalhar com matrizes multidimensionais, você pode usar o parâmetro opcional axis para obter o índice do elemento máximo ao longo de um eixo específico. Por exemplo, em uma matriz bidimensional: definindo axis = 0 e axis = 1 , você pode obter o índice do elemento máximo ao longo das linhas e colunas, respectivamente.
- Se você quiser armazenar o valor retornado em outro array, você pode definir o parâmetro opcional out para o array de saída. No entanto, a matriz de saída deve ter formato e tipo de dados compatíveis.
Em seguida, confira o guia detalhado sobre conjuntos de Python.

