NumPy 中的資料類型提升#

當混合兩種不同的資料類型時,NumPy 必須決定運算結果的適當 dtype。此步驟稱為提升尋找通用 dtype

在一般情況下,使用者不需要擔心提升的細節,因為提升步驟通常確保結果將符合或超過輸入的精度。

例如,當輸入具有相同的 dtype 時,結果的 dtype 會符合輸入的 dtype

>>> np.int8(1) + np.int8(1)
np.int8(2)

混合兩種不同的 dtype 通常會產生具有較高精度輸入 dtype 的結果

>>> np.int8(4) + np.int64(8)  # 64 > 8
np.int64(12)
>>> np.float32(3) + np.float16(3)  # 32 > 16
np.float32(6.0)

在一般情況下,這不會導致意外。但是,如果您使用非預設 dtype(如無號整數和低精度浮點數),或者如果您混合 NumPy 整數、NumPy 浮點數和 Python 純量值,則 NumPy 提升規則的一些細節可能相關。請注意,這些詳細規則並不總是與其他語言的規則相符 [1]

數值 dtype 分為四種「種類」,具有自然的層次結構。

  1. 無號整數 (uint)

  2. 帶號整數 (int)

  3. 浮點數 (float)

  4. 複數 (complex)

除了種類之外,NumPy 數值 dtype 也具有相關的精度,以位元指定。種類和精度共同指定 dtype。例如,uint8 是使用 8 位元儲存的無號整數。

運算的結果將始終與任何輸入的種類相同或更高。此外,結果的精度將始終大於或等於輸入的精度。已經,這可能會導致一些可能出乎意料的範例

  1. 當混合浮點數和整數時,整數的精度可能會強制結果為更高的精度浮點數。例如,涉及 int64float16 的運算結果為 float64

  2. 當混合具有相同精度的無號和帶號整數時,結果將具有比任一輸入更高的精度。此外,如果其中一個已經具有 64 位元精度,則沒有更高的精度整數可用,例如涉及 int64uint64 的運算會產生 float64

請參閱數值提升章節和下方圖片,以了解兩者的詳細資訊。

Python 純量值的詳細行為#

自 NumPy 2.0 [2] 以來,我們提升規則中的一個重點是,雖然涉及兩個 NumPy dtype 的運算永遠不會損失精度,但涉及 NumPy dtype 和 Python 純量值(intfloatcomplex)的運算可能會損失精度。例如,Python 整數和 NumPy 整數之間的運算結果應該是 NumPy 整數,這可能是直觀的。但是,Python 整數具有任意精度,而所有 NumPy dtype 都具有固定精度,因此無法保留 Python 整數的任意精度。

更一般而言,NumPy 會考慮 Python 純量值的「種類」,但在決定結果 dtype 時會忽略其精度。這通常很方便。例如,當使用低精度 dtype 的陣列時,通常希望與 Python 純量值的簡單運算能夠保留 dtype。

>>> arr_float32 = np.array([1, 2.5, 2.1], dtype="float32")
>>> arr_float32 + 10.0  # undesirable to promote to float64
array([11. , 12.5, 12.1], dtype=float32)
>>> arr_int16 = np.array([3, 5, 7], dtype="int16")
>>> arr_int16 + 10  # undesirable to promote to int64
array([13, 15, 17], dtype=int16)

在這兩種情況下,結果精度都由 NumPy dtype 決定。因此,arr_float32 + 3.0 的行為與 arr_float32 + np.float32(3.0) 相同,而 arr_int16 + 10 的行為與 arr_int16 + np.int16(10.) 相同。

作為另一個範例,當將 NumPy 整數與 Python floatcomplex 混合時,結果始終具有類型 float64complex128

>> np.int16(1) + 1.0 np.float64(2.0)

但是,當使用低精度 dtype 時,這些規則也可能導致令人驚訝的行為。

首先,由於 Python 值在執行運算之前會轉換為 NumPy 值,因此當結果看起來很明顯時,運算可能會因錯誤而失敗。例如,np.int8(1) + 1000 無法繼續,因為 1000 超過 int8 的最大值。當 Python 純量值無法強制轉換為 NumPy dtype 時,會引發錯誤

>>> np.int8(1) + 1000
Traceback (most recent call last):
  ...
OverflowError: Python integer 1000 out of bounds for int8
>>> np.int64(1) * 10**100
Traceback (most recent call last):
...
OverflowError: Python int too large to convert to C long
>>> np.float32(1) + 1e300
np.float32(inf)
... RuntimeWarning: overflow encountered in cast

其次,由於 Python 浮點數或整數精度始終被忽略,因此低精度 NumPy 純量值將繼續使用其較低的精度,除非明確轉換為更高的精度 NumPy dtype 或 Python 純量值(例如,透過 int()float()scalar.item())。這種較低的精度可能不利於某些計算,或導致不正確的結果,尤其是在整數溢位的情況下

>>> np.int8(100) + 100  # the result exceeds the capacity of int8
np.int8(-56)
... RuntimeWarning: overflow encountered in scalar add

請注意,當純量值發生溢位時,NumPy 會發出警告,但不會針對陣列發出警告;例如,np.array(100, dtype="uint8") + 100不會發出警告。

數值提升#

下圖顯示了數值提升規則,種類在垂直軸上,精度在水平軸上。

../_images/nep-0050-promotion-no-fonts.svg

具有較高種類的輸入 dtype 決定了結果 dtype 的種類。結果 dtype 的精度盡可能低,但不會出現在圖表中任一輸入 dtype 的左側。

請注意以下特定規則和觀察

  1. 當 Python floatcomplex 與 NumPy 整數互動時,結果將為 float64complex128(黃色邊框)。NumPy 布林值也將轉換為預設整數 [3]。當也涉及 NumPy 浮點數值時,這並不相關。

  2. 精度的繪製方式使得 float16 < int16 < uint16,因為大的 uint16 不適合 int16,而大的 int16 儲存在 float16 中時會損失精度。然而,此模式被打破,因為 NumPy 始終認為 float64complex128 是任何整數值的可接受提升結果。

  3. 一個特殊情況是,NumPy 將許多帶號和無號整數的組合提升為 float64。此處使用較高的種類,因為沒有帶號整數 dtype 具有足夠的精度來容納 uint64

一般提升規則的例外#

在 NumPy 中,提升指的是特定函式對結果的作用,在某些情況下,這表示 NumPy 可能會偏離 np.result_type 給出的結果。

sumprod 的行為#

np.sumnp.prod 在對整數值(或布林值)求和時,始終會傳回預設整數類型。這通常是 int64。這樣做的原因是,整數求和非常容易溢位並產生混淆的結果。此規則也適用於底層的 np.add.reducenp.multiply.reduce

NumPy 或 Python 整數純量值的顯著行為#

NumPy 提升指的是結果 dtype 和運算精度,但運算有時會決定結果。除法始終傳回浮點數值,而比較始終傳回布林值。

這導致了可能看起來像是規則「例外」的情況

  • NumPy 與 Python 整數或混合精度整數的比較始終傳回正確的結果。輸入永遠不會以損失精度的方式轉換。

  • 無法提升的類型之間的相等比較將被視為全部 False(相等)或全部 True(不相等)。

  • 一元數學函式(如 np.sin)始終傳回浮點數值,它們接受任何 Python 整數輸入,方法是將其轉換為 float64

  • 除法始終傳回浮點數值,因此也允許任何 NumPy 整數與任何 Python 整數值之間的除法,方法是將兩者都轉換為 float64

原則上,其中一些例外情況對於其他函式可能是有意義的。如果您認為是這種情況,請提出 issue。

非數值資料類型的提升#

NumPy 將提升擴展到非數值類型,儘管在許多情況下,提升未明確定義,並且會被直接拒絕。

以下規則適用

  • NumPy 位元組字串 (np.bytes_) 可以提升為 Unicode 字串 (np.str_)。但是,對於非 ASCII 字元,將位元組轉換為 Unicode 將會失敗。

  • 對於某些目的,NumPy 會將幾乎任何其他資料類型提升為字串。這適用於陣列建立或串連。

  • 陣列建構函式(如 np.array())在沒有可行的提升時,將使用 object dtype。

  • 當結構化 dtype 的欄位名稱和順序匹配時,它們可以提升。在這種情況下,所有欄位都會單獨提升。

  • NumPy timedelta 在某些情況下可以與整數提升。

注意

其中一些規則有些令人驚訝,並且正在考慮在未來進行更改。但是,任何向後不相容的更改都必須權衡打破現有程式碼的風險。如果您對提升應如何運作有任何特定想法,請提出 issue。

已提升 dtype 實例的詳細資訊#

以上討論主要處理了混合不同 DType 類別時的行為。附加到陣列的 dtype 實例可以攜帶額外資訊,例如位元組順序、中繼資料、字串長度或精確的結構化 dtype 佈局。

雖然結構化 dtype 的字串長度或欄位名稱很重要,但 NumPy 將位元組順序、中繼資料和結構化 dtype 的精確佈局視為儲存細節。在提升期間,NumPy 考慮這些儲存細節:* 位元組順序會轉換為原生位元組順序。* 附加到 dtype 的中繼資料可能會或可能不會保留。* 結果結構化 dtype 將被打包(但如果輸入已對齊,則會對齊)。

對於大多數程式來說,這種行為是最佳行為,因為儲存細節與最終結果無關,並且使用不正確的位元組順序可能會大大減慢評估速度。