numpy.put_along_axis#
- numpy.put_along_axis(arr, indices, values, axis)[原始碼]#
透過比對 1 維索引和資料切片,將值放入目標陣列。
此函數會迭代索引和資料陣列中,沿著指定軸向的相符 1 維切片,並使用前者將值放入後者。這些切片可以有不同的長度。
沿著軸傳回索引的函數,例如
argsort
和argpartition
,會產生適用於此函數的索引。- 參數:
- arrndarray (Ni…, M, Nk…)
目標陣列。
- indicesndarray (Ni…, J, Nk…)
沿著 arr 的每個 1 維切片變更的索引。這必須符合 arr 的維度,但 Ni 和 Nj 中的維度可以是 1,以便與 arr 廣播。
- valuesarray_like (Ni…, J, Nk…)
要插入到這些索引的值。其形狀和維度會廣播以符合
indices
的形狀和維度。- axisint
沿著哪個軸取 1 維切片。如果 axis 為 None,則目標陣列會被視為已建立其扁平化的 1 維視圖。
另請參閱
take_along_axis
透過比對 1 維索引和資料切片,從輸入陣列中取得值
附註
這等效於(但速度更快)以下使用
ndindex
和s_
的方式,它將每個ii
和kk
設定為索引的元組Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:] J = indices.shape[axis] # Need not equal M for ii in ndindex(Ni): for kk in ndindex(Nk): a_1d = a [ii + s_[:,] + kk] indices_1d = indices[ii + s_[:,] + kk] values_1d = values [ii + s_[:,] + kk] for j in range(J): a_1d[indices_1d[j]] = values_1d[j]
等效地,消除內部迴圈,最後兩行會是
a_1d[indices_1d] = values_1d
範例
>>> import numpy as np
對於此範例陣列
>>> a = np.array([[10, 30, 20], [60, 40, 50]])
我們可以將最大值替換為
>>> ai = np.argmax(a, axis=1, keepdims=True) >>> ai array([[1], [0]]) >>> np.put_along_axis(a, ai, 99, axis=1) >>> a array([[10, 99, 20], [99, 40, 50]])