import gradio as gr import plotly.express as px import numpy as np import pandas as pd from sklearn.metrics import confusion_matrix from PIL import Image from io import BytesIO def generate_plot( x_sequence: str, y_sequence: str, plot_type: str, x_label: str, y_label: str, width: int, height: int ) -> Image: """ Generate a plot based on the provided x and y sequences and plot type. Parameters: - x_sequence (str): A comma-separated string of x values. - y_sequence (str): A comma-separated string of y values. - plot_type (str): The type of plot to generate ('Bar', 'Scatter', 'Confusion Matrix'). - x_label (str): Label for the x-axis. - y_label (str): Label for the y-axis. - width (int): Width of the plot. - height (int): Height of the plot. Returns: - Image: A PIL Image object of the generated plot. """ # Convert the input sequences to lists of numbers try: x_data = list(map(float, x_sequence.split(","))) y_data = list(map(float, y_sequence.split(","))) except ValueError: return "Invalid input. Please enter sequences of numbers separated by commas." # Ensure the x and y sequences have the same length if len(x_data) != len(y_data): return "The x and y sequences must have the same length." # Create a DataFrame for plotting df = pd.DataFrame({"x": x_data, "y": y_data}) # Set default width and height if not provided width = width if width else 800 height = height if height else 600 # Generate the plot based on the selected type if plot_type == "Bar": fig = px.bar( df, x="x", y="y", title="Bar Plot", labels={"x": x_label, "y": y_label}, width=width, height=height, ) elif plot_type == "Scatter": fig = px.scatter( df, x="x", y="y", title="Scatter Plot", labels={"x": x_label, "y": y_label}, width=width, height=height, ) elif plot_type == "Confusion Matrix": # For demonstration, create a confusion matrix from the sequence y_true = np.random.randint(0, 2, len(y_data)) y_pred = np.array(y_data) > 0.5 cm = confusion_matrix(y_true, y_pred) fig = px.imshow( cm, text_auto=True, title="Confusion Matrix", width=width, height=height ) else: return "Invalid plot type selected." # Convert the plot to a PNG image img_bytes = fig.to_image( format="png", width=width, height=height, scale=2, engine="kaleido" ) return Image.open(BytesIO(img_bytes)) # Define the Gradio interface using the new syntax app = gr.Interface( fn=generate_plot, inputs=[ gr.Textbox( lines=2, placeholder="Enter x sequence of numbers separated by commas", label="X", ), gr.Textbox( lines=2, placeholder="Enter y sequence of numbers separated by commas", label="Y", ), gr.Radio(["Bar", "Scatter", "Confusion Matrix"], label="Type", value="Bar"), gr.Textbox( placeholder="Enter x-axis label (optional)", label="X_Label", value="" ), gr.Textbox( placeholder="Enter y-axis label (optional)", label="Y_Label", value="" ), gr.Number( value=800, label="Width", ), gr.Number(value=600, label="Height"), ], outputs=gr.Image(type="pil", label="Generated Plot"), title="Plotly Plot Generator", description="Generate plots using Plotly based on inputted sequences. Choose from Bar, Scatter, or Confusion Matrix plots.", ) # Launch the app app.launch()