File size: 3,849 Bytes
cb64143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import plotly.express as px
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from PIL import Image
from io import BytesIO

def generate_plot(
    x_sequence: str, 
    y_sequence: str, 
    plot_type: str, 
    x_label: str, 
    y_label: str, 
    width: int, 
    height: int
) -> Image:
    """
    Generate a plot based on the provided x and y sequences and plot type.

    Parameters:
    - x_sequence (str): A comma-separated string of x values.
    - y_sequence (str): A comma-separated string of y values.
    - plot_type (str): The type of plot to generate ('Bar', 'Scatter', 'Confusion Matrix').
    - x_label (str): Label for the x-axis.
    - y_label (str): Label for the y-axis.
    - width (int): Width of the plot.
    - height (int): Height of the plot.

    Returns:
    - Image: A PIL Image object of the generated plot.
    """
    # Convert the input sequences to lists of numbers
    try:
        x_data = list(map(float, x_sequence.split(",")))
        y_data = list(map(float, y_sequence.split(",")))
    except ValueError:
        return "Invalid input. Please enter sequences of numbers separated by commas."

    # Ensure the x and y sequences have the same length
    if len(x_data) != len(y_data):
        return "The x and y sequences must have the same length."

    # Create a DataFrame for plotting
    df = pd.DataFrame({"x": x_data, "y": y_data})

    # Set default width and height if not provided
    width = width if width else 800
    height = height if height else 600

    # Generate the plot based on the selected type
    if plot_type == "Bar":
        fig = px.bar(
            df,
            x="x",
            y="y",
            title="Bar Plot",
            labels={"x": x_label, "y": y_label},
            width=width,
            height=height,
        )
    elif plot_type == "Scatter":
        fig = px.scatter(
            df,
            x="x",
            y="y",
            title="Scatter Plot",
            labels={"x": x_label, "y": y_label},
            width=width,
            height=height,
        )
    elif plot_type == "Confusion Matrix":
        # For demonstration, create a confusion matrix from the sequence
        y_true = np.random.randint(0, 2, len(y_data))
        y_pred = np.array(y_data) > 0.5
        cm = confusion_matrix(y_true, y_pred)
        fig = px.imshow(
            cm, text_auto=True, title="Confusion Matrix", width=width, height=height
        )
    else:
        return "Invalid plot type selected."

    # Convert the plot to a PNG image
    img_bytes = fig.to_image(
        format="png", width=width, height=height, scale=2, engine="kaleido"
    )
    return Image.open(BytesIO(img_bytes))


# Define the Gradio interface using the new syntax
app = gr.Interface(
    fn=generate_plot,
    inputs=[
        gr.Textbox(
            lines=2,
            placeholder="Enter x sequence of numbers separated by commas",
            label="X",
        ),
        gr.Textbox(
            lines=2,
            placeholder="Enter y sequence of numbers separated by commas",
            label="Y",
        ),
        gr.Radio(["Bar", "Scatter", "Confusion Matrix"], label="Type", value="Bar"),
        gr.Textbox(
            placeholder="Enter x-axis label (optional)", label="X_Label", value=""
        ),
        gr.Textbox(
            placeholder="Enter y-axis label (optional)", label="Y_Label", value=""
        ),
        gr.Number(
            value=800,
            label="Width",
        ),
        gr.Number(value=600, label="Height"),
    ],
    outputs=gr.Image(type="pil", label="Generated Plot"),
    title="Plotly Plot Generator",
    description="Generate plots using Plotly based on inputted sequences. Choose from Bar, Scatter, or Confusion Matrix plots.",
)

# Launch the app
app.launch()