Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,83 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
metrics:
|
4 |
+
- accuracy
|
5 |
+
base_model:
|
6 |
+
- google/efficientnet-b4
|
7 |
+
pipeline_tag: image-classification
|
8 |
+
library_name: timm
|
9 |
+
tags:
|
10 |
+
- art
|
11 |
+
- pytorch
|
12 |
+
- images
|
13 |
+
- ai
|
14 |
+
---
|
15 |
+
|
16 |
+
# AI Image Detection
|
17 |
+
|
18 |
+
## Dataset
|
19 |
+
- **AI**: ≈100,000 Images
|
20 |
+
- **Human**: ≈100,000 Images
|
21 |
+
|
22 |
+
## Model
|
23 |
+
- **Architecture**: EfficientNet-B4
|
24 |
+
- **Framework**: PyTorch
|
25 |
+
|
26 |
+
## Evaluation Metrics
|
27 |
+
- **Training Accuracy**: 99.75%
|
28 |
+
- **Validation Accuracy**: 98.59%
|
29 |
+
- **Training Loss**: 0.0072
|
30 |
+
- **Validation Loss**: 0.0553
|
31 |
+
|
32 |
+
|
33 |
+
## Usage
|
34 |
+
|
35 |
+
```
|
36 |
+
pip install torch torchvision timm huggingface_hub pillow
|
37 |
+
```
|
38 |
+
|
39 |
+
### Example Code
|
40 |
+
```python
|
41 |
+
import torch
|
42 |
+
from torchvision import transforms
|
43 |
+
from PIL import Image
|
44 |
+
from timm import create_model
|
45 |
+
from huggingface_hub import hf_hub_download
|
46 |
+
|
47 |
+
# Parameters
|
48 |
+
IMG_SIZE = 380
|
49 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
50 |
+
LABEL_MAPPING = {1: "human", 0: "ai"}
|
51 |
+
|
52 |
+
# Download model from HuggingFace Hub
|
53 |
+
MODEL_PATH = hf_hub_download(repo_id="Dafilab/ai-vs-human-image-detection", filename="model_epoch_8_acc_0.9859.pth")
|
54 |
+
|
55 |
+
# Preprocessing
|
56 |
+
transform = transforms.Compose([
|
57 |
+
transforms.Resize(IMG_SIZE + 20),
|
58 |
+
transforms.CenterCrop(IMG_SIZE),
|
59 |
+
transforms.ToTensor(),
|
60 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
61 |
+
])
|
62 |
+
|
63 |
+
# Load model
|
64 |
+
model = create_model('efficientnet_b4', pretrained=False, num_classes=2)
|
65 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
66 |
+
model.to(DEVICE).eval()
|
67 |
+
|
68 |
+
# Prediction function
|
69 |
+
def predict_image(image_path):
|
70 |
+
img = Image.open(image_path).convert("RGB")
|
71 |
+
img = transform(img).unsqueeze(0).to(DEVICE)
|
72 |
+
with torch.no_grad():
|
73 |
+
logits = model(img)
|
74 |
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
75 |
+
predicted_class = torch.argmax(probs, dim=1).item()
|
76 |
+
confidence = probs[0, predicted_class].item()
|
77 |
+
return LABEL_MAPPING[predicted_class], confidence
|
78 |
+
|
79 |
+
# Example usage
|
80 |
+
image_path = "path/to/image.jpg"
|
81 |
+
label, confidence = predict_image(image_path)
|
82 |
+
print(f"Label: {label}, Confidence: {confidence:.2f}")
|
83 |
+
```
|