Spaces:
Runtime error
Runtime error
File size: 3,240 Bytes
d9b35be dfbe385 1791df2 ed8157d dfbe385 4388025 1791df2 dfbe385 bd2d69d 1791df2 bd2d69d ed8157d bd2d69d d9b35be 4cf404a 1791df2 ed8157d 1791df2 ed8157d 1791df2 bd2d69d dfbe385 06ee965 dfbe385 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os.path
import numpy as np
import gradio as gr
import plotly.graph_objects as go
from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, NearestNeighborEmbedderGuessr, \
AverageNeighborsEmbedderGuessr
from geoguessr_bot.retriever import DinoV2Embedder, Retriever, RandomEmbedder
ALL_GUESSR_CLASS = {
"random": RandomGuessr,
"nearestNeighborEmbedder": NearestNeighborEmbedderGuessr,
"averageNeighborsEmbedder": AverageNeighborsEmbedderGuessr,
}
ALL_GUESSR_ARGS = {
"random": {},
"nearestNeighborEmbedder": {
"embedder": DinoV2Embedder(
device="cpu"
),
# "embedder": RandomEmbedder(n_dim=384),
"retriever": Retriever(
embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
"resources/embeddings.npy"),
),
"metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
"resources/metadatav3.csv"),
},
"averageNeighborsEmbedder": {
"embedder": DinoV2Embedder(
device="cpu"
),
# "embedder": RandomEmbedder(n_dim=384),
"retriever": Retriever(
embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
"resources/embeddings.npy"),
),
"metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
"resources/metadatav3.csv"),
"n_neighbors": 100,
"dbscan_eps": 0.5
}
}
# For instantiating guessrs only when needed
ALL_GUESSR = {}
def create_map(guessr: str) -> go.Figure:
"""Create an interactive map
"""
# Instantiate guessr if not already done
if guessr not in ALL_GUESSR:
ALL_GUESSR[guessr] = ALL_GUESSR_CLASS[guessr](**ALL_GUESSR_ARGS[guessr])
return AbstractGuessr.create_map()
def guess(guessr: str, uploaded_image) -> go.Figure:
"""Guess a coordinate from an image uploaded in the Gradio interface
"""
# Instantiate guessr if not already done
if guessr not in ALL_GUESSR:
ALL_GUESSR[guessr] = ALL_GUESSR_CLASS[guessr](**ALL_GUESSR_ARGS[guessr])
# Convert image to numpy array
uploaded_image = np.array(uploaded_image)
# Guess coordinate
guess_coordinate = ALL_GUESSR[guessr].guess(uploaded_image)
# Create map
fig = ALL_GUESSR[guessr].create_map(guess_coordinate)
return fig
if __name__ == "__main__":
# Create & launch Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
guessr_dropdown = gr.Dropdown(
list(ALL_GUESSR_CLASS.keys()),
value="nearestNeighborEmbedder",
label="Guessr type",
info="More Guessr types will be added soon!"
)
image = gr.Image(shape=(800, 800))
button = gr.Button(text="Guess")
interactive_map = gr.Plot()
demo.load(create_map, [guessr_dropdown], interactive_map)
button.click(guess, [guessr_dropdown, image], interactive_map)
# Launch demo π
demo.launch()
|