Explain Codes LogoExplain Codes Logo

How do I get indices of N maximum values in a NumPy array?

python
numpy
argpartition
sorting
Nikita BarsukovbyNikita Barsukov·Oct 15, 2024
TLDR

To quickly obtain the indices of the N largest values in a NumPy array, we can employ np.argpartition for preliminary sorting. By chaining it with [-N:], we can extract the top N elements. For precise ordering, a subsequent operation with np.argsort is needed:

import numpy as np # Here is your array. Just an ordinary list of numbers... or is it? arr = np.array([1, 3, 2, 4, 5]) # The number of maximum values you're after. N = 3 # argpartition - the sorting hat of Hogwarts for NumPy arrays. idx = np.argpartition(arr, -N)[-N:] # One last touch with argsort to make sure our highest values are dressed in descending order. result = idx[np.argsort(arr[idx])[::-1]] # Drum roll please... print(result) # [4, 3, 1] Yup! They got the front seats!

Unlocking the potential of argpartition

np.argpartition is your trusty wizard for partial sorting. It efficiently calls forth the k-th smallest or largest element in linear-time. Remember though, the sorting hat's job ends there. The output indices need further grooming as they are not ordered.

Time complexity in fetching ordered array

When combined with np.argsort, np.argpartition can recover the top-k elements in a sorted fashion within O(n + k log k) time. This spell works by firstly ensuring the k-th element is at its final sorted position using np.argpartition. It then shoos all larger elements to the right and smaller ones to the left.

A peek into other spells

Apart from np.argpartition, there's also (-arr).argsort()[:N]. While it's quicker in extracting indices without much emphasis on order, it might be not as efficient in the longer run:

indices = (-arr).argsort()[:N] # More like a rush hour spell when you're running late.

Oh, and there's another spell in Python's grimoire - heapq.nlargest, which can also find top N values.

Advanced tricks and hacks

Ace in a Multidimensional Universe

When the battlefront shifts to multidimensional arrays, our knight in shining armor is np.unravel_index. It rolls the index positions flattened by np.argpartition back into an adaptable coordinate system within the array.

Time Complexity Unveiled

Demystifying time complexity is key. With np.argpartition being O(n) in the worst-case scenario, and the complete operation for top-k sorted retrieval equating to O(n + k log k) with further sorting, it's clear why the magic works.

Cython - The Hidden Realm

Look out for Cython - the hidden passage to C-level optimizations. Combine it with NumPy and make your spells work more efficiently, especially on large datasets.

The Enigmatic heapq

As N grows larger, the enchanting heapq becomes more appealing. Its heapq.nlargest spell is very handy with its simplicity and built-in support for key functions.

Pro-gamer moves

Adjusting N as per your needs

Tailor N to your needs. If only indices matter, go for (-arr).argsort()[:N]. If you care about the order too, the duo of np.argpartition and np.argsort is your friend.

A clever balancing act

Juggle between readability and execution speed. NumPy spells can be faster, but Python's built-in spells like sorted() or heapq.nlargest might provide clearer syntax or better fit some tasks/datasets.