Spaces:
Runtime error
Runtime error
Bastien Dechamps
commited on
Commit
·
1791df2
1
Parent(s):
9ed0050
[ADD] Average embedder
Browse files- app.py +18 -3
- geoguessr_bot/guessr/__init__.py +2 -1
- geoguessr_bot/guessr/abstract_guessr.py +1 -1
- geoguessr_bot/guessr/average_neighbor_embedder_guessr.py +64 -0
- geoguessr_bot/guessr/{global_embedder_guessr.py → nearest_neighbor_embedder_guessr.py} +4 -5
- geoguessr_bot/interfaces.py +14 -0
- geoguessr_bot/retriever/retriever.py +5 -6
- requirements.txt +1 -0
app.py
CHANGED
@@ -4,18 +4,20 @@ import numpy as np
|
|
4 |
import gradio as gr
|
5 |
import plotly.graph_objects as go
|
6 |
|
7 |
-
from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr,
|
|
|
8 |
from geoguessr_bot.retriever import DinoV2Embedder, Retriever
|
9 |
|
10 |
|
11 |
ALL_GUESSR_CLASS = {
|
12 |
"random": RandomGuessr,
|
13 |
-
"
|
|
|
14 |
}
|
15 |
|
16 |
ALL_GUESSR_ARGS = {
|
17 |
"random": {},
|
18 |
-
"
|
19 |
"embedder": DinoV2Embedder(
|
20 |
device="cpu"
|
21 |
),
|
@@ -25,6 +27,19 @@ ALL_GUESSR_ARGS = {
|
|
25 |
),
|
26 |
"metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
27 |
"resources/metadatav3.csv"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
}
|
29 |
}
|
30 |
|
|
|
4 |
import gradio as gr
|
5 |
import plotly.graph_objects as go
|
6 |
|
7 |
+
from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, NearestNeighborEmbedderGuessr, \
|
8 |
+
AverageNeighborsEmbedderGuessr
|
9 |
from geoguessr_bot.retriever import DinoV2Embedder, Retriever
|
10 |
|
11 |
|
12 |
ALL_GUESSR_CLASS = {
|
13 |
"random": RandomGuessr,
|
14 |
+
"nearestNeighborEmbedder": NearestNeighborEmbedderGuessr,
|
15 |
+
"averageNeighborsEmbedder": AverageNeighborsEmbedderGuessr,
|
16 |
}
|
17 |
|
18 |
ALL_GUESSR_ARGS = {
|
19 |
"random": {},
|
20 |
+
"nearestNeighborEmbedder": {
|
21 |
"embedder": DinoV2Embedder(
|
22 |
device="cpu"
|
23 |
),
|
|
|
27 |
),
|
28 |
"metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
29 |
"resources/metadatav3.csv"),
|
30 |
+
},
|
31 |
+
"averageNeighborsEmbedder": {
|
32 |
+
"embedder": DinoV2Embedder(
|
33 |
+
device="cpu"
|
34 |
+
),
|
35 |
+
"retriever": Retriever(
|
36 |
+
embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
37 |
+
"resources/embeddings.npy"),
|
38 |
+
),
|
39 |
+
"metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
40 |
+
"resources/metadatav3.csv"),
|
41 |
+
"n_neighbors": 2000,
|
42 |
+
"dbscan_eps": 0.5
|
43 |
}
|
44 |
}
|
45 |
|
geoguessr_bot/guessr/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
from .abstract_guessr import AbstractGuessr
|
2 |
from .random_guessr import RandomGuessr
|
3 |
-
from .
|
|
|
|
1 |
from .abstract_guessr import AbstractGuessr
|
2 |
from .random_guessr import RandomGuessr
|
3 |
+
from .nearest_neighbor_embedder_guessr import NearestNeighborEmbedderGuessr
|
4 |
+
from .average_neighbor_embedder_guessr import AverageNeighborsEmbedderGuessr
|
geoguessr_bot/guessr/abstract_guessr.py
CHANGED
@@ -25,7 +25,7 @@ class AbstractGuessr:
|
|
25 |
"""Create an interactive map showing a coordinate
|
26 |
"""
|
27 |
fig = go.Figure(go.Scattermapbox(
|
28 |
-
customdata=[guess_coordinate
|
29 |
lat=[guess_coordinate.latitude] if guess_coordinate is not None else None,
|
30 |
lon=[guess_coordinate.longitude] if guess_coordinate is not None else None,
|
31 |
mode="markers",
|
|
|
25 |
"""Create an interactive map showing a coordinate
|
26 |
"""
|
27 |
fig = go.Figure(go.Scattermapbox(
|
28 |
+
customdata=[str(guess_coordinate)] if guess_coordinate is not None else None,
|
29 |
lat=[guess_coordinate.latitude] if guess_coordinate is not None else None,
|
30 |
lon=[guess_coordinate.longitude] if guess_coordinate is not None else None,
|
31 |
mode="markers",
|
geoguessr_bot/guessr/average_neighbor_embedder_guessr.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import Counter
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from sklearn.cluster import DBSCAN
|
6 |
+
from sklearn.metrics.pairwise import haversine_distances
|
7 |
+
from PIL import Image
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
from geoguessr_bot.guessr import AbstractGuessr
|
11 |
+
from geoguessr_bot.interfaces import Coordinate
|
12 |
+
from geoguessr_bot.retriever import AbstractImageEmbedder
|
13 |
+
from geoguessr_bot.retriever import Retriever
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class AverageNeighborsEmbedderGuessr(AbstractGuessr):
|
18 |
+
"""Guesses a coordinate using an Embedder and a retriever followed by NN.
|
19 |
+
"""
|
20 |
+
embedder: AbstractImageEmbedder
|
21 |
+
retriever: Retriever
|
22 |
+
metadata_path: str
|
23 |
+
n_neighbors: int = 1000
|
24 |
+
dbscan_eps: float = 0.05
|
25 |
+
|
26 |
+
def __post_init__(self):
|
27 |
+
"""Load metadata
|
28 |
+
"""
|
29 |
+
metadata = pd.read_csv(self.metadata_path)
|
30 |
+
self.image_to_coordinate = {
|
31 |
+
image.split("/")[-1]: Coordinate(latitude=latitude, longitude=longitude)
|
32 |
+
for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"])
|
33 |
+
}
|
34 |
+
# DBSCAN will be used to take the centroid of the biggest cluster among the N neighbors, using Haversine
|
35 |
+
self.dbscan = DBSCAN(eps=self.dbscan_eps, metric=haversine_distances)
|
36 |
+
|
37 |
+
def guess(self, image: Image) -> Coordinate:
|
38 |
+
"""Guess a coordinate from an image
|
39 |
+
"""
|
40 |
+
# Embed image
|
41 |
+
image = Image.fromarray(image)
|
42 |
+
image_embedding = self.embedder.embed(image)[None, :]
|
43 |
+
|
44 |
+
# Retrieve nearest neighbors
|
45 |
+
nearest_neighbors, distances = self.retriever.retrieve(image_embedding, self.n_neighbors)
|
46 |
+
nearest_neighbors = nearest_neighbors[0]
|
47 |
+
distances = distances[0]
|
48 |
+
|
49 |
+
# Get coordinates of neighbors
|
50 |
+
neighbors_coordinates = [self.image_to_coordinate[nn].to_radians() for nn in nearest_neighbors]
|
51 |
+
neighbors_coordinates = np.array([[nn.latitude, nn.longitude] for nn in neighbors_coordinates])
|
52 |
+
|
53 |
+
# Use DBSCAN to find the biggest cluster and potentially remove outliers
|
54 |
+
clustering = self.dbscan.fit(neighbors_coordinates)
|
55 |
+
labels = clustering.labels_
|
56 |
+
biggest_cluster = max(Counter(labels))
|
57 |
+
neighbors_coordinates = neighbors_coordinates[labels == biggest_cluster]
|
58 |
+
distances = distances[labels == biggest_cluster]
|
59 |
+
|
60 |
+
# Guess coordinate as the closest image among the cluster regarding retrieving distance
|
61 |
+
guess_coordinate = neighbors_coordinates[np.argmin(distances)]
|
62 |
+
guess_coordinate = Coordinate.from_radians(guess_coordinate[0], guess_coordinate[1])
|
63 |
+
return guess_coordinate
|
64 |
+
|
geoguessr_bot/guessr/{global_embedder_guessr.py → nearest_neighbor_embedder_guessr.py}
RENAMED
@@ -10,10 +10,9 @@ from geoguessr_bot.retriever import Retriever
|
|
10 |
|
11 |
|
12 |
@dataclass
|
13 |
-
class
|
14 |
-
"""Guesses a coordinate using an Embedder and a retriever
|
15 |
"""
|
16 |
-
|
17 |
embedder: AbstractImageEmbedder
|
18 |
retriever: Retriever
|
19 |
metadata_path: str
|
@@ -35,9 +34,9 @@ class GlobalEmbedderGuessr(AbstractGuessr):
|
|
35 |
image = Image.fromarray(image)
|
36 |
image_embedding = self.embedder.embed(image)[None, :]
|
37 |
|
38 |
-
# Retrieve nearest
|
39 |
nearest_neighbors = self.retriever.retrieve(image_embedding)
|
40 |
-
nearest_neighbor = nearest_neighbors[0][0]
|
41 |
|
42 |
# Guess coordinate
|
43 |
guess_coordinate = self.image_to_coordinate[nearest_neighbor]
|
|
|
10 |
|
11 |
|
12 |
@dataclass
|
13 |
+
class NearestNeighborEmbedderGuessr(AbstractGuessr):
|
14 |
+
"""Guesses a coordinate using an Embedder and a retriever followed by NN.
|
15 |
"""
|
|
|
16 |
embedder: AbstractImageEmbedder
|
17 |
retriever: Retriever
|
18 |
metadata_path: str
|
|
|
34 |
image = Image.fromarray(image)
|
35 |
image_embedding = self.embedder.embed(image)[None, :]
|
36 |
|
37 |
+
# Retrieve nearest neighbor
|
38 |
nearest_neighbors = self.retriever.retrieve(image_embedding)
|
39 |
+
nearest_neighbor = nearest_neighbors[0][0][0]
|
40 |
|
41 |
# Guess coordinate
|
42 |
guess_coordinate = self.image_to_coordinate[nearest_neighbor]
|
geoguessr_bot/interfaces.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from pydantic.main import BaseModel
|
2 |
|
3 |
|
@@ -7,3 +8,16 @@ class Coordinate(BaseModel):
|
|
7 |
|
8 |
def __str__(self):
|
9 |
return f"({round(self.latitude, 6)}, {round(self.longitude, 6)})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
from pydantic.main import BaseModel
|
3 |
|
4 |
|
|
|
8 |
|
9 |
def __str__(self):
|
10 |
return f"({round(self.latitude, 6)}, {round(self.longitude, 6)})"
|
11 |
+
|
12 |
+
def to_radians(self) -> 'Coordinate':
|
13 |
+
return Coordinate(
|
14 |
+
latitude=self.latitude * np.pi / 180.,
|
15 |
+
longitude=self.longitude * np.pi / 180.
|
16 |
+
)
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def from_radians(latitude: float, longitude: float) -> 'Coordinate':
|
20 |
+
return Coordinate(
|
21 |
+
latitude=latitude * 180. / np.pi,
|
22 |
+
longitude=longitude * 180. / np.pi
|
23 |
+
)
|
geoguessr_bot/retriever/retriever.py
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
-
from typing import Dict, List
|
2 |
|
3 |
import numpy as np
|
4 |
import faiss
|
5 |
|
6 |
|
7 |
class Retriever:
|
8 |
-
def __init__(self, embeddings_path: str
|
9 |
self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)
|
10 |
-
self.n_neighbors = n_neighbors
|
11 |
|
12 |
# Keep track of image names
|
13 |
self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
|
@@ -25,8 +24,8 @@ class Retriever:
|
|
25 |
"""
|
26 |
return np.load(embeddings_path, allow_pickle=True).item()
|
27 |
|
28 |
-
def retrieve(self, queries: np.ndarray) -> List[List[str]]:
|
29 |
"""Retrieve nearest neighbors indexes from queries
|
30 |
"""
|
31 |
-
|
32 |
-
return [[self.index_to_image[i] for i in index] for index in indexes]
|
|
|
1 |
+
from typing import Dict, List, Tuple
|
2 |
|
3 |
import numpy as np
|
4 |
import faiss
|
5 |
|
6 |
|
7 |
class Retriever:
|
8 |
+
def __init__(self, embeddings_path: str):
|
9 |
self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)
|
|
|
10 |
|
11 |
# Keep track of image names
|
12 |
self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
|
|
|
24 |
"""
|
25 |
return np.load(embeddings_path, allow_pickle=True).item()
|
26 |
|
27 |
+
def retrieve(self, queries: np.ndarray, n_neighbors: int = 5) -> Tuple[List[List[str]], List[List[float]]]:
|
28 |
"""Retrieve nearest neighbors indexes from queries
|
29 |
"""
|
30 |
+
distances, indexes = self.index.search(queries, n_neighbors)
|
31 |
+
return [[self.index_to_image[i] for i in index] for index in indexes], distances
|
requirements.txt
CHANGED
@@ -12,3 +12,4 @@ torchvision
|
|
12 |
tqdm
|
13 |
configue
|
14 |
fire
|
|
|
|
12 |
tqdm
|
13 |
configue
|
14 |
fire
|
15 |
+
scikit-learn
|