# ------------------------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # # MIT License # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE # # Written by Ze Liu, Zhenda Xie # Modified by Jiarui Xu # ------------------------------------------------------------------------- from torch import optim as optim def build_optimizer(config, model): """Build optimizer, set weight decay of normalization to 0 by default.""" parameters = set_weight_decay(model, {}, {}) opt_name = config.optimizer.name optimizer = None if opt_name == 'adamw': optimizer = optim.AdamW( parameters, eps=config.optimizer.eps, betas=config.optimizer.betas, lr=config.base_lr, weight_decay=config.weight_decay) else: raise ValueError(f'Unsupported optimizer: {opt_name}') return optimizer def set_weight_decay(model, skip_list=(), skip_keywords=()): has_decay = [] no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith('.bias') or (name in skip_list) or \ check_keywords_in_name(name, skip_keywords): no_decay.append(param) # print(f"{name} has no weight decay") else: has_decay.append(param) return [{'params': has_decay}, {'params': no_decay, 'weight_decay': 0.}] def check_keywords_in_name(name, keywords=()): isin = False for keyword in keywords: if keyword in name: isin = True return isin