Spaces:
Running
Running
File size: 3,849 Bytes
cb64143 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import plotly.express as px
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from PIL import Image
from io import BytesIO
def generate_plot(
x_sequence: str,
y_sequence: str,
plot_type: str,
x_label: str,
y_label: str,
width: int,
height: int
) -> Image:
"""
Generate a plot based on the provided x and y sequences and plot type.
Parameters:
- x_sequence (str): A comma-separated string of x values.
- y_sequence (str): A comma-separated string of y values.
- plot_type (str): The type of plot to generate ('Bar', 'Scatter', 'Confusion Matrix').
- x_label (str): Label for the x-axis.
- y_label (str): Label for the y-axis.
- width (int): Width of the plot.
- height (int): Height of the plot.
Returns:
- Image: A PIL Image object of the generated plot.
"""
# Convert the input sequences to lists of numbers
try:
x_data = list(map(float, x_sequence.split(",")))
y_data = list(map(float, y_sequence.split(",")))
except ValueError:
return "Invalid input. Please enter sequences of numbers separated by commas."
# Ensure the x and y sequences have the same length
if len(x_data) != len(y_data):
return "The x and y sequences must have the same length."
# Create a DataFrame for plotting
df = pd.DataFrame({"x": x_data, "y": y_data})
# Set default width and height if not provided
width = width if width else 800
height = height if height else 600
# Generate the plot based on the selected type
if plot_type == "Bar":
fig = px.bar(
df,
x="x",
y="y",
title="Bar Plot",
labels={"x": x_label, "y": y_label},
width=width,
height=height,
)
elif plot_type == "Scatter":
fig = px.scatter(
df,
x="x",
y="y",
title="Scatter Plot",
labels={"x": x_label, "y": y_label},
width=width,
height=height,
)
elif plot_type == "Confusion Matrix":
# For demonstration, create a confusion matrix from the sequence
y_true = np.random.randint(0, 2, len(y_data))
y_pred = np.array(y_data) > 0.5
cm = confusion_matrix(y_true, y_pred)
fig = px.imshow(
cm, text_auto=True, title="Confusion Matrix", width=width, height=height
)
else:
return "Invalid plot type selected."
# Convert the plot to a PNG image
img_bytes = fig.to_image(
format="png", width=width, height=height, scale=2, engine="kaleido"
)
return Image.open(BytesIO(img_bytes))
# Define the Gradio interface using the new syntax
app = gr.Interface(
fn=generate_plot,
inputs=[
gr.Textbox(
lines=2,
placeholder="Enter x sequence of numbers separated by commas",
label="X",
),
gr.Textbox(
lines=2,
placeholder="Enter y sequence of numbers separated by commas",
label="Y",
),
gr.Radio(["Bar", "Scatter", "Confusion Matrix"], label="Type", value="Bar"),
gr.Textbox(
placeholder="Enter x-axis label (optional)", label="X_Label", value=""
),
gr.Textbox(
placeholder="Enter y-axis label (optional)", label="Y_Label", value=""
),
gr.Number(
value=800,
label="Width",
),
gr.Number(value=600, label="Height"),
],
outputs=gr.Image(type="pil", label="Generated Plot"),
title="Plotly Plot Generator",
description="Generate plots using Plotly based on inputted sequences. Choose from Bar, Scatter, or Confusion Matrix plots.",
)
# Launch the app
app.launch()
|