Conditional logic as array operations – where

The numpy.where function is a vectorised version of if and else.

In the following example, we first create a Boolean array and two arrays with values:

[1]:
import numpy as np
[2]:
cond = [False, True, False, True, False, False, False]
default_rng = np.random.default_rng()
data1 = default_rng.random(7)
data2 = default_rng.random(7)

Now we want to take the values from data1 if the corresponding value in cond is True and otherwise take the value from data2. With Python’s if-else, this could look like this:

[3]:
result = np.array(
    [(x if c else y) for x, y, c in zip(data1, data2, cond, strict=True)]
)

result
[3]:
array([0.88985185, 0.34568582, 0.76929043, 0.67471859, 0.26424192,
       0.6908897 , 0.88940803])

However, this has the following two problems:

  • with large arrays the function will not be very fast

  • this will not work with multidimensional arrays

With np.where you can work around these problems in a single function call:

[4]:
result = np.where(cond, data1, data2)

result
[4]:
array([0.88985185, 0.34568582, 0.76929043, 0.67471859, 0.26424192,
       0.6908897 , 0.88940803])

The second and third arguments of np.where do not have to be arrays; one or both can also be scalars. A typical use of where in data analysis is to create a new array of values based on another array. Suppose you have a matrix of randomly generated data and you want to make all the negative values positive values:

[5]:
data = default_rng.random(size=(4, 4))

data
[5]:
array([[0.36584626, 0.91970031, 0.13618506, 0.24056537],
       [0.59235974, 0.40813171, 0.71144251, 0.08058474],
       [0.21988608, 0.50140418, 0.41486371, 0.8820743 ],
       [0.4000475 , 0.85196835, 0.33001728, 0.07373272]])
[6]:
data > 0.9
[6]:
array([[False,  True, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False]])
[7]:
np.where(data > 0.9, np.nan, data)
[7]:
array([[0.36584626,        nan, 0.13618506, 0.24056537],
       [0.59235974, 0.40813171, 0.71144251, 0.08058474],
       [0.21988608, 0.50140418, 0.41486371, 0.8820743 ],
       [0.4000475 , 0.85196835, 0.33001728, 0.07373272]])