Numpy where 详解
参考:Exploring numpy.where Function
numpy.where()
函数是 Numpy 库中一个非常有用的函数,它在很多实际的数据处理场景中都能发挥重要的作用。本文将详细介绍 numpy.where()
函数的用法和相关注意事项。
语法
numpy.where(condition, x, y)
condition
:一个逻辑条件,当条件为 True 时,在输出中选择x
,否则选择y
。x
:满足条件的输出值。y
:不满足条件的输出值。
示例1:条件判断
import numpy as np
data = np.array([1, 2, 3, 4, 5])
condition = data > 3
result = np.where(condition, data, 0)
print(result)
运行结果:
在上面的示例中,我们首先创建了一个包含 1 到 5 的数组 data
,然后定义了一个条件 data > 3
,找出数组中大于 3 的元素。最后使用 np.where()
函数根据条件将满足条件的元素保留,不满足条件的元素替换为 0。
示例2:用零替换负值
可以通过对数组元素的负数进行判断,将负数元素赋值为0。示例代码如下:
import numpy as np
arr = np.array([-1, 2, -3, 4, -5])
result = np.where(arr < 0, 0, arr)
print(result)
运行结果:
示例3:多维数组做判断
对多维数组的进行判断,设定一个阈值,将大于或者小于阈值的元素设置为固定值。示例代码如下:
import numpy as np
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
threshold = 5
result = np.where(arr > threshold, arr, 0)
print(result)
运行结果:
示例4:应用在字符串数组
对于字符串类型的数组,也可以使用np.where()
进行操作。示例代码如下:
import numpy as np
arr = np.array(['data', 'numpywhere.com', 'geek-docs.com', 'deepinout.com'])
result = np.where(arr == 'data', 'Web', arr)
print(result)
运行结果:
注意事项
condition
、x
和y
的形状必须相同或者能够广播成相同形状。x
和y
可以是标量,数组或者函数。