Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from torch.nn import GroupNorm, LayerNorm | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint as checkpoint | |
| import timm | |
| class ViTWrapper(nn.Module): | |
| """Wrapper to make ViT compatible with feature extraction for ImageTagger""" | |
| def __init__(self, vit_model): | |
| super().__init__() | |
| self.vit = vit_model | |
| self.out_indices = (-1,) # mimic timm.features_only | |
| # Get patch size and embedding dim from the model | |
| self.patch_size = vit_model.patch_embed.patch_size[0] | |
| self.embed_dim = vit_model.embed_dim | |
| def forward(self, x): | |
| B = x.size(0) | |
| # β patch tokens | |
| x = self.vit.patch_embed(x) # (B, N, C) | |
| # β prepend CLS | |
| cls_tok = self.vit.cls_token.expand(B, -1, -1) # (B, 1, C) | |
| x = torch.cat((cls_tok, x), dim=1) # (B, 1+N, C) | |
| # β add positional encodings (full, incl. CLS) | |
| if self.vit.pos_embed is not None: | |
| x = x + self.vit.pos_embed[:, : x.size(1), :] | |
| x = self.vit.pos_drop(x) | |
| for blk in self.vit.blocks: | |
| x = blk(x) | |
| x = self.vit.norm(x) # (B, 1+N, C) | |
| # β split back out | |
| cls_final = x[:, 0] # (B, C) | |
| patch_tokens = x[:, 1:] # (B, N, C) | |
| # β reshape patches to (B, C, H, W) | |
| B, N, C = patch_tokens.shape | |
| h = w = int(N ** 0.5) # square assumption | |
| patch_features = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w) | |
| # Return **both**: (patch map, CLS) | |
| return patch_features, cls_final | |
| def set_grad_checkpointing(self, enable=True): | |
| """Enable gradient checkpointing if supported""" | |
| if hasattr(self.vit, 'set_grad_checkpointing'): | |
| self.vit.set_grad_checkpointing(enable) | |
| return True | |
| return False | |
| class ImageTagger(nn.Module): | |
| """ | |
| ImageTagger with Vision Transformer backbone | |
| """ | |
| def __init__(self, total_tags, dataset, model_name='vit_base_patch16_224', | |
| num_heads=16, dropout=0.1, pretrained=True, tag_context_size=256, | |
| use_gradient_checkpointing=False, img_size=224): | |
| super().__init__() | |
| # Store checkpointing config | |
| self.use_gradient_checkpointing = use_gradient_checkpointing | |
| self.model_name = model_name | |
| self.img_size = img_size | |
| # Debug and stats flags | |
| self._flags = { | |
| 'debug': False, | |
| 'model_stats': True | |
| } | |
| # Core model config | |
| self.dataset = dataset | |
| self.tag_context_size = tag_context_size | |
| self.total_tags = total_tags | |
| print(f"ποΈ Building ImageTagger with ViT backbone and {total_tags} tags") | |
| print(f" Backbone: {model_name}") | |
| print(f" Image size: {img_size}x{img_size}") | |
| print(f" Tag context size: {tag_context_size}") | |
| print(f" Gradient checkpointing: {use_gradient_checkpointing}") | |
| print(f" π― Custom embeddings, PyTorch native attention, no ground truth inclusion") | |
| # 1. Vision Transformer Backbone | |
| print("π¦ Loading Vision Transformer backbone...") | |
| self._load_vit_backbone() | |
| # Get backbone dimensions by running a test forward pass | |
| self._determine_backbone_dimensions() | |
| self.embedding_dim = self.backbone.embed_dim | |
| # 2. Custom Tag Embeddings (no CLIP) | |
| print("π― Using custom tag embeddings (no CLIP)") | |
| self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim) | |
| # 3. Shared weights approach - tag bias for initial predictions | |
| print("π Using shared weights between initial head and tag embeddings") | |
| self.tag_bias = nn.Parameter(torch.zeros(total_tags)) | |
| # 4. Image token extraction (for attention AND global pooling) | |
| self.image_token_proj = nn.Identity() | |
| # 5. Tags-as-queries cross-attention (using PyTorch's optimized implementation) | |
| self.cross_attention = nn.MultiheadAttention( | |
| embed_dim=self.embedding_dim, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| batch_first=True # Use (batch, seq, feature) format | |
| ) | |
| self.cross_norm = nn.LayerNorm(self.embedding_dim) | |
| # Initialize weights | |
| self._init_weights() | |
| # Enable gradient checkpointing | |
| if self.use_gradient_checkpointing: | |
| self._enable_gradient_checkpointing() | |
| print(f"β ImageTagger with ViT initialized!") | |
| self._print_parameter_count() | |
| def _load_vit_backbone(self): | |
| """Load Vision Transformer model from timm""" | |
| print(f" Loading from timm: {self.model_name}") | |
| # Load the ViT model (not features_only, we want the full model for token extraction) | |
| vit_model = timm.create_model( | |
| self.model_name, | |
| pretrained=True, | |
| img_size=self.img_size, | |
| num_classes=0 # Remove classification head | |
| ) | |
| # Wrap it in our compatibility layer | |
| self.backbone = ViTWrapper(vit_model) | |
| print(f" β ViT loaded successfully") | |
| print(f" Patch size: {self.backbone.patch_size}x{self.backbone.patch_size}") | |
| print(f" Embed dim: {self.backbone.embed_dim}") | |
| def _determine_backbone_dimensions(self): | |
| """Determine backbone output dimensions""" | |
| print(" π Determining backbone dimensions...") | |
| with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): | |
| # Create a dummy input | |
| dummy_input = torch.randn(1, 3, self.img_size, self.img_size) | |
| # Get features | |
| backbone_features, cls_dummy = self.backbone(dummy_input) | |
| feature_tensor = backbone_features | |
| self.backbone_dim = feature_tensor.shape[1] | |
| self.feature_map_size = feature_tensor.shape[2] | |
| print(f" Backbone output: {self.backbone_dim}D, {self.feature_map_size}x{self.feature_map_size} spatial") | |
| print(f" Total patch tokens: {self.feature_map_size * self.feature_map_size}") | |
| def _enable_gradient_checkpointing(self): | |
| """Enable gradient checkpointing for memory efficiency""" | |
| print("π Enabling gradient checkpointing...") | |
| # Enable checkpointing for ViT backbone | |
| if self.backbone.set_grad_checkpointing(True): | |
| print(" β ViT backbone checkpointing enabled") | |
| else: | |
| print(" β οΈ ViT backbone doesn't support built-in checkpointing, will checkpoint manually") | |
| def _checkpoint_backbone(self, x): | |
| """Wrapper for backbone with gradient checkpointing""" | |
| if self.use_gradient_checkpointing and self.training: | |
| return checkpoint.checkpoint(self.backbone, x, use_reentrant=False) | |
| else: | |
| return self.backbone(x) | |
| def _checkpoint_image_proj(self, x): | |
| """Wrapper for image projection with gradient checkpointing""" | |
| if self.use_gradient_checkpointing and self.training: | |
| return checkpoint.checkpoint(self.image_token_proj, x, use_reentrant=False) | |
| else: | |
| return self.image_token_proj(x) | |
| def _checkpoint_cross_attention(self, query, key, value): | |
| """Wrapper for cross attention with gradient checkpointing""" | |
| def _attention_forward(q, k, v): | |
| attended_features, _ = self.cross_attention(query=q, key=k, value=v) | |
| return self.cross_norm(attended_features) | |
| if self.use_gradient_checkpointing and self.training: | |
| return checkpoint.checkpoint(_attention_forward, query, key, value, use_reentrant=False) | |
| else: | |
| return _attention_forward(query, key, value) | |
| def _checkpoint_candidate_selection(self, initial_logits): | |
| """Wrapper for candidate selection with gradient checkpointing""" | |
| def _candidate_forward(logits): | |
| return self._get_candidate_tags(logits) | |
| if self.use_gradient_checkpointing and self.training: | |
| return checkpoint.checkpoint(_candidate_forward, initial_logits, use_reentrant=False) | |
| else: | |
| return _candidate_forward(initial_logits) | |
| def _checkpoint_final_scoring(self, attended_features, candidate_indices): | |
| """Wrapper for final scoring with gradient checkpointing""" | |
| def _scoring_forward(features, indices): | |
| emb = self.tag_embedding(indices) | |
| # BF16 in, BF16 out | |
| return (features * emb).sum(dim=-1) | |
| if self.use_gradient_checkpointing and self.training: | |
| return checkpoint.checkpoint(_scoring_forward, attended_features, candidate_indices, use_reentrant=False) | |
| else: | |
| return _scoring_forward(attended_features, candidate_indices) | |
| def _init_weights(self): | |
| """Initialize weights for new modules""" | |
| def _init_layer(layer): | |
| if isinstance(layer, nn.Linear): | |
| nn.init.xavier_uniform_(layer.weight) | |
| if layer.bias is not None: | |
| nn.init.zeros_(layer.bias) | |
| elif isinstance(layer, nn.Conv2d): | |
| nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') | |
| if layer.bias is not None: | |
| nn.init.zeros_(layer.bias) | |
| elif isinstance(layer, nn.Embedding): | |
| nn.init.normal_(layer.weight, mean=0, std=0.02) | |
| # Initialize new components | |
| self.image_token_proj.apply(_init_layer) | |
| # Initialize tag embeddings with normal distribution | |
| nn.init.normal_(self.tag_embedding.weight, mean=0, std=0.02) | |
| # Initialize tag bias | |
| nn.init.zeros_(self.tag_bias) | |
| def _print_parameter_count(self): | |
| """Print parameter statistics""" | |
| total_params = sum(p.numel() for p in self.parameters()) | |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| backbone_params = sum(p.numel() for p in self.backbone.parameters()) | |
| print(f"π Parameter Statistics:") | |
| print(f" Total parameters: {total_params/1e6:.1f}M") | |
| print(f" Trainable parameters: {trainable_params/1e6:.1f}M") | |
| print(f" Frozen parameters: {(total_params-trainable_params)/1e6:.1f}M") | |
| print(f" Backbone parameters: {backbone_params/1e6:.1f}M") | |
| if self.use_gradient_checkpointing: | |
| print(f" π Gradient checkpointing enabled for memory efficiency") | |
| def debug(self): | |
| return self._flags['debug'] | |
| def model_stats(self): | |
| return self._flags['model_stats'] | |
| def _get_candidate_tags(self, initial_logits, target_tags=None, hard_negatives=None): | |
| """Select candidate tags - no ground truth inclusion""" | |
| batch_size = initial_logits.size(0) | |
| # Simply select top K candidates based on initial predictions | |
| top_probs, top_indices = torch.topk( | |
| torch.sigmoid(initial_logits), | |
| k=min(self.tag_context_size, self.total_tags), | |
| dim=1, largest=True, sorted=True | |
| ) | |
| return top_indices | |
| def _analyze_predictions(self, predictions, tag_indices): | |
| """Analyze prediction patterns""" | |
| if not self.model_stats: | |
| return {} | |
| if torch._dynamo.is_compiling(): | |
| return {} | |
| with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): | |
| probs = torch.sigmoid(predictions) | |
| relevant_probs = torch.gather(probs, 1, tag_indices) | |
| return { | |
| 'prediction_confidence': relevant_probs.mean().item(), | |
| 'prediction_entropy': -(relevant_probs * torch.log(relevant_probs + 1e-9)).mean().item(), | |
| 'high_confidence_ratio': (relevant_probs > 0.7).float().mean().item(), | |
| 'above_threshold_ratio': (relevant_probs > 0.5).float().mean().item(), | |
| } | |
| def forward(self, x, targets=None, hard_negatives=None): | |
| """ | |
| Forward pass with ViT backbone, CLS token support and gradient-checkpointing. | |
| All arithmetic tensors stay in the backboneβs dtype (BF16 under autocast, | |
| FP32 otherwise). Anything that must mix dtypes is cast to match. | |
| """ | |
| batch_size = x.size(0) | |
| model_stats = {} if self.model_stats else {} | |
| # ------------------------------------------------------------------ | |
| # 1. Backbone β patch map + CLS token | |
| # ------------------------------------------------------------------ | |
| patch_map, cls_token = self._checkpoint_backbone(x) # patch_map: [B, C, H, W] | |
| # cls_token: [B, C] | |
| # ------------------------------------------------------------------ | |
| # 2. Tokens β global image vector | |
| # ------------------------------------------------------------------ | |
| image_tokens_4d = self._checkpoint_image_proj(patch_map) # [B, C, H, W] | |
| image_tokens = image_tokens_4d.flatten(2).transpose(1, 2) # [B, N, C] | |
| # βDual-poolβ: mean-pool patches β CLS | |
| global_features = 0.5 * (image_tokens.mean(dim=1, dtype=image_tokens.dtype) + cls_token) # [B, C] | |
| compute_dtype = global_features.dtype # BF16 or FP32 | |
| # ------------------------------------------------------------------ | |
| # 3. Initial logits (shared weights) | |
| # ------------------------------------------------------------------ | |
| tag_weights = self.tag_embedding.weight.to(compute_dtype) # [T, C] | |
| tag_bias = self.tag_bias.to(compute_dtype) # [T] | |
| initial_logits = global_features @ tag_weights.t() + tag_bias # [B, T] | |
| initial_logits = initial_logits.to(compute_dtype) # keep dtype uniform | |
| initial_preds = initial_logits # alias | |
| # ------------------------------------------------------------------ | |
| # 4. Candidate set | |
| # ------------------------------------------------------------------ | |
| candidate_indices = self._checkpoint_candidate_selection(initial_logits) # [B, K] | |
| tag_embeddings = self.tag_embedding(candidate_indices).to(compute_dtype) # [B, K, C] | |
| attended_features = self._checkpoint_cross_attention( # [B, K, C] | |
| tag_embeddings, image_tokens, image_tokens | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 5. Score candidates & scatter back | |
| # ------------------------------------------------------------------ | |
| candidate_logits = self._checkpoint_final_scoring(attended_features, candidate_indices) # [B, K] | |
| # --- align dtypes so scatter never throws --- | |
| if candidate_logits.dtype != initial_logits.dtype: | |
| candidate_logits = candidate_logits.to(initial_logits.dtype) | |
| refined_logits = initial_logits.clone() | |
| refined_logits.scatter_(1, candidate_indices, candidate_logits) | |
| refined_preds = refined_logits | |
| # ------------------------------------------------------------------ | |
| # 6. Optional stats | |
| # ------------------------------------------------------------------ | |
| if self.model_stats and targets is not None and not torch._dynamo.is_compiling(): | |
| model_stats['initial_prediction_stats'] = self._analyze_predictions(initial_preds, | |
| candidate_indices) | |
| model_stats['refined_prediction_stats'] = self._analyze_predictions(refined_preds, | |
| candidate_indices) | |
| return { | |
| 'initial_predictions': initial_preds, | |
| 'refined_predictions': refined_preds, | |
| 'selected_candidates': candidate_indices, | |
| 'model_stats': model_stats | |
| } | |
| def predict |