numpy.take_along_axis#
- numpy.take_along_axis(arr, indices, axis)[原始碼]#
透過比對一維索引和資料切片,從輸入陣列中取出值。
此函數會迭代索引和資料陣列中,沿著指定軸向的一維切片,並使用前者在後者中查找值。這些切片可以有不同的長度。
返回沿軸索引的函數,例如
argsort
和argpartition
,會為此函數產生合適的索引。- 參數:
- 回傳:
- out: ndarray (Ni…, J, Nk…)
索引結果。
另請參閱
take
沿著軸取值,對每個一維切片使用相同的索引
put_along_axis
透過比對一維索引和資料切片,將值放入目標陣列
註解
這等效於(但比以下使用
ndindex
和s_
的方法更快),它將ii
和kk
各自設定為索引元組Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:] J = indices.shape[axis] # Need not equal M out = np.empty(Ni + (J,) + Nk) for ii in ndindex(Ni): for kk in ndindex(Nk): a_1d = a [ii + s_[:,] + kk] indices_1d = indices[ii + s_[:,] + kk] out_1d = out [ii + s_[:,] + kk] for j in range(J): out_1d[j] = a_1d[indices_1d[j]]
等效地,消除內部迴圈,最後兩行會是
out_1d[:] = a_1d[indices_1d]
範例
>>> import numpy as np
對於這個範例陣列
>>> a = np.array([[10, 30, 20], [60, 40, 50]])
我們可以透過直接使用 sort,或使用 argsort 和此函數來排序
>>> np.sort(a, axis=1) array([[10, 20, 30], [40, 50, 60]]) >>> ai = np.argsort(a, axis=1) >>> ai array([[0, 2, 1], [1, 2, 0]]) >>> np.take_along_axis(a, ai, axis=1) array([[10, 20, 30], [40, 50, 60]])
如果您使用
keepdims
維護了微不足道的維度,則最大值和最小值也適用>>> np.max(a, axis=1, keepdims=True) array([[30], [60]]) >>> ai = np.argmax(a, axis=1, keepdims=True) >>> ai array([[1], [0]]) >>> np.take_along_axis(a, ai, axis=1) array([[30], [60]])
如果我們想要同時取得最大值和最小值,我們可以先堆疊索引
>>> ai_min = np.argmin(a, axis=1, keepdims=True) >>> ai_max = np.argmax(a, axis=1, keepdims=True) >>> ai = np.concatenate([ai_min, ai_max], axis=1) >>> ai array([[0, 1], [1, 0]]) >>> np.take_along_axis(a, ai, axis=1) array([[10, 30], [40, 60]])