[PyTorch] nearestとnearest-exact

挙動確認の時にnearestの動きに違和感を覚えました。

というのも、例えば9×9サイズのデータを1/3にスケールダウンした場合、以下の結果となったからです。

テンソルの準備

x = torch.arange(0, 81, dtype=torch.float32).view(-1, 1, 9, 9)

# tensor([[[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.],
#           [ 9., 10., 11., 12., 13., 14., 15., 16., 17.],
#           [18., 19., 20., 21., 22., 23., 24., 25., 26.],
#           [27., 28., 29., 30., 31., 32., 33., 34., 35.],
#           [36., 37., 38., 39., 40., 41., 42., 43., 44.],
#           [45., 46., 47., 48., 49., 50., 51., 52., 53.],
#           [54., 55., 56., 57., 58., 59., 60., 61., 62.],
#           [63., 64., 65., 66., 67., 68., 69., 70., 71.],
#           [72., 73., 74., 75., 76., 77., 78., 79., 80.]]]])

mode=’nearest’の結果

scaled_x = F.interpolate(x, size=None, scale_factor=1/3, mode='nearest')

# tensor([[[[ 0.,  3.,  6.],
#           [27., 30., 33.],
#           [54., 57., 60.]]]])

ダウンサンプリング後の座標(0, 0)に格納されるべき値は、10のほうが適切なのではないでしょうか?

気になって調べてみたところ、PyTorch公式に以下の記述がありました。どうやら、mode=’nearest-exact’というものも選択できるようです。

Mode mode=’nearest-exact’ matches Scikit-Image and PIL nearest neighbours interpolation algorithms and fixes known issues with mode=’nearest’. This mode is introduced to keep backward compatibility. Mode mode=’nearest’ matches buggy OpenCV’s INTER_NEAREST interpolation algorithm.

mode=’nearest-exact’の結果

試しに使ってみた結果が以下で、感覚的にリーズナブルな結果になりました。

scaled_x = F.interpolate(x, size=None, scale_factor=1/3, mode='nearest-exact')

# tensor([[[[10., 13., 16.],
#           [37., 40., 43.],
#           [64., 67., 70.]]]])

ちなみに、このmodeはUpsampleでも同様に指定することができます。

upsample = nn.Upsample(scale_factor=1/3, mode='nearest-exact')
scaled_x = upsample(x)

# tensor([[[[10., 13., 16.],
#           [37., 40., 43.],
#           [64., 67., 70.]]]])

OpenCVのバグ?

この記載を見ると、mode=’nearest’はOpenCVのINTER_NEARESTのバグに合わせるための仕様であるとのことです。

ほんとに?とチェックしてみました。

配列の準備

テンソルと同じ値が格納された配列を準備します。

x = np.arange(0, 81).astype(np.float32).reshape(-1, 9)

# [[ 0.  1.  2.  3.  4.  5.  6.  7.  8.]
#  [ 9. 10. 11. 12. 13. 14. 15. 16. 17.]
#  [18. 19. 20. 21. 22. 23. 24. 25. 26.]
#  [27. 28. 29. 30. 31. 32. 33. 34. 35.]
#  [36. 37. 38. 39. 40. 41. 42. 43. 44.]
#  [45. 46. 47. 48. 49. 50. 51. 52. 53.]
#  [54. 55. 56. 57. 58. 59. 60. 61. 62.]
#  [63. 64. 65. 66. 67. 68. 69. 70. 71.]
#  [72. 73. 74. 75. 76. 77. 78. 79. 80.]]

実験

cv2.resize()でinterpolation=cv2.INTER_NEARESTを指定してその結果を確認しました。

x_scaled = cv2.resize(x, dsize=None, fx=1/3, fy=1/3, interpolation=cv2.INTER_NEAREST)

# [[ 0.  3.  6.]
#  [27. 30. 33.]
#  [54. 57. 60.]]

本当だ!!

Scikit-learnとPILはmode=’area-exact’の動作と同じだよ、と書いてありますので、そっちもついでにチェックしました。

  • PIL
x_pil = Image.fromarray(x)
x_scaled_pil = x_pil.resize((3, 3), resample=Image.Resampling.NEAREST)

# [[10. 13. 16.]
#  [37. 40. 43.]
#  [64. 67. 70.]]

ちなみに、resampleの指定は現状、resample=Image.NEARESTでも可能ですが、今後変更されるようなので、上記を使っておいたほうが安心かと思います。以下は、処理を流した時にWarningで出た文面です。

check_opencv.py:12: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead. x_scaled_pil = x_pil.resize((3, 3), resample=Image.NEAREST)

  • Scikit-learn
from skimage.transform import rescale, resize, downscale_local_mean

image_resized = resize(x, (3, 3), order=0, anti_aliasing=False)
# anti_aliasingはデフォルトでTrueになっているので
# 純粋なNearestの結果を取得する場合はFalseに設定する必要あり

# [[10. 13. 16.]
#  [37. 40. 43.]
#  [64. 67. 70.]]

確かに、OpenCVとPIL/Scikit-learnの挙動が違います。

気にしたこともなかった……

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です