NumPy配列からサイズ1の次元を削除:np.squeeze()の徹底解説
データ分析や機械学習において、NumPyは欠かせないライブラリです。特にNumPy配列(ndarray)の操作は頻繁に行われます。今回は、配列からサイズ1の次元を効率的に削除できる**np.squeeze()**関数について詳しく解説します。
np.squeeze()とは?
np.squeeze()は、NumPy配列からサイズが1の次元(軸)をすべて削除するために使用される関数です。これにより、配列の形状(shape)をよりコンパクトにしたり、特定の操作に適した形に整形したりできます。特に、モデルの出力が不要な次元を持っている場合などに非常に役立ちます。
例えば、(1, 3, 1, 4)のような形状の配列があった場合、np.squeeze()を適用すると(3, 4)のようにサイズ1の次元が取り除かれます。
なぜnp.squeeze()が必要なのか?
NumPy配列は多次元のデータを効率的に扱うために設計されています。しかし、計算過程で意図せずサイズ1の次元が追加されてしまうことがあります。例えば、単一の画像に対するモデルの予測を行う場合、バッチ次元としてサイズ1の次元が追加され、形状が(1, 高さ, 幅, チャンネル数)のようになることがあります。このような場合、後続の処理でこの不要な次元が邪魔になることがあります。
np.squeeze()を使用することで、これらの不要な次元を自動的に削除し、データの扱いや後続の処理を簡素化できます。
np.squeeze()の基本的な使い方
np.squeeze()の基本的な使い方は非常にシンプルです。引数に次元を削除したいNumPy配列を渡すだけです。
import numpy as np
# 例1:複数のサイズ1次元を持つ配列
arr1 = np.array([[[[1, 2, 3]]]])
print(f"元の配列1の形状: {arr1.shape}")
# 元の配列1の形状: (1, 1, 1, 3)
squeezed_arr1 = np.squeeze(arr1)
print(f"squeeze後の配列1の形状: {squeezed_arr1.shape}")
# squeeze後の配列1の形状: (3,)
# 例2:特定の次元のみがサイズ1の配列
arr2 = np.array([[1, 2, 3]])
print(f"元の配列2の形状: {arr2.shape}")
# 元の配列2の形状: (1, 3)
squeezed_arr2 = np.squeeze(arr2)
print(f"squeeze後の配列2の形状: {squeezed_arr2.shape}")
# squeeze後の配列2の形状: (3,)
# 例3:サイズ1次元を持たない配列
arr3 = np.array([[1, 2], [3, 4]])
print(f"元の配列3の形状: {arr3.shape}")
# 元の配列3の形状: (2, 2)
squeezed_arr3 = np.squeeze(arr3)
print(f"squeeze後の配列3の形状: {squeezed_arr3.shape}")
# squeeze後の配列3の形状: (2, 2)
上記の例からわかるように、np.squeeze()はサイズ1の次元のみを削除し、それ以外の次元には影響を与えません。
axis引数で特定の次元を指定する
np.squeeze()は、デフォルトではすべてのサイズ1の次元を削除しますが、axis引数を使用することで特定の次元のみを対象とすることができます。これは、特定の次元だけを削除したい場合に便利です。
単一の次元を指定
import numpy as np
arr = np.array([[[1, 2, 3]]]) # 形状: (1, 1, 3)
print(f"元の配列の形状: {arr.shape}")
# 元の配列の形状: (1, 1, 3)
# 0番目の次元を削除
squeezed_axis0 = np.squeeze(arr, axis=0)
print(f"axis=0 でsqueeze後の形状: {squeezed_axis0.shape}")
# axis=0 でsqueeze後の形状: (1, 3)
# 1番目の次元を削除
squeezed_axis1 = np.squeeze(arr, axis=1)
print(f"axis=1 でsqueeze後の形状: {squeezed_axis1.shape}")
# axis=1 でsqueeze後の形状: (1, 3)
axisで指定した次元がサイズ1でない場合、または存在しない場合はエラーが発生します。
import numpy as np
arr_no_squeeze = np.array([[1, 2], [3, 4]]) # 形状: (2, 2)
print(f"元の配列の形状: {arr_no_squeeze.shape}")
# 元の配列の形状: (2, 2)
try:
np.squeeze(arr_no_squeeze, axis=0)
except ValueError as e:
print(f"エラー発生: {e}")
# エラー発生: cannot select an axis to squeeze out which has size not equal to one
複数の次元を指定
axis引数には、削除したい次元のインデックスをリストやタプルで渡すこともできます。
import numpy as np
arr = np.array([[[[1, 2, 3]]]]) # 形状: (1, 1, 1, 3)
print(f"元の配列の形状: {arr.shape}")
# 元の配列の形状: (1, 1, 1, 3)
# 0番目と2番目の次元を削除
squeezed_axes = np.squeeze(arr, axis=(0, 2))
print(f"axis=(0, 2) でsqueeze後の形状: {squeezed_axes.shape}")
# axis=(0, 2) でsqueeze後の形状: (1, 3)
np.squeeze()のユースケース
1. ディープラーニングのモデル出力の整形
ディープラーニングモデルは、入力や出力にバッチ次元(通常、最初の次元)を必要とすることがよくあります。単一のサンプルを扱う場合でも、形状は(1, ...) となります。推論結果からこのバッチ次元を取り除く際にnp.squeeze()が役立ちます。
import numpy as np
# モデルの出力例(バッチサイズ1)
model_output = np.random.rand(1, 10) # 形状: (1, 10)
print(f"モデル出力の形状: {model_output.shape}")
# モデル出力の形状: (1, 10)
# バッチ次元を削除
processed_output = np.squeeze(model_output)
print(f"処理後の出力形状: {processed_output.shape}")
# 処理後の出力形状: (10,)
2. データの前処理
特定のデータセットを扱う際、次元が不必要に増えてしまうことがあります。np.squeeze()を使うことで、より扱いやすい形状にデータを変換できます。
import numpy as np
# 10個の単一特徴量データ
data_with_extra_dim = np.array([[x] for x in range(10)]) # 形状: (10, 1)
print(f"元のデータの形状: {data_with_extra_dim.shape}")
# 元のデータの形状: (10, 1)
# 不要な次元を削除
flattened_data = np.squeeze(data_with_extra_dim)
print(f"フラット化されたデータの形状: {flattened_data.shape}")
# フラット化されたデータの形状: (10,)
np.squeeze()と他の形状操作関数の違い
NumPyにはnp.reshape(), np.flatten(), arr.ravel()など、配列の形状を変更する多くの関数があります。np.squeeze()はこれらの関数と何が違うのでしょうか?
-
np.reshape(): 任意の形状に変更できますが、元の要素数を変えることはできません。サイズ1でない次元も変更できます。 -
np.flatten()/arr.ravel(): 配列を1次元に平坦化します。np.squeeze()のようにサイズ1の次元のみを対象にするわけではありません。
np.squeeze()は、「サイズが1の次元」という特定の条件を満たす次元のみを削除するという点で、他の関数とは異なります。これにより、意図しない形状変更を防ぎつつ、不要な次元のみを効率的に取り除くことができます。
まとめ
np.squeeze()は、NumPy配列からサイズ1の次元を削除するための非常に便利な関数です。特にディープラーニングのモデル出力の整形やデータの前処理において、コードを簡潔にし、可読性を向上させることができます。ぜひ、あなたのNumPy操作に活用してみてください。

