撰寫自訂陣列容器#

NumPy 的分派機制,在 numpy 版本 v1.16 中引入,是撰寫自訂 N 維陣列容器的建議方法,這些容器與 numpy API 相容,並提供 numpy 功能的自訂實作。應用範例包括 dask 陣列(一種分佈在多個節點上的 N 維陣列)和 cupy 陣列(一種在 GPU 上的 N 維陣列)。

為了讓您感受一下撰寫自訂陣列容器,我們將從一個簡單的範例開始,這個範例的實用性相當狹窄,但說明了所涉及的概念。

>>> import numpy as np
>>> class DiagonalArray:
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None, copy=None):
...         if copy is False:
...             raise ValueError(
...                 "`copy=False` isn't supported. A copy is always created."
...             )
...         return self._i * np.eye(self._N, dtype=dtype)

我們的自訂陣列可以像這樣實例化

>>> arr = DiagonalArray(5, 1)
>>> arr
DiagonalArray(N=5, value=1)

我們可以使用 numpy.arraynumpy.asarray 轉換為 numpy 陣列,這將呼叫其 __array__ 方法以取得標準的 numpy.ndarray

>>> np.asarray(arr)
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

如果我們使用 numpy 函數對 arr 進行操作,numpy 將再次使用 __array__ 介面將其轉換為陣列,然後以通常的方式應用該函數。

>>> np.multiply(arr, 2)
array([[2., 0., 0., 0., 0.],
       [0., 2., 0., 0., 0.],
       [0., 0., 2., 0., 0.],
       [0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 2.]])

請注意,回傳類型是標準的 numpy.ndarray

>>> type(np.multiply(arr, 2))
<class 'numpy.ndarray'>

我們如何讓我們的自訂陣列類型通過這個函數?Numpy 允許一個類別透過 __array_ufunc____array_function__ 介面表明它希望以自訂方式處理計算。讓我們一次處理一個,從 __array_ufunc__ 開始。此方法涵蓋 通用函數 (ufunc),這是一類函數,例如包括 numpy.multiplynumpy.sin

__array_ufunc__ 接收

  • ufunc,一個像 numpy.multiply 這樣的函數

  • method,一個字串,區分 numpy.multiply(...) 和變體,例如 numpy.multiply.outernumpy.multiply.accumulate 等。對於常見的情況 numpy.multiply(...)method == '__call__'

  • inputs,可能是不同類型的混合

  • kwargs,傳遞給函數的關鍵字參數

在這個範例中,我們只會處理 __call__ 方法

>>> from numbers import Number
>>> class DiagonalArray:
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None, copy=None):
...         if copy is False:
...             raise ValueError(
...                 "`copy=False` isn't supported. A copy is always created."
...             )
...         return self._i * np.eye(self._N, dtype=dtype)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != input._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = input._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented

現在我們的自訂陣列類型可以通過 numpy 函數。

>>> arr = DiagonalArray(5, 1)
>>> np.multiply(arr, 3)
DiagonalArray(N=5, value=3)
>>> np.add(arr, 3)
DiagonalArray(N=5, value=4)
>>> np.sin(arr)
DiagonalArray(N=5, value=0.8414709848078965)

此時 arr + 3 無法運作。

>>> arr + 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for +: 'DiagonalArray' and 'int'

為了支援它,我們需要定義 Python 介面 __add____lt__ 等,以分派到相應的 ufunc。我們可以透過繼承 mixin NDArrayOperatorsMixin 來方便地實現這一點。

>>> import numpy.lib.mixins
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None, copy=None):
...         if copy is False:
...             raise ValueError(
...                 "`copy=False` isn't supported. A copy is always created."
...             )
...         return self._i * np.eye(self._N, dtype=dtype)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != input._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = input._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
>>> arr = DiagonalArray(5, 1)
>>> arr + 3
DiagonalArray(N=5, value=4)
>>> arr > 0
DiagonalArray(N=5, value=True)

現在讓我們處理 __array_function__。我們將建立一個字典,將 numpy 函數對應到我們的自訂變體。

>>> HANDLED_FUNCTIONS = {}
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None, copy=None):
...         if copy is False:
...             raise ValueError(
...                 "`copy=False` isn't supported. A copy is always created."
...             )
...         return self._i * np.eye(self._N, dtype=dtype)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 # In this case we accept only scalar numbers or DiagonalArrays.
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != input._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = input._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
...     def __array_function__(self, func, types, args, kwargs):
...         if func not in HANDLED_FUNCTIONS:
...             return NotImplemented
...         # Note: this allows subclasses that don't override
...         # __array_function__ to handle DiagonalArray objects.
...         if not all(issubclass(t, self.__class__) for t in types):
...             return NotImplemented
...         return HANDLED_FUNCTIONS[func](*args, **kwargs)
...

一個方便的模式是定義一個裝飾器 implements,可以用於將函數添加到 HANDLED_FUNCTIONS

>>> def implements(np_function):
...    "Register an __array_function__ implementation for DiagonalArray objects."
...    def decorator(func):
...        HANDLED_FUNCTIONS[np_function] = func
...        return func
...    return decorator
...

現在我們為 DiagonalArray 撰寫 numpy 函數的實作。為了完整起見,為了支援用法 arr.sum(),新增一個方法 sum,它會呼叫 numpy.sum(self)mean 也是如此。

>>> @implements(np.sum)
... def sum(arr):
...     "Implementation of np.sum for DiagonalArray objects"
...     return arr._i * arr._N
...
>>> @implements(np.mean)
... def mean(arr):
...     "Implementation of np.mean for DiagonalArray objects"
...     return arr._i / arr._N
...
>>> arr = DiagonalArray(5, 1)
>>> np.sum(arr)
5
>>> np.mean(arr)
0.2

如果使用者嘗試使用任何未包含在 HANDLED_FUNCTIONS 中的 numpy 函數,numpy 將會引發 TypeError,表示不支援此操作。例如,串聯兩個 DiagonalArrays 不會產生另一個對角陣列,因此不支援。

>>> np.concatenate([arr, arr])
Traceback (most recent call last):
...
TypeError: no implementation found for 'numpy.concatenate' on types that implement __array_function__: [<class '__main__.DiagonalArray'>]

此外,我們對 summean 的實作不接受 numpy 實作所接受的可選參數。

>>> np.sum(arr, axis=0)
Traceback (most recent call last):
...
TypeError: sum() got an unexpected keyword argument 'axis'

使用者始終可以選擇使用 numpy.asarray 轉換為正常的 numpy.ndarray,並從那裡使用標準的 numpy。

>>> np.concatenate([np.asarray(arr), np.asarray(arr)])
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

為了簡潔起見,本範例中 DiagonalArray 的實作僅處理 np.sumnp.mean 函數。Numpy API 中的許多其他函數也可用於包裝,而功能完善的自訂陣列容器可以明確支援 Numpy 提供的所有可包裝函數。

Numpy 提供了一些實用工具,以協助測試在 numpy.testing.overrides 命名空間中實作 __array_ufunc____array_function__ 協定的自訂陣列容器。

若要檢查是否可以透過 __array_ufunc__ 覆寫 Numpy 函數,您可以使用 allows_array_ufunc_override

>>> from numpy.testing.overrides import allows_array_ufunc_override
>>> allows_array_ufunc_override(np.add)
True

同樣地,您可以使用 allows_array_function_override 檢查是否可以透過 __array_function__ 覆寫函數。

Numpy API 中每個可覆寫函數的列表也可透過 get_overridable_numpy_array_functions 取得,適用於支援 __array_function__ 協定的函數,以及透過 get_overridable_numpy_ufuncs 取得,適用於支援 __array_ufunc__ 協定的函數。這兩個函數都會傳回 Numpy 公開 API 中存在的函數集。使用者定義的 ufunc 或在其他依賴 Numpy 的程式庫中定義的 ufunc 不會出現在這些集合中。

請參閱 dask 原始碼cupy 原始碼,以取得更完整的自訂陣列容器範例。

另請參閱 NEP 18。(在 NumPy 增強提案中)