Breadcrumb
【PyTorch 筆記】別再被 Dimension Error 搞瘋了,一次搞懂模型吃什麼格式
寫 PyTorch 最讓人崩潰的瞬間是什麼?絕對不是你的模型太難寫,而是當你興高采烈按下 Run,結果 Terminal 直接噴出一大串紅字 RuntimeError: Expected 4-dimensional input...。
當下真的只想摔鍵盤:「我就只有一張圖,為什麼還要我給 4 個維度啦?」
其實 PyTorch 的脾氣很硬,但摸順了就很簡單。這篇簡單整理一下它的「眉角」,讓你下次 Debug 可以少死幾個腦細胞。
1. 它是個有強迫症的傢伙:請給它 (N, C, H, W)
不管你今天是做一般圖片分類(2D CNN),還是要跑比較高大上的 3D 醫學影像,PyTorch 卷積層對於「輸入長怎樣」是有嚴格規定的。
這口訣請刻在腦海裡:NCHW。
-
N (Batch Size):最常被遺忘的傢伙。PyTorch 強制第一維一定要是「批次大小」。 哪怕你現在只是想測試一張貓咪的照片,你也不能只給它
(3, 224, 224)。 你必須騙它說:「嘿,這是一整批照片,雖然裡面只有一張。」變成(1, 3, 224, 224)。 -
C (Channel):通道數(RGB 就是 3,灰階就是 1)。
-
H, W:高跟寬。
(N, C, D, H, W)。
2. 常見的「形狀不對」與解決辦法
如果你是從 OpenCV 或是 PIL 讀圖進來,形狀通常都不會剛好符合 PyTorch 的胃口。這裡有幾個最常見的坑,以及怎麼把它們「捏」成對的形狀:
情況一:黑白圖/灰階圖 (只有高跟寬)
讀進來是 (H, W),什麼都沒有。
解法:包兩層皮。 你需要先幫它加上 Channel 軸,再加 Batch 軸。
# 假設 img 是 (28, 28) 的 Tensor
img = ...
# img shape: (28, 28)
input_tensor = img.unsqueeze(0).unsqueeze(0)
# 變成 (1, 1, 28, 28) -> 搞定!
情況二:一般的彩色圖 (缺 Batch)
如果你已經轉成 Tensor 了,通常是 (C, H, W)。
解法:加一層皮在最前面。
# 假設 img 是 (3, 224, 224) 的 Tensor
img = ...
# img shape: (3, 224, 224)
input_tensor = img.unsqueeze(0)
# 變成 (1, 3, 224, 224) -> 完美。
情況三:大魔王 OpenCV 格式 (H, W, C)
這是最多人踩的雷!OpenCV 讀進來的順序跟 PyTorch 是反的(Channel 在最後面),直接丟進去絕對報錯。 解法:先換位子,再加皮。
# 假設 img 是 (224, 224, 3) 的 Tensor
img = ...
# img shape: (224, 224, 3)
# 1. 先用 permute 把 Channel 搬到前面 (2, 0, 1)
# 2. 再用 unsqueeze 加 Batch
input_tensor = img.permute(2, 0, 1).unsqueeze(0)
3. 真的不知道錯哪? Print 就對了
說真的,我看過太多人(包含我自己)在想為什麼模型跑不動,盯著螢幕發呆。其實最快的方法就是在丟進模型前,直接把形狀印出來看。
# 假設 x 是你的輸入資料
x = ...
# 這是你最好的朋友
# print(x.shape) # 註解掉以避免 IDE 警告
快速判斷指南:
- 看到
[128, 128]這種只有兩個數字的 ❌ -> 絕對掛掉,缺太多東西了。 - 看到
[3, 128, 128]❌ -> 雖然有 Channel,但少了 Batch,訓練一定會出事。 - 看到
[1, 3, 128, 128]✅ -> 舒服,這才是 PyTorch 要的。