import os.path import numpy as np import gradio as gr import plotly.graph_objects as go from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, NearestNeighborEmbedderGuessr, \ AverageNeighborsEmbedderGuessr from geoguessr_bot.retriever import DinoV2Embedder, Retriever, RandomEmbedder ALL_GUESSR_CLASS = { "random": RandomGuessr, "nearestNeighborEmbedder": NearestNeighborEmbedderGuessr, "averageNeighborsEmbedder": AverageNeighborsEmbedderGuessr, } ALL_GUESSR_ARGS = { "random": {}, "nearestNeighborEmbedder": { "embedder": DinoV2Embedder( device="cpu" ), # "embedder": RandomEmbedder(n_dim=384), "retriever": Retriever( embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/embeddings.npy"), ), "metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/metadatav3.csv"), }, "averageNeighborsEmbedder": { "embedder": DinoV2Embedder( device="cpu" ), # "embedder": RandomEmbedder(n_dim=384), "retriever": Retriever( embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/embeddings.npy"), ), "metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/metadatav3.csv"), "n_neighbors": 100, "dbscan_eps": 0.5 } } # For instantiating guessrs only when needed ALL_GUESSR = {} def create_map(guessr: str) -> go.Figure: """Create an interactive map """ # Instantiate guessr if not already done if guessr not in ALL_GUESSR: ALL_GUESSR[guessr] = ALL_GUESSR_CLASS[guessr](**ALL_GUESSR_ARGS[guessr]) return AbstractGuessr.create_map() def guess(guessr: str, uploaded_image) -> go.Figure: """Guess a coordinate from an image uploaded in the Gradio interface """ # Instantiate guessr if not already done if guessr not in ALL_GUESSR: ALL_GUESSR[guessr] = ALL_GUESSR_CLASS[guessr](**ALL_GUESSR_ARGS[guessr]) # Convert image to numpy array uploaded_image = np.array(uploaded_image) # Guess coordinate guess_coordinate = ALL_GUESSR[guessr].guess(uploaded_image) # Create map fig = ALL_GUESSR[guessr].create_map(guess_coordinate) return fig if __name__ == "__main__": # Create & launch Gradio interface with gr.Blocks() as demo: with gr.Row(): with gr.Column(): guessr_dropdown = gr.Dropdown( list(ALL_GUESSR_CLASS.keys()), value="nearestNeighborEmbedder", label="Guessr type", info="More Guessr types will be added soon!" ) image = gr.Image(shape=(800, 800)) button = gr.Button(text="Guess") interactive_map = gr.Plot() demo.load(create_map, [guessr_dropdown], interactive_map) button.click(guess, [guessr_dropdown, image], interactive_map) # Launch demo 🚀 demo.launch()