clip_model.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. """ CLIP Model
  2. Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
  3. """
  4. from collections import OrderedDict
  5. import logging
  6. import math
  7. import os
  8. from typing import List, Tuple, Union
  9. import hashlib
  10. import urllib
  11. from tqdm import tqdm
  12. import warnings
  13. import numpy as np
  14. import torch
  15. import torch.nn.functional as F
  16. from torch import nn
  17. logger = logging.getLogger("IRRA.model")
  18. _MODELS = {
  19. "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
  20. "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
  21. "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
  22. "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
  23. "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
  24. "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
  25. "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
  26. "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
  27. }
  28. def available_models() -> List[str]:
  29. """Returns the names of available CLIP models"""
  30. return list(_MODELS.keys())
  31. def _download(url: str, root: str):
  32. os.makedirs(root, exist_ok=True)
  33. filename = os.path.basename(url)
  34. expected_sha256 = url.split("/")[-2]
  35. download_target = os.path.join(root, filename)
  36. if os.path.exists(download_target) and not os.path.isfile(download_target):
  37. raise RuntimeError(f"{download_target} exists and is not a regular file")
  38. if os.path.isfile(download_target):
  39. if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
  40. return download_target
  41. else:
  42. warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
  43. with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
  44. with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
  45. while True:
  46. buffer = source.read(8192)
  47. if not buffer:
  48. break
  49. output.write(buffer)
  50. loop.update(len(buffer))
  51. if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
  52. raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
  53. return download_target
  54. class Bottleneck(nn.Module):
  55. expansion = 4
  56. def __init__(self, inplanes, planes, stride=1):
  57. super().__init__()
  58. # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
  59. self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
  60. self.bn1 = nn.BatchNorm2d(planes)
  61. self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
  62. self.bn2 = nn.BatchNorm2d(planes)
  63. self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
  64. self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
  65. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  66. self.relu = nn.ReLU(inplace=True)
  67. self.downsample = None
  68. self.stride = stride
  69. if stride > 1 or inplanes != planes * Bottleneck.expansion:
  70. # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
  71. self.downsample = nn.Sequential(OrderedDict([
  72. ("-1", nn.AvgPool2d(stride)),
  73. ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
  74. ("1", nn.BatchNorm2d(planes * self.expansion))
  75. ]))
  76. def forward(self, x: torch.Tensor):
  77. identity = x
  78. out = self.relu(self.bn1(self.conv1(x)))
  79. out = self.relu(self.bn2(self.conv2(out)))
  80. out = self.avgpool(out)
  81. out = self.bn3(self.conv3(out))
  82. if self.downsample is not None:
  83. identity = self.downsample(x)
  84. out += identity
  85. out = self.relu(out)
  86. return out
  87. class AttentionPool2d(nn.Module):
  88. def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
  89. super().__init__()
  90. # self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
  91. self.positional_embedding = nn.Parameter(torch.randn((spacial_dim[0] * spacial_dim[1]) + 1, embed_dim)/ embed_dim ** 0.5)
  92. self.k_proj = nn.Linear(embed_dim, embed_dim)
  93. self.q_proj = nn.Linear(embed_dim, embed_dim)
  94. self.v_proj = nn.Linear(embed_dim, embed_dim)
  95. self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
  96. self.num_heads = num_heads
  97. def forward(self, x):
  98. x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
  99. x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
  100. x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
  101. x, _ = F.multi_head_attention_forward(
  102. query=x, key=x, value=x,
  103. embed_dim_to_check=x.shape[-1],
  104. num_heads=self.num_heads,
  105. q_proj_weight=self.q_proj.weight,
  106. k_proj_weight=self.k_proj.weight,
  107. v_proj_weight=self.v_proj.weight,
  108. in_proj_weight=None,
  109. in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
  110. bias_k=None,
  111. bias_v=None,
  112. add_zero_attn=False,
  113. dropout_p=0,
  114. out_proj_weight=self.c_proj.weight,
  115. out_proj_bias=self.c_proj.bias,
  116. use_separate_proj_weight=True,
  117. training=self.training,
  118. need_weights=False
  119. )
  120. return x[0]
  121. class ModifiedResNet(nn.Module):
  122. """
  123. A ResNet class that is similar to torchvision's but contains the following changes:
  124. - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
  125. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
  126. - The final pooling layer is a QKV attention instead of an average pool
  127. """
  128. def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
  129. super().__init__()
  130. self.output_dim = output_dim
  131. self.input_resolution = input_resolution
  132. # the 3-layer stem
  133. self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
  134. self.bn1 = nn.BatchNorm2d(width // 2)
  135. self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
  136. self.bn2 = nn.BatchNorm2d(width // 2)
  137. self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
  138. self.bn3 = nn.BatchNorm2d(width)
  139. self.avgpool = nn.AvgPool2d(2)
  140. self.relu = nn.ReLU(inplace=True)
  141. # residual layers
  142. self._inplanes = width # this is a *mutable* variable used during construction
  143. self.layer1 = self._make_layer(width, layers[0])
  144. self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
  145. self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
  146. self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
  147. embed_dim = width * 32 # the ResNet feature dimension
  148. spacial_dim = (
  149. input_resolution[0] // 32,
  150. input_resolution[1] // 32,
  151. )
  152. self.attnpool = AttentionPool2d(spacial_dim, embed_dim, heads, output_dim)
  153. def _make_layer(self, planes, blocks, stride=1):
  154. layers = [Bottleneck(self._inplanes, planes, stride)]
  155. self._inplanes = planes * Bottleneck.expansion
  156. for _ in range(1, blocks):
  157. layers.append(Bottleneck(self._inplanes, planes))
  158. return nn.Sequential(*layers)
  159. def forward(self, x):
  160. def stem(x):
  161. for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
  162. x = self.relu(bn(conv(x)))
  163. x = self.avgpool(x)
  164. return x
  165. x = x.type(self.conv1.weight.dtype)
  166. x = stem(x)
  167. x = self.layer1(x)
  168. x = self.layer2(x)
  169. x = self.layer3(x)
  170. x = self.layer4(x)
  171. x = self.attnpool(x)
  172. return x
  173. class LayerNorm(nn.LayerNorm):
  174. """Subclass torch's LayerNorm to handle fp16."""
  175. def forward(self, x: torch.Tensor):
  176. orig_type = x.dtype
  177. ret = super().forward(x.type(torch.float32))
  178. return ret.type(orig_type)
  179. class QuickGELU(nn.Module):
  180. def forward(self, x: torch.Tensor):
  181. return x * torch.sigmoid(1.702 * x)
  182. class ResidualAttentionBlock(nn.Module):
  183. def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
  184. super().__init__()
  185. self.attn = nn.MultiheadAttention(d_model, n_head)
  186. self.ln_1 = LayerNorm(d_model)
  187. self.mlp = nn.Sequential(OrderedDict([
  188. ("c_fc", nn.Linear(d_model, d_model * 4)),
  189. ("gelu", QuickGELU()),
  190. ("c_proj", nn.Linear(d_model * 4, d_model))
  191. ]))
  192. self.ln_2 = LayerNorm(d_model)
  193. self.attn_mask = attn_mask
  194. def attention(self, x: torch.Tensor):
  195. self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
  196. return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
  197. def forward(self, x: torch.Tensor):
  198. x = x + self.attention(self.ln_1(x))
  199. x = x + self.mlp(self.ln_2(x))
  200. return x
  201. class Transformer(nn.Module):
  202. def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
  203. super().__init__()
  204. self.width = width
  205. self.layers = layers
  206. self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
  207. def forward(self, x: torch.Tensor):
  208. return self.resblocks(x)
  209. class VisionTransformer(nn.Module):
  210. def __init__(self, input_resolution: Tuple[int, int], patch_size: int, stride_size: int, width: int, layers: int, heads: int, output_dim: int):
  211. super().__init__()
  212. self.input_resolution = input_resolution # (384, 128)
  213. self.num_x = (input_resolution[1] - patch_size) // stride_size + 1
  214. self.num_y = (input_resolution[0] - patch_size) // stride_size + 1
  215. num_patches = self.num_x * self.num_y
  216. self.output_dim = output_dim
  217. self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=stride_size, bias=False)
  218. scale = width ** -0.5 # 1/sqrt(768)
  219. self.class_embedding = nn.Parameter(scale * torch.randn(width))
  220. self.positional_embedding = nn.Parameter(scale * torch.randn(num_patches + 1, width))
  221. self.ln_pre = LayerNorm(width)
  222. self.transformer = Transformer(width, layers, heads)
  223. self.ln_post = LayerNorm(width)
  224. self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
  225. def forward(self, x: torch.Tensor):
  226. x = self.conv1(x) # shape = [*, width, grid, grid]
  227. x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
  228. x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
  229. x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
  230. x = x + self.positional_embedding.to(x.dtype)
  231. x = self.ln_pre(x)
  232. x = x.permute(1, 0, 2) # NLD -> LND
  233. x = self.transformer(x)
  234. x = x.permute(1, 0, 2) # LND -> NLD
  235. # x = self.ln_post(x[:, 0, :])
  236. x = self.ln_post(x)
  237. if self.proj is not None:
  238. x = x @ self.proj
  239. return x
  240. class CLIP(nn.Module):
  241. def __init__(self,
  242. embed_dim: int,
  243. # vision
  244. image_resolution: Union[int, Tuple[int, int]],
  245. vision_layers: Union[Tuple[int, int, int, int], int],
  246. vision_width: int,
  247. vision_patch_size: int,
  248. stride_size: int,
  249. # text
  250. context_length: int,
  251. vocab_size: int,
  252. transformer_width: int,
  253. transformer_heads: int,
  254. transformer_layers: int
  255. ):
  256. super().__init__()
  257. self.context_length = context_length
  258. if isinstance(vision_layers, (tuple, list)):
  259. vision_heads = vision_width * 32 // 64
  260. self.visual = ModifiedResNet(
  261. layers=vision_layers,
  262. output_dim=embed_dim,
  263. heads=vision_heads,
  264. input_resolution=image_resolution,
  265. width=vision_width
  266. )
  267. else:
  268. vision_heads = vision_width // 64
  269. self.visual = VisionTransformer(
  270. input_resolution=image_resolution,
  271. patch_size=vision_patch_size,
  272. stride_size=stride_size,
  273. width=vision_width,
  274. layers=vision_layers,
  275. heads=vision_heads,
  276. output_dim=embed_dim
  277. )
  278. self.transformer = Transformer(
  279. width=transformer_width,
  280. layers=transformer_layers,
  281. heads=transformer_heads,
  282. attn_mask=self.build_attention_mask()
  283. )
  284. self.vocab_size = vocab_size
  285. self.token_embedding = nn.Embedding(vocab_size, transformer_width)
  286. self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
  287. self.ln_final = LayerNorm(transformer_width)
  288. self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
  289. # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
  290. self.initialize_parameters()
  291. def initialize_parameters(self):
  292. nn.init.normal_(self.token_embedding.weight, std=0.02)
  293. nn.init.normal_(self.positional_embedding, std=0.01)
  294. if isinstance(self.visual, ModifiedResNet):
  295. if self.visual.attnpool is not None:
  296. std = self.visual.attnpool.c_proj.in_features ** -0.5
  297. nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
  298. nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
  299. nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
  300. nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
  301. for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
  302. for name, param in resnet_block.named_parameters():
  303. if name.endswith("bn3.weight"):
  304. nn.init.zeros_(param)
  305. proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
  306. attn_std = self.transformer.width ** -0.5
  307. fc_std = (2 * self.transformer.width) ** -0.5
  308. for block in self.transformer.resblocks:
  309. nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
  310. nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
  311. nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
  312. nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
  313. if self.text_projection is not None:
  314. nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
  315. def build_attention_mask(self):
  316. # lazily create causal attention mask, with full attention between the vision tokens
  317. # pytorch uses additive attention mask; fill with -inf
  318. mask = torch.empty(self.context_length, self.context_length)
  319. mask.fill_(float("-inf"))
  320. mask.triu_(1) # zero out the lower diagonal
  321. return mask
  322. @property
  323. def dtype(self):
  324. return self.visual.conv1.weight.dtype
  325. def encode_image(self, image):
  326. return self.visual(image.type(self.dtype))
  327. def encode_text(self, text):
  328. x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
  329. x = x + self.positional_embedding.type(self.dtype)
  330. x = x.permute(1, 0, 2) # NLD -> LND
  331. x = self.transformer(x)
  332. x = x.permute(1, 0, 2) # LND -> NLD
  333. x = self.ln_final(x).type(self.dtype)
  334. # x.shape = [batch_size, n_ctx, transformer.width]
  335. # take features from the eot embedding (eot_token is the highest number in each sequence)
  336. # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
  337. x = x @ self.text_projection
  338. return x
  339. def forward(self, image, text):
  340. image_features = self.encode_image(image)
  341. text_features = self.encode_text(text)
  342. # # normalized features
  343. # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
  344. # text_features = text_features / text_features.norm(dim=-1, keepdim=True)
  345. # # cosine similarity as logits
  346. # logit_scale = self.logit_scale.exp()
  347. # logits_per_image = logit_scale * image_features @ text_features.t()
  348. # logits_per_text = logits_per_image.t()
  349. # # shape = [global_batch_size, global_batch_size]
  350. # return logits_per_image, logits_per_text
  351. return image_features, text_features
  352. def load_param(self, state_dict):
  353. # 将pretrained_dict里不属于model_dict的键剔除掉
  354. param_dict = {k: v for k, v in state_dict.items() if k in self.state_dict()}
  355. if 'model' in param_dict:
  356. param_dict = param_dict['model']
  357. if 'state_dict' in param_dict:
  358. param_dict = param_dict['state_dict']
  359. for k, v in param_dict.items():
  360. if k == 'visual.positional_embedding' and v.shape != self.visual.positional_embedding.shape:
  361. v = resize_pos_embed(v, self.visual.positional_embedding, self.visual.num_y, self.visual.num_x)
  362. elif k == 'positional_embedding' and v.shape != self.positional_embedding.shape:
  363. v = resize_text_pos_embed(v, self.context_length)
  364. try:
  365. self.state_dict()[k].copy_(v)
  366. except:
  367. print(f'===========================ERROR occur in copy {k}, {v.shape}=========================')
  368. print('shape do not match in k :{}: param_dict{} vs self.state_dict(){}'.format(k, v.shape, self.state_dict()[k].shape))
  369. def resize_pos_embed(posemb, posemb_new, hight, width):
  370. # Rescale the grid of position embeddings when loading from state_dict. Adapted from
  371. # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
  372. posemb = posemb.unsqueeze(0)
  373. posemb_new = posemb_new.unsqueeze(0)
  374. posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
  375. gs_old = int(math.sqrt(len(posemb_grid)))
  376. print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width))
  377. posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
  378. posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
  379. posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
  380. posemb = torch.cat([posemb_token, posemb_grid], dim=1)
  381. return posemb.squeeze(0)
  382. def convert_weights(model: nn.Module):
  383. """Convert applicable model parameters to fp16"""
  384. def _convert_weights_to_fp16(l):
  385. if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
  386. l.weight.data = l.weight.data.half()
  387. if l.bias is not None:
  388. l.bias.data = l.bias.data.half()
  389. if isinstance(l, nn.MultiheadAttention):
  390. for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
  391. tensor = getattr(l, attr)
  392. if tensor is not None:
  393. tensor.data = tensor.data.half()
  394. for name in ["text_projection", "proj", "mcq_proj"]:
  395. if hasattr(l, name):
  396. attr = getattr(l, name)
  397. if attr is not None:
  398. attr.data = attr.data.half()
  399. model.apply(_convert_weights_to_fp16)
  400. def build_CLIP_from_openai_pretrained(name: str, image_size: Union[int, Tuple[int, int]], stride_size: int, jit: bool = False, download_root: str = None):
  401. """Load a CLIP model
  402. Parameters
  403. ----------
  404. name : str
  405. A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
  406. image_size: Union[int, Tuple[int, int]]
  407. Input image size, in Re-ID task, image size commonly set to 384x128, instead of 224x224
  408. jit : bool
  409. Whether to load the optimized JIT model or more hackable non-JIT model (default).
  410. download_root: str
  411. path to download the model files; by default, it uses "~/.cache/clip"
  412. Returns
  413. -------
  414. model : torch.nn.Module
  415. The CLIP model
  416. """
  417. if name in _MODELS:
  418. model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
  419. elif os.path.isfile(name):
  420. model_path = name
  421. else:
  422. raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
  423. try:
  424. # loading JIT archive
  425. model = torch.jit.load(model_path, map_location="cpu")
  426. state_dict = None
  427. except RuntimeError:
  428. # loading saved state dict
  429. if jit:
  430. warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
  431. jit = False
  432. state_dict = torch.load(model_path, map_location="cpu")
  433. state_dict = state_dict or model.state_dict()
  434. vit = "visual.proj" in state_dict
  435. if vit:
  436. vision_width = state_dict["visual.conv1.weight"].shape[0]
  437. vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
  438. vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
  439. grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
  440. image_resolution = vision_patch_size * grid_size
  441. else:
  442. counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
  443. vision_layers = tuple(counts)
  444. vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
  445. output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
  446. vision_patch_size = None
  447. assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
  448. image_resolution = output_width * 32
  449. embed_dim = state_dict["text_projection"].shape[1]
  450. context_length = state_dict["positional_embedding"].shape[0]
  451. vocab_size = state_dict["token_embedding.weight"].shape[0]
  452. transformer_width = state_dict["ln_final.weight"].shape[0]
  453. transformer_heads = transformer_width // 64
  454. transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
  455. model_cfg = {
  456. 'embed_dim': embed_dim,
  457. 'image_resolution': image_resolution,
  458. 'vision_layers': vision_layers,
  459. 'vision_width': vision_width,
  460. 'vision_patch_size': vision_patch_size,
  461. 'context_length': context_length,
  462. 'vocab_size': vocab_size,
  463. 'transformer_width': transformer_width,
  464. 'transformer_heads': transformer_heads,
  465. 'transformer_layers': transformer_layers
  466. }
  467. # modify image resolution to adapt Re-ID task
  468. model_cfg['image_resolution'] = image_size
  469. model_cfg['stride_size'] = stride_size
  470. logger.info(f"Load pretrained {name} CLIP model with model config: {model_cfg}")
  471. model = CLIP(**model_cfg)
  472. # covert model to fp16
  473. # convert_weights(model)
  474. # resize modified pos embedding
  475. model.load_param(state_dict)
  476. return model, model_cfg