numpy.testing.assert_array_equal#

testing.assert_array_equal(actual, desired, err_msg='', verbose=True, *, strict=False)[原始碼]#

若兩個類陣列物件不相等,則引發 AssertionError。

給定兩個類陣列物件,檢查形狀是否相等,以及這些物件的所有元素是否相等 (但請參閱「註釋」以了解純量值的特殊處理方式)。若形狀不符或值衝突,則會引發例外。與 numpy 中的標準用法相反,NaN 會像數字一樣進行比較;若兩個物件在相同位置都有 NaN,則不會引發 assertion。

建議對驗證浮點數的相等性保持通常的謹慎態度。

註釋

actualdesired 其中之一已是 numpy.ndarray 的實例,且 desired 不是 dict 時,assert_equal(actual, desired) 的行為與此函數的行為相同。否則,此函數會在比較之前對輸入執行 np.asanyarray,而 assert_equal 為常見的 Python 類型定義了特殊的比較規則。例如,只有 assert_equal 可以用於比較巢狀 Python 列表。在新程式碼中,請考慮僅使用 assert_equal,並在需要 assert_array_equal 行為時,明確地將 actualdesired 轉換為陣列。

參數:
actual類陣列

要檢查的實際物件。

desired類陣列

期望的物件。

err_msgstr,選用

在失敗時要印出的錯誤訊息。

verbosebool,選用

若為 True,則衝突的值會附加到錯誤訊息中。

strictbool,選用

若為 True,當類陣列物件的形狀或資料類型不符時,引發 AssertionError。「註釋」章節中提及的純量值特殊處理方式會停用。

版本 1.24.0 新增。

引發:
AssertionError

若 actual 和 desired 物件不相等。

另請參閱

assert_allclose

比較兩個類陣列物件的相等性,具有期望的相對和/或絕對精度。

assert_array_almost_equal_nulpassert_array_max_ulpassert_equal

註釋

actualdesired 其中之一是純量值,而另一個是類陣列時,此函數會檢查類陣列物件的每個元素是否等於純量值。此行為可以使用 strict 參數停用。

範例

第一個 assert 不會引發例外

>>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
...                               [np.exp(0),2.33333, np.nan])

Assert 在浮點數的數值不精確時失敗

>>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
...                               [1, np.sqrt(np.pi)**2, np.nan])
Traceback (most recent call last):
    ...
AssertionError:
Arrays are not equal

Mismatched elements: 1 / 3 (33.3%)
Max absolute difference among violations: 4.4408921e-16
Max relative difference among violations: 1.41357986e-16
 ACTUAL: array([1.      , 3.141593,      nan])
 DESIRED: array([1.      , 3.141593,      nan])

在這些情況下,請改用 assert_allclose 或其中一個 nulp (浮點數值數量) 函數

>>> np.testing.assert_allclose([1.0,np.pi,np.nan],
...                            [1, np.sqrt(np.pi)**2, np.nan],
...                            rtol=1e-10, atol=0)

如「註釋」章節所述,assert_array_equal 對純量值有特殊處理方式。此處的測試檢查 x 中的每個值是否為 3

>>> x = np.full((2, 5), fill_value=3)
>>> np.testing.assert_array_equal(x, 3)

使用 strict 在比較純量值與陣列時引發 AssertionError

>>> np.testing.assert_array_equal(x, 3, strict=True)
Traceback (most recent call last):
    ...
AssertionError:
Arrays are not equal

(shapes (2, 5), () mismatch)
 ACTUAL: array([[3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3]])
 DESIRED: array(3)

strict 參數也確保陣列資料類型相符

>>> x = np.array([2, 2, 2])
>>> y = np.array([2., 2., 2.], dtype=np.float32)
>>> np.testing.assert_array_equal(x, y, strict=True)
Traceback (most recent call last):
    ...
AssertionError:
Arrays are not equal

(dtypes int64, float32 mismatch)
 ACTUAL: array([2, 2, 2])
 DESIRED: array([2., 2., 2.], dtype=float32)