PawMatchAI / model_architecture.py
DawnC's picture
Upload 3 files
da003cc
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np
dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staffordshire_Terrier",
"Appenzeller", "Australian_Terrier", "Bedlington_Terrier", "Bernese_Mountain_Dog", "Bichon_Frise",
"Blenheim_Spaniel", "Border_Collie", "Border_Terrier", "Boston_Bull", "Bouvier_Des_Flandres",
"Brabancon_Griffon", "Brittany_Spaniel", "Cardigan", "Chesapeake_Bay_Retriever",
"Chihuahua", "Dachshund", "Dandie_Dinmont", "Doberman", "English_Foxhound", "English_Setter",
"English_Springer", "EntleBucher", "Eskimo_Dog", "French_Bulldog", "German_Shepherd",
"German_Short-Haired_Pointer", "Gordon_Setter", "Great_Dane", "Great_Pyrenees",
"Greater_Swiss_Mountain_Dog","Havanese", "Ibizan_Hound", "Irish_Setter", "Irish_Terrier",
"Irish_Water_Spaniel", "Irish_Wolfhound", "Italian_Greyhound", "Japanese_Spaniel",
"Kerry_Blue_Terrier", "Labrador_Retriever", "Lakeland_Terrier", "Leonberg", "Lhasa",
"Maltese_Dog", "Mexican_Hairless", "Newfoundland", "Norfolk_Terrier", "Norwegian_Elkhound",
"Norwich_Terrier", "Old_English_Sheepdog", "Pekinese", "Pembroke", "Pomeranian",
"Rhodesian_Ridgeback", "Rottweiler", "Saint_Bernard", "Saluki", "Samoyed",
"Scotch_Terrier", "Scottish_Deerhound", "Sealyham_Terrier", "Shetland_Sheepdog", "Shiba_Inu",
"Shih-Tzu", "Siberian_Husky", "Staffordshire_Bullterrier", "Sussex_Spaniel",
"Tibetan_Mastiff", "Tibetan_Terrier", "Walker_Hound", "Weimaraner",
"Welsh_Springer_Spaniel", "West_Highland_White_Terrier", "Yorkshire_Terrier",
"Affenpinscher", "Basenji", "Basset", "Beagle", "Black-and-Tan_Coonhound", "Bloodhound",
"Bluetick", "Borzoi", "Boxer", "Briard", "Bull_Mastiff", "Cairn", "Chow", "Clumber",
"Cocker_Spaniel", "Collie", "Curly-Coated_Retriever", "Dhole", "Dingo",
"Flat-Coated_Retriever", "Giant_Schnauzer", "Golden_Retriever", "Groenendael", "Keeshond",
"Kelpie", "Komondor", "Kuvasz", "Malamute", "Malinois", "Miniature_Pinscher",
"Miniature_Poodle", "Miniature_Schnauzer", "Otterhound", "Papillon", "Pug", "Redbone",
"Schipperke", "Silky_Terrier", "Soft-Coated_Wheaten_Terrier", "Standard_Poodle",
"Standard_Schnauzer", "Toy_Poodle", "Toy_Terrier", "Vizsla", "Whippet",
"Wire-Haired_Fox_Terrier"]
class MorphologicalFeatureExtractor(nn.Module):
def __init__(self, in_features):
super().__init__()
# 基礎特徵維度設置
self.reduced_dim = in_features // 4
self.spatial_size = max(7, int(np.sqrt(self.reduced_dim // 64)))
# 1. 特徵空間轉換器:將一維特徵轉換為二維空間表示
self.dimension_transformer = nn.Sequential(
nn.Linear(in_features, self.spatial_size * self.spatial_size * 64),
nn.LayerNorm(self.spatial_size * self.spatial_size * 64),
nn.ReLU()
)
# 2. 形態特徵分析器:分析具體的形態特徵
self.morphological_analyzers = nn.ModuleDict({
# 體型分析器:分析整體比例和大小
'body_proportion': nn.Sequential(
# 使用大卷積核捕捉整體體型特徵
nn.Conv2d(64, 128, kernel_size=7, padding=3),
nn.BatchNorm2d(128),
nn.ReLU(),
# 使用較小的卷積核精煉特徵
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
),
# 頭部特徵分析器:關注耳朵、臉部等
'head_features': nn.Sequential(
# 中等大小的卷積核,適合分析頭部結構
nn.Conv2d(64, 128, kernel_size=5, padding=2),
nn.BatchNorm2d(128),
nn.ReLU(),
# 小卷積核捕捉細節
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
),
# 尾部特徵分析器
'tail_features': nn.Sequential(
nn.Conv2d(64, 128, kernel_size=5, padding=2),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
),
# 毛髮特徵分析器:分析毛髮長度、質地等
'fur_features': nn.Sequential(
# 使用多個小卷積核捕捉毛髮紋理
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
),
# 顏色特徵分析器:分析顏色分佈
'color_pattern': nn.Sequential(
# 第一層:捕捉基本顏色分布
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
# 第二層:分析顏色模式和花紋
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
# 第三層:整合顏色信息
nn.Conv2d(128, 128, kernel_size=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
})
# 3. 特徵注意力機制:動態關注不同特徵
self.feature_attention = nn.MultiheadAttention(
embed_dim=128,
num_heads=8,
dropout=0.1,
batch_first=True
)
# 4. 特徵關係分析器:分析不同特徵之間的關係
self.relation_analyzer = nn.Sequential(
nn.Linear(128 * 5, 256), # 4個特徵分析器的輸出
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 128),
nn.LayerNorm(128),
nn.ReLU()
)
# 5. 特徵整合器:將所有特徵智能地組合在一起
self.feature_integrator = nn.Sequential(
nn.Linear(128 * 6, in_features), # 5個原始特徵 + 1個關係特徵
nn.LayerNorm(in_features),
nn.ReLU()
)
def forward(self, x):
batch_size = x.size(0)
# 1. 將特徵轉換為空間形式
spatial_features = self.dimension_transformer(x).view(
batch_size, 64, self.spatial_size, self.spatial_size
)
# 2. 分析各種形態特徵
morphological_features = {}
for name, analyzer in self.morphological_analyzers.items():
# 提取特定形態特徵
features = analyzer(spatial_features)
# 使用自適應池化統一特徵大小
pooled_features = F.adaptive_avg_pool2d(features, (1, 1))
# 重塑特徵為向量形式
morphological_features[name] = pooled_features.view(batch_size, -1)
# 3. 特徵注意力處理
# 將所有特徵堆疊成序列
stacked_features = torch.stack(list(morphological_features.values()), dim=1)
# 應用注意力機制
attended_features, _ = self.feature_attention(
stacked_features, stacked_features, stacked_features
)
# 4. 分析特徵之間的關係
# 將所有特徵連接起來
combined_features = torch.cat(list(morphological_features.values()), dim=1)
# 提取特徵間的關係
relation_features = self.relation_analyzer(combined_features)
# 5. 特徵整合
# 將原始特徵和關係特徵結合
final_features = torch.cat([
*morphological_features.values(),
relation_features
], dim=1)
# 6. 最終整合
integrated_features = self.feature_integrator(final_features)
# 添加殘差連接
return integrated_features + x
class MultiHeadAttention(nn.Module):
def __init__(self, in_dim, num_heads=8):
"""
Initializes the MultiHeadAttention module.
Args:
in_dim (int): Dimension of the input features.
num_heads (int): Number of attention heads. Defaults to 8.
"""
super().__init__()
self.num_heads = num_heads
self.head_dim = max(1, in_dim // num_heads)
self.scaled_dim = self.head_dim * num_heads
self.fc_in = nn.Linear(in_dim, self.scaled_dim)
self.query = nn.Linear(self.scaled_dim, self.scaled_dim) # Query projection
self.key = nn.Linear(self.scaled_dim, self.scaled_dim) # Key projection
self.value = nn.Linear(self.scaled_dim, self.scaled_dim) # Value projection
self.fc_out = nn.Linear(self.scaled_dim, in_dim) # Linear layer to project output back to in_dim
def forward(self, x):
"""
Forward pass for multi-head attention mechanism.
Args:
x (Tensor): Input tensor of shape (batch_size, input_dim).
x 是 (N,D), N:批次大小, D:輸入特徵維度
Returns:
Tensor: Output tensor after applying attention mechanism.
"""
N = x.shape[0] # Batch size
x = self.fc_in(x) # Project input to scaled_dim
q = self.query(x).view(N, self.num_heads, self.head_dim) # Compute queries
k = self.key(x).view(N, self.num_heads, self.head_dim) # Compute keys
v = self.value(x).view(N, self.num_heads, self.head_dim) # Compute values
# Calculate attention scores
energy = torch.einsum("nqd,nkd->nqk", [q, k])
attention = F.softmax(energy / (self.head_dim ** 0.5), dim=2) # Apply softmax with scaling
# Compute weighted sum of values based on attention scores
out = torch.einsum("nqk,nvd->nqd", [attention, v])
out = out.reshape(N, self.scaled_dim) # Concatenate all heads
out = self.fc_out(out) # Project back to original input dimension
return out
class BaseModel(nn.Module):
def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
super().__init__()
self.device = device
# 1. Initialize backbone
self.backbone = timm.create_model(
'convnextv2_base',
pretrained=True,
num_classes=0
)
# 2. 使用測試數據來確定實際的特徵維度
with torch.no_grad():
dummy_input = torch.randn(1, 3, 224, 224)
features = self.backbone(dummy_input)
if len(features.shape) > 2:
features = features.mean([-2, -1])
self.feature_dim = features.shape[1]
print(f"Feature Dimension from V2 backbone: {self.feature_dim}")
# 3. Setup multi-head attention layer
self.num_heads = max(1, min(8, self.feature_dim // 64))
self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
# 4. Setup classifier
self.classifier = nn.Sequential(
nn.LayerNorm(self.feature_dim),
nn.Dropout(0.3),
nn.Linear(self.feature_dim, num_classes)
)
self.morphological_extractor = MorphologicalFeatureExtractor(
in_features=self.feature_dim
)
self.feature_fusion = nn.Sequential(
nn.Linear(self.feature_dim * 3, self.feature_dim),
nn.LayerNorm(self.feature_dim),
nn.ReLU(),
nn.Linear(self.feature_dim, self.feature_dim),
nn.LayerNorm(self.feature_dim),
nn.ReLU()
)
def forward(self, x):
"""
Forward propagation process, combining V2's FCCA and multi-head attention mechanism
Args:
x (Tensor): Input image tensor of shape [batch_size, channels, height, width]
Returns:
Tuple[Tensor, Tensor]: Classification logits and attention features
"""
x = x.to(self.device)
# 1. Extract base features
features = self.backbone(x)
if len(features.shape) > 2:
features = features.mean([-2, -1])
# 2. Extract morphological features (including all detail features)
morphological_features = self.morphological_extractor(features)
# 3. Feature fusion (note dimension alignment with new fusion layer)
combined_features = torch.cat([
features, # Original features
morphological_features, # Morphological features
features * morphological_features # Feature interaction information
], dim=1)
fused_features = self.feature_fusion(combined_features)
# 4. Apply attention mechanism
attended_features = self.attention(fused_features)
# 5. Final classifier
logits = self.classifier(attended_features)
return logits, attended_features