numpy.take_along_axis#

numpy.take_along_axis(arr, indices, axis)[原始碼]#

透過比對一維索引和資料切片,從輸入陣列中取出值。

此函數會迭代索引和資料陣列中,沿著指定軸向的一維切片,並使用前者在後者中查找值。這些切片可以有不同的長度。

返回沿軸索引的函數,例如 argsortargpartition,會為此函數產生合適的索引。

參數:
arrndarray (Ni…, M, Nk…)

來源陣列

indicesndarray (Ni…, J, Nk…)

沿著 arr 的每個一維切片取值的索引。這必須符合 arr 的維度,但維度 Ni 和 Nj 只需要對 arr 進行廣播。

axisint

沿著哪個軸取一維切片。如果 axis 為 None,則輸入陣列會被視為首先展平成一維,以便與 sortargsort 保持一致。

回傳:
out: ndarray (Ni…, J, Nk…)

索引結果。

另請參閱

take

沿著軸取值,對每個一維切片使用相同的索引

put_along_axis

透過比對一維索引和資料切片,將值放入目標陣列

註解

這等效於(但比以下使用 ndindexs_ 的方法更快),它將 iikk 各自設定為索引元組

Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:]
J = indices.shape[axis]  # Need not equal M
out = np.empty(Ni + (J,) + Nk)

for ii in ndindex(Ni):
    for kk in ndindex(Nk):
        a_1d       = a      [ii + s_[:,] + kk]
        indices_1d = indices[ii + s_[:,] + kk]
        out_1d     = out    [ii + s_[:,] + kk]
        for j in range(J):
            out_1d[j] = a_1d[indices_1d[j]]

等效地,消除內部迴圈,最後兩行會是

out_1d[:] = a_1d[indices_1d]

範例

>>> import numpy as np

對於這個範例陣列

>>> a = np.array([[10, 30, 20], [60, 40, 50]])

我們可以透過直接使用 sort,或使用 argsort 和此函數來排序

>>> np.sort(a, axis=1)
array([[10, 20, 30],
       [40, 50, 60]])
>>> ai = np.argsort(a, axis=1)
>>> ai
array([[0, 2, 1],
       [1, 2, 0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[10, 20, 30],
       [40, 50, 60]])

如果您使用 keepdims 維護了微不足道的維度,則最大值和最小值也適用

>>> np.max(a, axis=1, keepdims=True)
array([[30],
       [60]])
>>> ai = np.argmax(a, axis=1, keepdims=True)
>>> ai
array([[1],
       [0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[30],
       [60]])

如果我們想要同時取得最大值和最小值,我們可以先堆疊索引

>>> ai_min = np.argmin(a, axis=1, keepdims=True)
>>> ai_max = np.argmax(a, axis=1, keepdims=True)
>>> ai = np.concatenate([ai_min, ai_max], axis=1)
>>> ai
array([[0, 1],
       [1, 0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[10, 30],
       [40, 60]])