Spaces:
Runtime error
Runtime error
import dataclasses | |
import enum | |
import functools | |
import json | |
import os | |
import re | |
import types | |
from typing import Callable | |
import einops | |
import imageio | |
import numpy as np | |
import torch.utils.data | |
import torchvision | |
import tqdm | |
from config import CONFIG | |
from utils import load_pickle_or_build_object_and_save | |
class Source(enum.Enum): | |
generated = "generated" | |
extracted = "extracted" | |
class ChartType(enum.Enum): | |
dot = "dot" | |
horizontal_bar = "horizontal_bar" | |
vertical_bar = "vertical_bar" | |
line = "line" | |
scatter = "scatter" | |
class PlotBoundingBox: | |
height: int | |
width: int | |
x0: int | |
y0: int | |
def get_bounds(self): | |
xs = [self.x0, self.x0 + self.width, self.x0 + self.width, self.x0, self.x0] | |
ys = [self.y0, self.y0, self.y0 + self.height, self.y0 + self.height, self.y0] | |
return xs, ys | |
class DataPoint: | |
x: float or str | |
y: float or str | |
class TextRole(enum.Enum): | |
axis_title = "axis_title" | |
chart_title = "chart_title" | |
legend_label = "legend_label" | |
tick_grouping = "tick_grouping" | |
tick_label = "tick_label" | |
other = "other" | |
class Polygon: | |
x0: int | |
x1: int | |
x2: int | |
x3: int | |
y0: int | |
y1: int | |
y2: int | |
y3: int | |
def get_bounds(self): | |
xs = [ | |
self.x0, | |
self.x1, | |
self.x2, | |
self.x3, | |
self.x0, | |
] | |
ys = [ | |
self.y0, | |
self.y1, | |
self.y2, | |
self.y3, | |
self.y0, | |
] | |
return xs, ys | |
class Text: | |
id: int | |
polygon: Polygon | |
role: TextRole | |
text: str | |
def __post_init__(self): | |
self.polygon = Polygon(**self.polygon) | |
self.role = TextRole(self.role) | |
class ValuesType(enum.Enum): | |
categorical = "categorical" | |
numerical = "numerical" | |
class Tick: | |
id: int | |
x: int | |
y: int | |
class TickType(enum.Enum): | |
markers = "markers" | |
separators = "separators" | |
class Axis: | |
values_type: ValuesType | |
tick_type: TickType | |
ticks: list[Tick] | |
def __post_init__(self): | |
self.values_type = ValuesType(self.values_type) | |
self.tick_type = TickType(self.tick_type) | |
self.ticks = [ | |
Tick(id=kw["id"], x=kw["tick_pt"]["x"], y=kw["tick_pt"]["y"]) | |
for kw in self.ticks | |
] | |
def get_bounds(self): | |
min_x = min(tick.x for tick in self.ticks) | |
max_x = max(tick.x for tick in self.ticks) | |
min_y = min(tick.y for tick in self.ticks) | |
max_y = max(tick.y for tick in self.ticks) | |
xs = [min_x, max_x, max_x, min_x, min_x] | |
ys = [min_y, min_y, max_y, max_y, min_y] | |
return xs, ys | |
def convert_dashes_to_underscores_in_key_names(dictionary): | |
return {k.replace("-", "_"): v for k, v in dictionary.items()} | |
class Axes: | |
x_axis: Axis | |
y_axis: Axis | |
def __post_init__(self): | |
self.x_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.x_axis)) | |
self.y_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.y_axis)) | |
def preprocess_numerical_value(value): | |
value = float(value) | |
value = 0 if np.isnan(value) else value | |
return value | |
def preprocess_value(value, value_type: ValuesType): | |
if value_type == ValuesType.numerical: | |
return preprocess_numerical_value(value) | |
else: | |
return str(value) | |
class Annotation: | |
source: Source | |
chart_type: ChartType | |
plot_bb: PlotBoundingBox | |
text: list[Text] | |
axes: Axes | |
data_series: list[DataPoint] | |
def __post_init__(self): | |
self.source = Source(self.source) | |
self.chart_type = ChartType(self.chart_type) | |
self.plot_bb = PlotBoundingBox(**self.plot_bb) | |
self.text = [Text(**kw) for kw in self.text] | |
self.axes = Axes(**convert_dashes_to_underscores_in_key_names(self.axes)) | |
self.data_series = [DataPoint(**kw) for kw in self.data_series] | |
for i in range(len(self.data_series)): | |
self.data_series[i].x = preprocess_value( | |
self.data_series[i].x, self.axes.x_axis.values_type | |
) | |
self.data_series[i].y = preprocess_value( | |
self.data_series[i].y, self.axes.y_axis.values_type | |
) | |
def from_dict_with_dashes(kwargs): | |
return Annotation(**convert_dashes_to_underscores_in_key_names(kwargs)) | |
def from_image_index(image_index: int): | |
image_id = load_train_image_ids()[image_index] | |
return Annotation.from_dict_with_dashes(load_image_annotation(image_id)) | |
def get_text_by_role(self, text_role: TextRole) -> list[Text]: | |
return [t for t in self.text if t.role == text_role] | |
class AnnotatedImage: | |
id: str | |
image: np.ndarray | |
annotation: Annotation | |
def from_image_id(image_id: str): | |
return AnnotatedImage( | |
id=image_id, | |
image=load_image(image_id), | |
annotation=Annotation.from_dict_with_dashes( | |
load_image_annotation(image_id) | |
), | |
) | |
def from_image_index(image_index: int): | |
return AnnotatedImage.from_image_id(load_train_image_ids()[image_index]) | |
def generate_annotated_images(): | |
for image_id in tqdm.autonotebook.tqdm( | |
load_train_image_ids(), "Iterating over annotated images" | |
): | |
yield AnnotatedImage.from_image_id(image_id) | |
def load_train_image_ids() -> list[str]: | |
train_image_ids = [i.replace(".jpg", "") for i in os.listdir("data/train/images")] | |
return train_image_ids[: 1000 if CONFIG.debug else None] | |
def load_test_image_ids() -> list[str]: | |
return [i.replace(".jpg", "") for i in os.listdir("data/test/images")] | |
def load_image_annotation(image_id: str) -> dict: | |
return json.load(open(f"data/train/annotations/{image_id}.json")) | |
def load_image(image_id: str) -> np.ndarray: | |
return imageio.v3.imread(open(f"data/train/images/{image_id}.jpg", "rb")) | |
class DataItem: | |
image: torch.FloatTensor | |
target_string: str | |
data_index: int | |
def __post_init__(self): | |
shape = einops.parse_shape(self.image, "channel height width") | |
assert shape["channel"] == 3, "Image is expected to have 3 channels." | |
def split_train_indices_by_source(): | |
extracted_image_indices = [] | |
generated_image_indices = [] | |
for i, annotated_image in enumerate(generate_annotated_images()): | |
if annotated_image.annotation.source == Source.extracted: | |
extracted_image_indices.append(i) | |
else: | |
generated_image_indices.append(i) | |
return extracted_image_indices, generated_image_indices | |
def get_train_val_split_indices(val_fraction=0.1, seed=42): | |
np.random.seed(seed) | |
val_size = int(len(load_train_image_ids()) * val_fraction) | |
extracted_image_indices, generated_image_indices = split_train_indices_by_source() | |
extracted_image_indices = np.random.permutation(extracted_image_indices) | |
generated_image_indices = np.random.permutation(generated_image_indices) | |
val_indices = extracted_image_indices[:val_size] | |
n_generated_images_in_val = val_size - len(val_indices) | |
val_indices = np.concatenate( | |
[val_indices, generated_image_indices[:n_generated_images_in_val]] | |
) | |
train_indices = generated_image_indices[n_generated_images_in_val:] | |
assert len(set(train_indices) | set(val_indices)) == len(load_train_image_ids()) | |
assert len(val_indices) == val_size | |
assert len(set(train_indices) & set(val_indices)) == 0 | |
return train_indices, val_indices | |
def to_token_str(value: str or enum.Enum): | |
string = value.name if isinstance(value, enum.Enum) else value | |
if re.fullmatch("<.*>", string): | |
return string | |
else: | |
return f"<{string}>" | |
def get_extra_tokens() -> types.SimpleNamespace: | |
token_ns = types.SimpleNamespace() | |
token_ns.benetech_prompt = to_token_str("benetech_prompt") | |
token_ns.benetech_prompt_end = to_token_str("/benetech_prompt") | |
token_ns.x_start = to_token_str("x_start") | |
token_ns.y_start = to_token_str("y_start") | |
token_ns.value_separator = to_token_str(";") | |
for chart_type in ChartType: | |
setattr(token_ns, chart_type.name, to_token_str(chart_type)) | |
for values_type in ValuesType: | |
setattr(token_ns, values_type.name, to_token_str(values_type)) | |
return token_ns | |
def convert_number_to_scientific_string(value: int or float) -> str: | |
return f"{value:.{CONFIG.float_scientific_notation_string_precision}e}" | |
def convert_axis_data_to_string( | |
axis_data: list[str or float], values_type: ValuesType | |
) -> str: | |
formatted_axis_data = [] | |
for value in axis_data: | |
if values_type == ValuesType.numerical: | |
value = convert_number_to_scientific_string(value) | |
formatted_axis_data.append(value) | |
return get_extra_tokens().value_separator.join(formatted_axis_data) | |
def convert_string_to_axis_data(string, values_type: ValuesType): | |
data = string.split(get_extra_tokens().value_separator) | |
if values_type == ValuesType.numerical: | |
data = [float(i.replace(" ", "")) for i in data] | |
return data | |
class BenetechOutput: | |
chart_type: ChartType | |
x_values_type: ValuesType | |
y_values_type: ValuesType | |
x_data: list[str or float] | |
y_data: list[str or float] | |
def __post_init__(self): | |
self.chart_type = ChartType(self.chart_type) | |
self.x_values_type = ValuesType(self.x_values_type) | |
self.y_values_type = ValuesType(self.y_values_type) | |
assert isinstance(self.x_data, list) | |
assert isinstance(self.y_data, list) | |
def get_main_characteristics(self): | |
return ( | |
self.chart_type, | |
self.x_values_type, | |
self.y_values_type, | |
len(self.x_data), | |
len(self.y_data), | |
) | |
def from_annotation(annotation: Annotation): | |
return BenetechOutput( | |
chart_type=annotation.chart_type, | |
x_values_type=annotation.axes.x_axis.values_type, | |
y_values_type=annotation.axes.y_axis.values_type, | |
x_data=[dp.x for dp in annotation.data_series], | |
y_data=[dp.y for dp in annotation.data_series], | |
) | |
def to_string(self): | |
return self.format_strings( | |
chart_type=self.chart_type, | |
x_values_type=self.x_values_type, | |
y_values_type=self.y_values_type, | |
x_data=convert_axis_data_to_string(self.x_data, self.x_values_type), | |
y_data=convert_axis_data_to_string(self.y_data, self.y_values_type), | |
) | |
def format_strings(*, chart_type, x_values_type, y_values_type, x_data, y_data): | |
chart_type = to_token_str(chart_type) | |
x_values_type = to_token_str(x_values_type) | |
y_values_type = to_token_str(y_values_type) | |
token_ns = get_extra_tokens() | |
return ( | |
f"{token_ns.benetech_prompt}{chart_type}" | |
f"{token_ns.x_start}{x_values_type}{x_data}" | |
f"{token_ns.y_start}{y_values_type}{y_data}" | |
f"{token_ns.benetech_prompt_end}" | |
) | |
def get_string_pattern(): | |
field_names = [field.name for field in dataclasses.fields(BenetechOutput)] | |
pattern = BenetechOutput.format_strings( | |
**{field_name: f"(?P<{field_name}>.*?)" for field_name in field_names} | |
) | |
return pattern | |
def does_string_match_expected_pattern(string): | |
try: | |
BenetechOutput.from_string(string) | |
return True | |
except: | |
return False | |
def from_string(string): | |
fullmatch = re.fullmatch(BenetechOutput.get_string_pattern(), string) | |
benetech_kwargs = fullmatch.groupdict() | |
benetech_kwargs["chart_type"] = ChartType(benetech_kwargs["chart_type"]) | |
benetech_kwargs["x_values_type"] = ValuesType(benetech_kwargs["x_values_type"]) | |
benetech_kwargs["y_values_type"] = ValuesType(benetech_kwargs["y_values_type"]) | |
benetech_kwargs["x_data"] = convert_string_to_axis_data( | |
benetech_kwargs["x_data"], benetech_kwargs["x_values_type"] | |
) | |
benetech_kwargs["y_data"] = convert_string_to_axis_data( | |
benetech_kwargs["y_data"], benetech_kwargs["y_values_type"] | |
) | |
return BenetechOutput(**benetech_kwargs) | |
def get_annotation_ground_truth_str(annotation: Annotation): | |
benetech_output = BenetechOutput( | |
chart_type=annotation.chart_type, | |
x_values_type=annotation.axes.x_axis.values_type, | |
x_data=[dp.x for dp in annotation.data_series], | |
y_values_type=annotation.axes.y_axis.values_type, | |
y_data=[dp.y for dp in annotation.data_series], | |
) | |
return benetech_output.to_string() | |
def get_annotation_ground_truth_str_from_image_index(image_index: int) -> str: | |
return get_annotation_ground_truth_str(Annotation.from_image_index(image_index)) | |
class Dataset(torch.utils.data.Dataset): | |
def __init__(self, indices: list[int]): | |
super().__init__() | |
self.indices = indices | |
self.to_tensor = torchvision.transforms.ToTensor() | |
def __len__(self): | |
return len(self.indices) | |
def __getitem__(self, idx: int) -> DataItem: | |
data_index = self.indices[idx] | |
annotated_image = AnnotatedImage.from_image_index(data_index) | |
image = annotated_image.image | |
image = self.to_tensor(image) | |
target_string = get_annotation_ground_truth_str(annotated_image.annotation) | |
return DataItem(image=image, target_string=target_string, data_index=data_index) | |
def get_train_val_datasets(): | |
train_indices, val_indices = load_pickle_or_build_object_and_save( | |
CONFIG.train_val_indices_path, | |
lambda: get_train_val_split_indices(CONFIG.val_fraction, CONFIG.seed), | |
) | |
return Dataset(train_indices), Dataset(val_indices) | |
def get_train_dataset(): | |
return get_train_val_datasets()[0] | |
def get_val_dataset(): | |
return get_train_val_datasets()[1] | |
class Batch: | |
images: torch.FloatTensor | |
labels: torch.IntTensor | |
data_indices: list[int] | |
def __post_init__(self): | |
if CONFIG.debug: | |
images_shape = einops.parse_shape(self.images, "batch channel height width") | |
labels_shape = einops.parse_shape(self.labels, "batch label") | |
assert images_shape["batch"] == labels_shape["batch"] | |
assert len(self.data_indices) == images_shape["batch"] | |
class Split(enum.Enum): | |
train = "train" | |
val = "val" | |
BatchCollateFunction = Callable[[list[DataItem], Split], Batch] | |
def build_dataloader(split: Split, batch_collate_function: BatchCollateFunction): | |
return torch.utils.data.DataLoader( | |
get_train_dataset() if split == Split.train else get_val_dataset(), | |
batch_size=CONFIG.batch_size, | |
shuffle=split == Split.train, | |
num_workers=CONFIG.num_workers, | |
collate_fn=functools.partial(batch_collate_function, split=split), | |
) | |