build.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from model import objectives
  2. from .clip_model import Transformer, QuickGELU, LayerNorm, build_CLIP_from_openai_pretrained, convert_weights
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from collections import OrderedDict
  7. class IRRA(nn.Module):
  8. def __init__(self, args, num_classes=11003):
  9. super().__init__()
  10. self.args = args
  11. self.num_classes = num_classes
  12. self._set_task()
  13. self.base_model, base_cfg = build_CLIP_from_openai_pretrained(args.pretrain_choice, args.img_size, args.stride_size)
  14. self.embed_dim = base_cfg['embed_dim']
  15. self.logit_scale = torch.ones([]) * (1 / args.temperature)
  16. if 'id' in args.loss_names:
  17. self.classifier = nn.Linear(self.embed_dim, self.num_classes)
  18. nn.init.normal_(self.classifier.weight.data, std=0.001)
  19. nn.init.constant_(self.classifier.bias.data, val=0.0)
  20. if 'mlm' in args.loss_names:
  21. self.cross_attn = nn.MultiheadAttention(self.embed_dim,
  22. self.embed_dim // 64,
  23. batch_first=True)
  24. self.cross_modal_transformer = Transformer(width=self.embed_dim,
  25. layers=args.cmt_depth,
  26. heads=self.embed_dim //
  27. 64)
  28. scale = self.cross_modal_transformer.width**-0.5
  29. self.ln_pre_t = LayerNorm(self.embed_dim)
  30. self.ln_pre_i = LayerNorm(self.embed_dim)
  31. self.ln_post = LayerNorm(self.embed_dim)
  32. proj_std = scale * ((2 * self.cross_modal_transformer.layers)**-0.5)
  33. attn_std = scale
  34. fc_std = (2 * self.cross_modal_transformer.width)**-0.5
  35. for block in self.cross_modal_transformer.resblocks:
  36. nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
  37. nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
  38. nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
  39. nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
  40. # init cross attn
  41. nn.init.normal_(self.cross_attn.in_proj_weight, std=attn_std)
  42. nn.init.normal_(self.cross_attn.out_proj.weight, std=proj_std)
  43. self.mlm_head = nn.Sequential(
  44. OrderedDict([('dense', nn.Linear(self.embed_dim, self.embed_dim)),
  45. ('gelu', QuickGELU()),
  46. ('ln', LayerNorm(self.embed_dim)),
  47. ('fc', nn.Linear(self.embed_dim, args.vocab_size))]))
  48. # init mlm head
  49. nn.init.normal_(self.mlm_head.dense.weight, std=fc_std)
  50. nn.init.normal_(self.mlm_head.fc.weight, std=proj_std)
  51. def _set_task(self):
  52. loss_names = self.args.loss_names
  53. self.current_task = [l.strip() for l in loss_names.split('+')]
  54. print(f'Training Model with {self.current_task} tasks')
  55. def cross_former(self, q, k, v):
  56. x = self.cross_attn(
  57. self.ln_pre_t(q),
  58. self.ln_pre_i(k),
  59. self.ln_pre_i(v),
  60. need_weights=False)[0]
  61. x = x.permute(1, 0, 2) # NLD -> LND
  62. x = self.cross_modal_transformer(x)
  63. x = x.permute(1, 0, 2) # LND -> NLD
  64. x = self.ln_post(x)
  65. return x
  66. def encode_image(self, image):
  67. x = self.base_model.encode_image(image)
  68. return x[:, 0, :].float()
  69. # return x.float() # for CLIP ResNet visual model
  70. def encode_text(self, text):
  71. x = self.base_model.encode_text(text)
  72. return x[torch.arange(x.shape[0]), text.argmax(dim=-1)].float()
  73. def forward(self, batch):
  74. ret = dict()
  75. images = batch['images']
  76. caption_ids = batch['caption_ids']
  77. image_feats, text_feats = self.base_model(images, caption_ids)
  78. i_feats = image_feats[:, 0, :].float()
  79. # i_feats = image_feats.float() # for CLIP ResNet visual model
  80. t_feats = text_feats[torch.arange(text_feats.shape[0]), caption_ids.argmax(dim=-1)].float()
  81. logit_scale = self.logit_scale
  82. ret.update({'temperature': 1 / logit_scale})
  83. if 'itc' in self.current_task:
  84. ret.update({'itc_loss':objectives.compute_itc(i_feats, t_feats, logit_scale)})
  85. if 'sdm' in self.current_task:
  86. ret.update({'sdm_loss':objectives.compute_sdm(i_feats, t_feats, batch['pids'], logit_scale)})
  87. if 'cmpm' in self.current_task:
  88. ret.update({'cmpm_loss':objectives.compute_cmpm(i_feats, t_feats, batch['pids'])})
  89. if 'id' in self.current_task:
  90. image_logits = self.classifier(i_feats.half()).float()
  91. text_logits = self.classifier(t_feats.half()).float()
  92. ret.update({'id_loss':objectives.compute_id(image_logits, text_logits, batch['pids'])*self.args.id_loss_weight})
  93. image_pred = torch.argmax(image_logits, dim=1)
  94. text_pred = torch.argmax(text_logits, dim=1)
  95. image_precision = (image_pred == batch['pids']).float().mean()
  96. text_precision = (text_pred == batch['pids']).float().mean()
  97. ret.update({'img_acc': image_precision})
  98. ret.update({'txt_acc': text_precision})
  99. if 'mlm' in self.current_task:
  100. mlm_ids = batch['mlm_ids']
  101. mlm_feats = self.base_model.encode_text(mlm_ids)
  102. x = self.cross_former(mlm_feats, image_feats, image_feats)
  103. x = self.mlm_head(x) # [batch_size, text_len, num_colors]
  104. scores = x.float().reshape(-1, self.args.vocab_size)
  105. mlm_labels = batch['mlm_labels'].reshape(-1)
  106. ret.update({'mlm_loss': objectives.compute_mlm(scores, mlm_labels)*self.args.mlm_loss_weight})
  107. pred = scores.max(1)[1]
  108. mlm_label_idx = torch.nonzero(mlm_labels)
  109. acc = (pred[mlm_label_idx] == mlm_labels[mlm_label_idx]).float().mean()
  110. ret.update({'mlm_acc': acc})
  111. return ret
  112. def build_model(args, num_classes=11003):
  113. model = IRRA(args, num_classes)
  114. # covert model to fp16
  115. convert_weights(model)
  116. return model