utils.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # -------------------------------------------------------------------------
  2. # MIT License
  3. #
  4. # Copyright (c) 2021 OpenAI
  5. #
  6. # Permission is hereby granted, free of charge, to any person obtaining a copy
  7. # of this software and associated documentation files (the "Software"), to deal
  8. # in the Software without restriction, including without limitation the rights
  9. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  10. # copies of the Software, and to permit persons to whom the Software is
  11. # furnished to do so, subject to the following conditions:
  12. #
  13. # The above copyright notice and this permission notice shall be included in all
  14. # copies or substantial portions of the Software.
  15. #
  16. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  17. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  18. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  19. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  20. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  21. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  22. # SOFTWARE.
  23. #
  24. # Modified by Jiarui Xu
  25. # -------------------------------------------------------------------------
  26. from collections import OrderedDict
  27. import torch
  28. from torch import nn
  29. class QuickGELU(nn.Module):
  30. def forward(self, x: torch.Tensor):
  31. return x * torch.sigmoid(1.702 * x)
  32. class ResidualAttentionBlock(nn.Module):
  33. def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
  34. super().__init__()
  35. self.attn = nn.MultiheadAttention(d_model, n_head)
  36. self.ln_1 = nn.LayerNorm(d_model)
  37. self.mlp = nn.Sequential(OrderedDict([
  38. ('c_fc', nn.Linear(d_model, d_model * 4)),
  39. ('gelu', QuickGELU()),
  40. ('c_proj', nn.Linear(d_model * 4, d_model))]))
  41. self.ln_2 = nn.LayerNorm(d_model)
  42. self.attn_mask = attn_mask
  43. def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor):
  44. self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
  45. return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0]
  46. def forward(self, x: torch.Tensor, key_padding_mask=None):
  47. x = x + self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)
  48. x = x + self.mlp(self.ln_2(x))
  49. return x