|
@@ -107,6 +107,23 @@ def train(cfg):
|
|
|
|
|
|
logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
|
|
|
model = build_model(cfg.model)
|
|
|
+
|
|
|
+ # load_checkpoint(cfg, model, None, None)
|
|
|
+
|
|
|
+ # 冻结所有层
|
|
|
+ for param in model.parameters():
|
|
|
+ param.requires_grad = False
|
|
|
+
|
|
|
+ # 如果你只想冻结特定的层,可以按照以下方式进行
|
|
|
+ # 例如,冻结所有的 img_projector 层
|
|
|
+ for param in model.img_projector.parameters():
|
|
|
+ param.requires_grad = True
|
|
|
+
|
|
|
+ # 如果你只想冻结特定的层,可以按照以下方式进行
|
|
|
+ # 例如,冻结所有的 text_projector 层
|
|
|
+ for param in model.text_projector.parameters():
|
|
|
+ param.requires_grad = True
|
|
|
+
|
|
|
model.cuda()
|
|
|
logger.info(str(model))
|
|
|
|