transformer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # -------------------------------------------------------------------------
  2. # MIT License
  3. #
  4. # Copyright (c) 2021 OpenAI
  5. #
  6. # Permission is hereby granted, free of charge, to any person obtaining a copy
  7. # of this software and associated documentation files (the "Software"), to deal
  8. # in the Software without restriction, including without limitation the rights
  9. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  10. # copies of the Software, and to permit persons to whom the Software is
  11. # furnished to do so, subject to the following conditions:
  12. #
  13. # The above copyright notice and this permission notice shall be included in all
  14. # copies or substantial portions of the Software.
  15. #
  16. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  17. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  18. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  19. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  20. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  21. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  22. # SOFTWARE.
  23. #
  24. # Modified by Jiarui Xu
  25. # -------------------------------------------------------------------------
  26. # Modified by Jilan Xu
  27. # -------------------------------------------------------------------------
  28. import torch
  29. import torch.utils.checkpoint as checkpoint
  30. from torch import nn
  31. from .builder import MODELS
  32. from .misc import Result
  33. from .utils import ResidualAttentionBlock
  34. from ipdb import set_trace
  35. import clip
  36. from transformers import AutoModel
  37. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  38. class Transformer(nn.Module):
  39. def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_checkpoint=False):
  40. super().__init__()
  41. self.width = width
  42. self.layers = layers
  43. self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
  44. proj_std = (self.width**-0.5) * ((2 * self.layers)**-0.5)
  45. attn_std = self.width**-0.5
  46. fc_std = (2 * self.width)**-0.5
  47. for block in self.resblocks:
  48. nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
  49. nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
  50. nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
  51. nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
  52. self.use_checkpoint = use_checkpoint
  53. def forward(self, x: torch.Tensor):
  54. for i, resblock in enumerate(self.resblocks):
  55. if self.use_checkpoint:
  56. x = checkpoint.checkpoint(resblock, x)
  57. else:
  58. x = resblock(x)
  59. return x
  60. @MODELS.register_module()
  61. class DistilBert(nn.Module):
  62. def __init__(
  63. self,
  64. context_length: int,
  65. width: int,
  66. layers: int,
  67. vocab_size,
  68. use_checkpoint=False,
  69. pretrained=True,
  70. fixed=True,
  71. ):
  72. super().__init__()
  73. self.transformer = AutoModel.from_pretrained('distilbert-base-uncased', output_hidden_states=True)
  74. self.transformer.train()
  75. self.width = width
  76. if fixed is True:
  77. for p in self.transformer.parameters():
  78. p.requires_grad = False
  79. if pretrained is False:
  80. self.apply(self._init_weights)
  81. def _init_weights(self, m):
  82. if isinstance(m, nn.Linear):
  83. trunc_normal_(m.weight, std=.02)
  84. if isinstance(m, nn.Linear) and m.bias is not None:
  85. nn.init.constant_(m.bias, 0)
  86. elif isinstance(m, nn.LayerNorm):
  87. nn.init.constant_(m.bias, 0)
  88. nn.init.constant_(m.weight, 1.0)
  89. def forward(self, x, as_dict=True):
  90. outs = Result(as_dict=as_dict)
  91. out_x = self.transformer(**x)
  92. out_hidden = out_x.last_hidden_state[:, 0, :]
  93. last_hidden = out_x.hidden_states[-1]
  94. outs.append(out_hidden, name='x')
  95. outs.append(last_hidden, name='all_tokens')
  96. return outs.as_return()
  97. @MODELS.register_module()
  98. class Bert(nn.Module):
  99. def __init__(
  100. self,
  101. context_length: int,
  102. width: int,
  103. layers: int,
  104. vocab_size,
  105. use_checkpoint=False,
  106. pretrained=True,
  107. fixed=True,
  108. ):
  109. super().__init__()
  110. self.transformer = AutoModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
  111. self.transformer.train()
  112. self.width = width
  113. if fixed is True:
  114. for p in self.transformer.parameters():
  115. p.requires_grad = False
  116. if pretrained is False:
  117. self.apply(self._init_weights)
  118. def _init_weights(self, m):
  119. if isinstance(m, nn.Linear):
  120. trunc_normal_(m.weight, std=.02)
  121. if isinstance(m, nn.Linear) and m.bias is not None:
  122. nn.init.constant_(m.bias, 0)
  123. elif isinstance(m, nn.LayerNorm):
  124. nn.init.constant_(m.bias, 0)
  125. nn.init.constant_(m.weight, 1.0)
  126. def forward(self, x, as_dict=True):
  127. outs = Result(as_dict=as_dict)
  128. out_x = self.transformer(**x)
  129. out_hidden = out_x.last_hidden_state[:, 0, :]
  130. last_hidden = out_x.hidden_states[-1]
  131. outs.append(out_hidden, name='x')
  132. outs.append(last_hidden, name='all_tokens')
  133. return outs.as_return()
  134. @MODELS.register_module()
  135. class Roberta(nn.Module):
  136. def __init__(
  137. self,
  138. context_length: int,
  139. width: int,
  140. layers: int,
  141. vocab_size,
  142. use_checkpoint=False,
  143. pretrained=True,
  144. fixed=True,
  145. ):
  146. super().__init__()
  147. self.transformer = AutoModel.from_pretrained('roberta-base', output_hidden_states=True, cache_dir='/mnt/petrelfs/xujilan/checkpoints/')
  148. self.transformer.train()
  149. self.width = width
  150. if fixed is True:
  151. for p in self.transformer.parameters():
  152. p.requires_grad = False
  153. if pretrained is False:
  154. self.apply(self._init_weights)
  155. def _init_weights(self, m):
  156. if isinstance(m, nn.Linear):
  157. trunc_normal_(m.weight, std=.02)
  158. if isinstance(m, nn.Linear) and m.bias is not None:
  159. nn.init.constant_(m.bias, 0)
  160. elif isinstance(m, nn.LayerNorm):
  161. nn.init.constant_(m.bias, 0)
  162. nn.init.constant_(m.weight, 1.0)
  163. def forward(self, x, question=None, as_dict=True):
  164. outs = Result(as_dict=as_dict)
  165. out_x = self.transformer(**x)
  166. out_hidden = out_x.last_hidden_state[:, 0, :]
  167. last_hidden = out_x.hidden_states[-1]
  168. outs.append(out_hidden, name='x')
  169. outs.append(last_hidden, name='all_tokens')
  170. return outs.as_return()
  171. @MODELS.register_module()
  172. class BertMedium(nn.Module):
  173. def __init__(
  174. self,
  175. context_length: int,
  176. width: int,
  177. layers: int,
  178. vocab_size,
  179. use_checkpoint=False,
  180. pretrained=True,
  181. fixed=True,
  182. ):
  183. super().__init__()
  184. self.transformer = AutoModel.from_pretrained('prajjwal1/bert-medium', output_hidden_states=True)
  185. self.transformer.train()
  186. self.width = width
  187. if fixed is True:
  188. for p in self.transformer.parameters():
  189. p.requires_grad = False
  190. if pretrained is False:
  191. self.apply(self._init_weights)
  192. def _init_weights(self, m):
  193. if isinstance(m, nn.Linear):
  194. trunc_normal_(m.weight, std=.02)
  195. if isinstance(m, nn.Linear) and m.bias is not None:
  196. nn.init.constant_(m.bias, 0)
  197. elif isinstance(m, nn.LayerNorm):
  198. nn.init.constant_(m.bias, 0)
  199. nn.init.constant_(m.weight, 1.0)
  200. def forward(self, x, as_dict=True):
  201. outs = Result(as_dict=as_dict)
  202. out_x = self.transformer(**x)
  203. out_hidden = out_x.last_hidden_state[:, 0, :]
  204. last_hidden = out_x.hidden_states[-1]
  205. outs.append(out_hidden, name='x')
  206. outs.append(last_hidden, name='all_tokens')
  207. return outs.as_return()
  208. @MODELS.register_module()
  209. class TextTransformer(nn.Module):
  210. def __init__(
  211. self,
  212. context_length: int,
  213. width: int,
  214. layers: int,
  215. vocab_size,
  216. use_checkpoint=False,
  217. pretrained=True,
  218. fixed=True,
  219. ):
  220. super().__init__()
  221. heads = width // 64
  222. self.context_length = context_length
  223. self.width = width
  224. self.transformer = Transformer(
  225. width=width,
  226. layers=layers,
  227. heads=heads,
  228. attn_mask=self.build_attention_mask(),
  229. use_checkpoint=use_checkpoint)
  230. self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
  231. self.ln_final = nn.LayerNorm(width)
  232. self.token_embedding = nn.Embedding(vocab_size, width)
  233. nn.init.normal_(self.token_embedding.weight, std=0.02)
  234. clip_model, _ = clip.load('ViT-B/16', device='cuda', jit=False)
  235. self.text_projection = nn.Parameter(torch.empty(clip_model.text_projection.shape))
  236. nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
  237. # initialization
  238. nn.init.normal_(self.positional_embedding, std=0.01)
  239. if pretrained:
  240. print('loading clip weights for text encoder')
  241. self.reload_clip_weights(clip_model)
  242. if fixed:
  243. print('freezing text encoder')
  244. self.freeze_text_encoder()
  245. def freeze_text_encoder(self):
  246. for p in self.parameters():
  247. p.requires_grad=False
  248. def reload_clip_weights(self, clip_model):
  249. text_dict = clip_model.state_dict()
  250. msg = self.load_state_dict(text_dict, strict=False)
  251. def build_attention_mask(self):
  252. # lazily create causal attention mask, with full attention between the vision tokens
  253. # pytorch uses additive attention mask; fill with -inf
  254. mask = torch.empty(self.context_length, self.context_length)
  255. mask.fill_(float('-inf'))
  256. mask.triu_(1) # zero out the lower diagonal
  257. return mask
  258. def forward(self, text, *, as_dict=True):
  259. x = self.token_embedding(text)
  260. outs = Result(as_dict=as_dict)
  261. x = x + self.positional_embedding
  262. x = x.permute(1, 0, 2) # NLD -> LND
  263. x = self.transformer(x)
  264. x = x.permute(1, 0, 2) # LND -> NLD
  265. x = self.ln_final(x)
  266. ### w/o text projection ###
  267. # all_tokens = x.clone()
  268. # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
  269. ### w/ text projection ###
  270. all_tokens = x.clone() @ self.text_projection
  271. x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
  272. outs.append(x, name='x')
  273. outs.append(all_tokens, name='all_tokens')
  274. return outs.as_return()