yiskw note

勉強したことや日々の生活について気まぐれで書きます

pytorchで全結合層以外の重みを固定して学習し直す

概要

pytorchで転移学習を行う際に,全結合層以外の重みを固定するということをよくやるのですが,毎回やり方を忘れて調べていたので,備忘録としてこちらに残しておきます.

方法

モデルのパラメータのrequires_gradFalseに指定するだけです.

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18


model = resnet18()

# モデルの重みを固定
for param in model.parameters():
    param.requires_grad = False

# 新しいmoduleを定義した場合は,デフォルトで`requires_grad = True`
# なので,model.fcのみが学習される
model.fc = nn.Linear(model.fc.in_features, 10)

# optimizerにパラメータを渡すときは,以下のどちらでも大丈夫
optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

ちなみに,optimizer = optim.Adam(model.parameters(), lr=0.001)でも問題なく動くかを確認してみました.

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18


model = resnet18()

# モデルの重みを固定
for param in model.parameters():
    param.requires_grad = False

# 新しいmoduleを定義した場合は,デフォルトで`requires_grad = True`
# なので,model.fcのみが学習される
model.fc = nn.Linear(model.fc.in_features, 10)

# optimizerに全てのパラメータを渡しても全結合層のみが学習されることを確認する.
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 元の重みを保存しておく.
original_conv_weight = model.conv1.weight.clone()
original_fc_weight = model.fc.weight.clone()

# 擬似データの作成
img = torch.randn(2, 3, 56, 56)
label = torch.zeros(2, ).long()
criterion = nn.CrossEntropyLoss()

# 学習
output = model(img)
loss = criterion(output, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()

# 重みが更新されているかどうかの確認
print(torch.all(original_conv_weight == model.conv1.weight))
print(torch.all(original_fc_weight == model.fc.weight))

結果

tensor(True)
tensor(False)

ちゃんと全結合層だけが更新されていることが確認できました!

参考