import torch import torch.nn as nn import torch.nn.functional as F import timm import numpy as np dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staffordshire_Terrier", "Appenzeller", "Australian_Terrier", "Bedlington_Terrier", "Bernese_Mountain_Dog", "Bichon_Frise", "Blenheim_Spaniel", "Border_Collie", "Border_Terrier", "Boston_Bull", "Bouvier_Des_Flandres", "Brabancon_Griffon", "Brittany_Spaniel", "Cardigan", "Chesapeake_Bay_Retriever", "Chihuahua", "Dachshund", "Dandie_Dinmont", "Doberman", "English_Foxhound", "English_Setter", "English_Springer", "EntleBucher", "Eskimo_Dog", "French_Bulldog", "German_Shepherd", "German_Short-Haired_Pointer", "Gordon_Setter", "Great_Dane", "Great_Pyrenees", "Greater_Swiss_Mountain_Dog","Havanese", "Ibizan_Hound", "Irish_Setter", "Irish_Terrier", "Irish_Water_Spaniel", "Irish_Wolfhound", "Italian_Greyhound", "Japanese_Spaniel", "Kerry_Blue_Terrier", "Labrador_Retriever", "Lakeland_Terrier", "Leonberg", "Lhasa", "Maltese_Dog", "Mexican_Hairless", "Newfoundland", "Norfolk_Terrier", "Norwegian_Elkhound", "Norwich_Terrier", "Old_English_Sheepdog", "Pekinese", "Pembroke", "Pomeranian", "Rhodesian_Ridgeback", "Rottweiler", "Saint_Bernard", "Saluki", "Samoyed", "Scotch_Terrier", "Scottish_Deerhound", "Sealyham_Terrier", "Shetland_Sheepdog", "Shiba_Inu", "Shih-Tzu", "Siberian_Husky", "Staffordshire_Bullterrier", "Sussex_Spaniel", "Tibetan_Mastiff", "Tibetan_Terrier", "Walker_Hound", "Weimaraner", "Welsh_Springer_Spaniel", "West_Highland_White_Terrier", "Yorkshire_Terrier", "Affenpinscher", "Basenji", "Basset", "Beagle", "Black-and-Tan_Coonhound", "Bloodhound", "Bluetick", "Borzoi", "Boxer", "Briard", "Bull_Mastiff", "Cairn", "Chow", "Clumber", "Cocker_Spaniel", "Collie", "Curly-Coated_Retriever", "Dhole", "Dingo", "Flat-Coated_Retriever", "Giant_Schnauzer", "Golden_Retriever", "Groenendael", "Keeshond", "Kelpie", "Komondor", "Kuvasz", "Malamute", "Malinois", "Miniature_Pinscher", "Miniature_Poodle", "Miniature_Schnauzer", "Otterhound", "Papillon", "Pug", "Redbone", "Schipperke", "Silky_Terrier", "Soft-Coated_Wheaten_Terrier", "Standard_Poodle", "Standard_Schnauzer", "Toy_Poodle", "Toy_Terrier", "Vizsla", "Whippet", "Wire-Haired_Fox_Terrier"] class MorphologicalFeatureExtractor(nn.Module): def __init__(self, in_features): super().__init__() # 基礎特徵維度設置 self.reduced_dim = in_features // 4 self.spatial_size = max(7, int(np.sqrt(self.reduced_dim // 64))) # 1. 特徵空間轉換器:將一維特徵轉換為二維空間表示 self.dimension_transformer = nn.Sequential( nn.Linear(in_features, self.spatial_size * self.spatial_size * 64), nn.LayerNorm(self.spatial_size * self.spatial_size * 64), nn.ReLU() ) # 2. 形態特徵分析器:分析具體的形態特徵 self.morphological_analyzers = nn.ModuleDict({ # 體型分析器:分析整體比例和大小 'body_proportion': nn.Sequential( # 使用大卷積核捕捉整體體型特徵 nn.Conv2d(64, 128, kernel_size=7, padding=3), nn.BatchNorm2d(128), nn.ReLU(), # 使用較小的卷積核精煉特徵 nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU() ), # 頭部特徵分析器:關注耳朵、臉部等 'head_features': nn.Sequential( # 中等大小的卷積核,適合分析頭部結構 nn.Conv2d(64, 128, kernel_size=5, padding=2), nn.BatchNorm2d(128), nn.ReLU(), # 小卷積核捕捉細節 nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU() ), # 尾部特徵分析器 'tail_features': nn.Sequential( nn.Conv2d(64, 128, kernel_size=5, padding=2), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU() ), # 毛髮特徵分析器:分析毛髮長度、質地等 'fur_features': nn.Sequential( # 使用多個小卷積核捕捉毛髮紋理 nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU() ), # 顏色特徵分析器:分析顏色分佈 'color_pattern': nn.Sequential( # 第一層:捕捉基本顏色分布 nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), # 第二層:分析顏色模式和花紋 nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), # 第三層:整合顏色信息 nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU() ) }) # 3. 特徵注意力機制:動態關注不同特徵 self.feature_attention = nn.MultiheadAttention( embed_dim=128, num_heads=8, dropout=0.1, batch_first=True ) # 4. 特徵關係分析器:分析不同特徵之間的關係 self.relation_analyzer = nn.Sequential( nn.Linear(128 * 5, 256), # 4個特徵分析器的輸出 nn.LayerNorm(256), nn.ReLU(), nn.Linear(256, 128), nn.LayerNorm(128), nn.ReLU() ) # 5. 特徵整合器:將所有特徵智能地組合在一起 self.feature_integrator = nn.Sequential( nn.Linear(128 * 6, in_features), # 5個原始特徵 + 1個關係特徵 nn.LayerNorm(in_features), nn.ReLU() ) def forward(self, x): batch_size = x.size(0) # 1. 將特徵轉換為空間形式 spatial_features = self.dimension_transformer(x).view( batch_size, 64, self.spatial_size, self.spatial_size ) # 2. 分析各種形態特徵 morphological_features = {} for name, analyzer in self.morphological_analyzers.items(): # 提取特定形態特徵 features = analyzer(spatial_features) # 使用自適應池化統一特徵大小 pooled_features = F.adaptive_avg_pool2d(features, (1, 1)) # 重塑特徵為向量形式 morphological_features[name] = pooled_features.view(batch_size, -1) # 3. 特徵注意力處理 # 將所有特徵堆疊成序列 stacked_features = torch.stack(list(morphological_features.values()), dim=1) # 應用注意力機制 attended_features, _ = self.feature_attention( stacked_features, stacked_features, stacked_features ) # 4. 分析特徵之間的關係 # 將所有特徵連接起來 combined_features = torch.cat(list(morphological_features.values()), dim=1) # 提取特徵間的關係 relation_features = self.relation_analyzer(combined_features) # 5. 特徵整合 # 將原始特徵和關係特徵結合 final_features = torch.cat([ *morphological_features.values(), relation_features ], dim=1) # 6. 最終整合 integrated_features = self.feature_integrator(final_features) # 添加殘差連接 return integrated_features + x class MultiHeadAttention(nn.Module): def __init__(self, in_dim, num_heads=8): """ Initializes the MultiHeadAttention module. Args: in_dim (int): Dimension of the input features. num_heads (int): Number of attention heads. Defaults to 8. """ super().__init__() self.num_heads = num_heads self.head_dim = max(1, in_dim // num_heads) self.scaled_dim = self.head_dim * num_heads self.fc_in = nn.Linear(in_dim, self.scaled_dim) self.query = nn.Linear(self.scaled_dim, self.scaled_dim) # Query projection self.key = nn.Linear(self.scaled_dim, self.scaled_dim) # Key projection self.value = nn.Linear(self.scaled_dim, self.scaled_dim) # Value projection self.fc_out = nn.Linear(self.scaled_dim, in_dim) # Linear layer to project output back to in_dim def forward(self, x): """ Forward pass for multi-head attention mechanism. Args: x (Tensor): Input tensor of shape (batch_size, input_dim). x 是 (N,D), N:批次大小, D:輸入特徵維度 Returns: Tensor: Output tensor after applying attention mechanism. """ N = x.shape[0] # Batch size x = self.fc_in(x) # Project input to scaled_dim q = self.query(x).view(N, self.num_heads, self.head_dim) # Compute queries k = self.key(x).view(N, self.num_heads, self.head_dim) # Compute keys v = self.value(x).view(N, self.num_heads, self.head_dim) # Compute values # Calculate attention scores energy = torch.einsum("nqd,nkd->nqk", [q, k]) attention = F.softmax(energy / (self.head_dim ** 0.5), dim=2) # Apply softmax with scaling # Compute weighted sum of values based on attention scores out = torch.einsum("nqk,nvd->nqd", [attention, v]) out = out.reshape(N, self.scaled_dim) # Concatenate all heads out = self.fc_out(out) # Project back to original input dimension return out class BaseModel(nn.Module): def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'): super().__init__() self.device = device # 1. Initialize backbone self.backbone = timm.create_model( 'convnextv2_base', pretrained=True, num_classes=0 ) # 2. 使用測試數據來確定實際的特徵維度 with torch.no_grad(): dummy_input = torch.randn(1, 3, 224, 224) features = self.backbone(dummy_input) if len(features.shape) > 2: features = features.mean([-2, -1]) self.feature_dim = features.shape[1] print(f"Feature Dimension from V2 backbone: {self.feature_dim}") # 3. Setup multi-head attention layer self.num_heads = max(1, min(8, self.feature_dim // 64)) self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads) # 4. Setup classifier self.classifier = nn.Sequential( nn.LayerNorm(self.feature_dim), nn.Dropout(0.3), nn.Linear(self.feature_dim, num_classes) ) self.morphological_extractor = MorphologicalFeatureExtractor( in_features=self.feature_dim ) self.feature_fusion = nn.Sequential( nn.Linear(self.feature_dim * 3, self.feature_dim), nn.LayerNorm(self.feature_dim), nn.ReLU(), nn.Linear(self.feature_dim, self.feature_dim), nn.LayerNorm(self.feature_dim), nn.ReLU() ) def forward(self, x): """ Forward propagation process, combining V2's FCCA and multi-head attention mechanism Args: x (Tensor): Input image tensor of shape [batch_size, channels, height, width] Returns: Tuple[Tensor, Tensor]: Classification logits and attention features """ x = x.to(self.device) # 1. Extract base features features = self.backbone(x) if len(features.shape) > 2: features = features.mean([-2, -1]) # 2. Extract morphological features (including all detail features) morphological_features = self.morphological_extractor(features) # 3. Feature fusion (note dimension alignment with new fusion layer) combined_features = torch.cat([ features, # Original features morphological_features, # Morphological features features * morphological_features # Feature interaction information ], dim=1) fused_features = self.feature_fusion(combined_features) # 4. Apply attention mechanism attended_features = self.attention(fused_features) # 5. Final classifier logits = self.classifier(attended_features) return logits, attended_features