pytorchで全結合層以外の重みを固定して学習し直す
概要
pytorchで転移学習を行う際に,全結合層以外の重みを固定するということをよくやるのですが,毎回やり方を忘れて調べていたので,備忘録としてこちらに残しておきます.
方法
モデルのパラメータのrequires_grad
をFalse
に指定するだけです.
import torch import torch.nn as nn import torch.optim as optim from torchvision.models import resnet18 model = resnet18() # モデルの重みを固定 for param in model.parameters(): param.requires_grad = False # 新しいmoduleを定義した場合は,デフォルトで`requires_grad = True` # なので,model.fcのみが学習される model.fc = nn.Linear(model.fc.in_features, 10) # optimizerにパラメータを渡すときは,以下のどちらでも大丈夫 optimizer = optim.Adam(model.parameters(), lr=0.001) optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
ちなみに,optimizer = optim.Adam(model.parameters(), lr=0.001)
でも問題なく動くかを確認してみました.
import torch import torch.nn as nn import torch.optim as optim from torchvision.models import resnet18 model = resnet18() # モデルの重みを固定 for param in model.parameters(): param.requires_grad = False # 新しいmoduleを定義した場合は,デフォルトで`requires_grad = True` # なので,model.fcのみが学習される model.fc = nn.Linear(model.fc.in_features, 10) # optimizerに全てのパラメータを渡しても全結合層のみが学習されることを確認する. optimizer = optim.Adam(model.parameters(), lr=0.001) # 元の重みを保存しておく. original_conv_weight = model.conv1.weight.clone() original_fc_weight = model.fc.weight.clone() # 擬似データの作成 img = torch.randn(2, 3, 56, 56) label = torch.zeros(2, ).long() criterion = nn.CrossEntropyLoss() # 学習 output = model(img) loss = criterion(output, label) loss.backward() optimizer.step() optimizer.zero_grad() # 重みが更新されているかどうかの確認 print(torch.all(original_conv_weight == model.conv1.weight)) print(torch.all(original_fc_weight == model.fc.weight))
結果
tensor(True) tensor(False)
ちゃんと全結合層だけが更新されていることが確認できました!
参考
- How the pytorch freeze network in some layers, only the rest of the training? - #2 by L0SG - PyTorch Forums
- PyTorch (8) Transfer Learning (Ants and Bees) - 人工知能に関する断創録
リンク
リンク