{
"cells": [
{
"cell_type": "markdown",
"id": "6c8f222f",
"metadata": {},
"source": [
"## [Description](#Description_)\n",
"## [Todo](#Todo_)\n",
"## [Research](#Research_)\n",
"## [Setup](#Setup_)\n",
"### - [Requirements](#Requirements_)\n",
"### - [Imports](#Imports_)\n",
"### - [Globals](#Globals_)\n",
"### - [Utils](#Utils_)\n",
"## [Data](#Data_)\n",
"### - [Annotation structure](#Annotation_structure_)\n",
"### - [Data exploration](#Data_exploration_)\n",
"### - [Data splits](#Data_splits_)\n",
"### - [Expected model output format](#Expected_model_output_format_)\n",
"### - [Metrics](#Metrics_)\n",
"### - [Dataset](#Dataset_)\n",
"## [Model](#Model_)\n",
"### - [Add task specific tokens](#Add_task_specific_tokens_)\n",
"### - [Add dataset specific tokens](#Add_dataset_specific_tokens_)\n",
"### - [Predicting](#Predicting_)\n",
"### - [Dataloader](#Dataloader_)\n",
"### - [Lightning module](#Lightning_module_)\n",
"### - [Callbacks](#Callbacks_)\n",
"## [Training](#Training_)\n",
"## [Results](#Results_)\n",
"### - [Gradio interface](#Gradio_interface_)"
]
},
{
"cell_type": "markdown",
"id": "3b03d3cc",
"metadata": {},
"source": [
"## Description "
]
},
{
"cell_type": "markdown",
"id": "64776daa",
"metadata": {},
"source": [
"Trying my hand at this kaggle challenge:\n",
"\n",
"https://www.kaggle.com/competitions/benetech-making-graphs-accessible"
]
},
{
"cell_type": "markdown",
"id": "82bdf04d",
"metadata": {},
"source": [
"## Todo "
]
},
{
"cell_type": "markdown",
"id": "965d48df",
"metadata": {},
"source": [
"- Add wandb logs: metrics, images, text\n",
"- Create separate training script\n",
"- Train\n",
"- Get familiar with transformers library: main classes, how to work with config\n",
"- Do more research, check out notebooks in kaggle\n",
"- Check out dataset https://chartinfo.github.io/toolsanddata.html\n",
"- Try segmentation -> classification -> parsing pipeline\n",
"- Make predicting faster, check out https://pytorch.org/serve/"
]
},
{
"cell_type": "markdown",
"id": "49fece33",
"metadata": {},
"source": [
"## Research "
]
},
{
"cell_type": "markdown",
"id": "0940fdc8",
"metadata": {},
"source": [
"[Donut](https://arxiv.org/pdf/2111.15664.pdf) - document understanding transformer without the intermediate optical character recognition step.\n",
"[Example notebook one](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Donut/CORD/Fine_tune_Donut_on_a_custom_dataset_(CORD)_with_PyTorch_Lightning.ipynb),\n",
"[example notebook two](https://www.kaggle.com/code/nbroad/donut-train-benetech)."
]
},
{
"cell_type": "markdown",
"id": "d9064993",
"metadata": {},
"source": [
"## Setup "
]
},
{
"cell_type": "markdown",
"id": "47af4f6b",
"metadata": {},
"source": [
"### Imports "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8ccdc3b0",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-27T13:03:49.524541Z",
"start_time": "2023-04-27T13:03:29.372899Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 2;\n",
" var nbb_unformatted_code = \"%load_ext nb_black\\n%matplotlib inline\\n\\n\\nimport collections\\nimport dataclasses\\nimport datasets\\nimport einops\\nimport enum\\nimport gradio\\nimport glob\\nimport IPython\\nimport imageio\\nimport json\\nimport functools\\nimport matplotlib.animation\\nimport matplotlib.pyplot as plt\\nimport numpy as np\\nimport os\\nimport PIL\\nimport pandas as pd\\nimport pickle\\nimport pprint\\nimport pytorch_lightning as pl\\nimport rapidfuzz\\nimport re\\nimport reprlib\\nimport sklearn.metrics\\nimport torch\\nimport torchvision\\nimport tqdm.autonotebook\\nimport transformers\\nimport types\\nfrom typing import Callable, Literal\\nimport wandb\";\n",
" var nbb_formatted_code = \"%load_ext nb_black\\n%matplotlib inline\\n\\n\\nimport collections\\nimport dataclasses\\nimport datasets\\nimport einops\\nimport enum\\nimport gradio\\nimport glob\\nimport IPython\\nimport imageio\\nimport json\\nimport functools\\nimport matplotlib.animation\\nimport matplotlib.pyplot as plt\\nimport numpy as np\\nimport os\\nimport PIL\\nimport pandas as pd\\nimport pickle\\nimport pprint\\nimport pytorch_lightning as pl\\nimport rapidfuzz\\nimport re\\nimport reprlib\\nimport sklearn.metrics\\nimport torch\\nimport torchvision\\nimport tqdm.autonotebook\\nimport transformers\\nimport types\\nfrom typing import Callable, Literal\\nimport wandb\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%load_ext nb_black\n",
"%matplotlib inline\n",
"\n",
"\n",
"import collections\n",
"import dataclasses\n",
"import datasets\n",
"import einops\n",
"import enum\n",
"import gradio\n",
"import glob\n",
"import IPython\n",
"import imageio\n",
"import json\n",
"import functools\n",
"import matplotlib.animation\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"import PIL\n",
"import pandas as pd\n",
"import pickle\n",
"import pprint\n",
"import pytorch_lightning as pl\n",
"import rapidfuzz\n",
"import re\n",
"import reprlib\n",
"import sklearn.metrics\n",
"import torch\n",
"import torchvision\n",
"import tqdm.autonotebook\n",
"import transformers\n",
"import types\n",
"from typing import Callable, Literal\n",
"import wandb"
]
},
{
"cell_type": "markdown",
"id": "2b711a53",
"metadata": {},
"source": [
"### Requirements"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad8e0f9f",
"metadata": {},
"outputs": [],
"source": [
"def pip_freeze_requirements():\n",
" !pip freeze > requirements.txt\n",
" \n",
"#pip_freeze_requirements()"
]
},
{
"cell_type": "markdown",
"id": "77b39d61",
"metadata": {},
"source": [
"### Globals "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "db1722f2",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:30.754713Z",
"start_time": "2023-04-18T15:47:30.740063Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 2;\n",
" var nbb_unformatted_code = \"COMPETITION = \\\"benetech-making-graphs-accessible\\\"\\nDEBUG: bool = False\\nDATA = types.SimpleNamespace()\\nTOKEN = types.SimpleNamespace()\\nCONFIG = types.SimpleNamespace()\\nMODEL = types.SimpleNamespace()\\nTRAINING = types.SimpleNamespace()\";\n",
" var nbb_formatted_code = \"COMPETITION = \\\"benetech-making-graphs-accessible\\\"\\nDEBUG: bool = False\\nDATA = types.SimpleNamespace()\\nTOKEN = types.SimpleNamespace()\\nCONFIG = types.SimpleNamespace()\\nMODEL = types.SimpleNamespace()\\nTRAINING = types.SimpleNamespace()\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"COMPETITION = \"benetech-making-graphs-accessible\"\n",
"DEBUG: bool = False\n",
"DATA = types.SimpleNamespace()\n",
"TOKEN = types.SimpleNamespace()\n",
"CONFIG = types.SimpleNamespace()\n",
"MODEL = types.SimpleNamespace()\n",
"TRAINING = types.SimpleNamespace()"
]
},
{
"cell_type": "markdown",
"id": "52ea33de",
"metadata": {},
"source": [
"### Markdown"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c2aefef2",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:30.801463Z",
"start_time": "2023-04-18T15:47:30.758086Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 3;\n",
" var nbb_unformatted_code = \"def make_new_markdown_section_with_link(section, header=\\\"##\\\", do_print=True):\\n section_id = section.replace(\\\" \\\", \\\"_\\\") + \\\"_\\\"\\n section_link = f\\\"{header} [{section}](#{section_id})\\\"\\n section_header = f\\\"{header} {section} \\\"\\n if do_print:\\n print(section_link + \\\"\\\\n\\\" + section_header)\\n return section_link, section_header\\n\\n\\ndef make_several_sections(\\n section_names=(\\n \\\"Description\\\",\\n \\\"Imports\\\",\\n \\\"Globals\\\",\\n \\\"Setup\\\",\\n \\\"Data\\\",\\n \\\"Data exploration\\\",\\n \\\"Model\\\",\\n \\\"Training\\\",\\n \\\"Results\\\",\\n )\\n):\\n links, headers = zip(\\n *[\\n make_new_markdown_section_with_link(sn, do_print=False)\\n for sn in section_names\\n ]\\n )\\n print(\\\"\\\\n\\\".join(links + (\\\"\\\",) + headers))\\n\\n\\ndef print_python_libraries_requirements():\\n requirements = !pip freeze\\n requirements = \\\"\\\\n\\\".join(requirements)\\n requirements = (\\n f\\\"\\\\n\\\"\\n f\\\"\\\\t Python requirements
\\\\n\\\\n\\\"\\n f\\\"```\\\\n\\\"\\n f\\\"{requirements}\\\\n\\\"\\n f\\\"```\\\\n\\\"\\n f\\\" \\\"\\n )\\n print(requirements)\";\n",
" var nbb_formatted_code = \"def make_new_markdown_section_with_link(section, header=\\\"##\\\", do_print=True):\\n section_id = section.replace(\\\" \\\", \\\"_\\\") + \\\"_\\\"\\n section_link = f\\\"{header} [{section}](#{section_id})\\\"\\n section_header = f\\\"{header} {section} \\\"\\n if do_print:\\n print(section_link + \\\"\\\\n\\\" + section_header)\\n return section_link, section_header\\n\\n\\ndef make_several_sections(\\n section_names=(\\n \\\"Description\\\",\\n \\\"Imports\\\",\\n \\\"Globals\\\",\\n \\\"Setup\\\",\\n \\\"Data\\\",\\n \\\"Data exploration\\\",\\n \\\"Model\\\",\\n \\\"Training\\\",\\n \\\"Results\\\",\\n )\\n):\\n links, headers = zip(\\n *[\\n make_new_markdown_section_with_link(sn, do_print=False)\\n for sn in section_names\\n ]\\n )\\n print(\\\"\\\\n\\\".join(links + (\\\"\\\",) + headers))\\n\\n\\ndef print_python_libraries_requirements():\\n requirements = !pip freeze\\n requirements = \\\"\\\\n\\\".join(requirements)\\n requirements = (\\n f\\\"\\\\n\\\"\\n f\\\"\\\\t Python requirements
\\\\n\\\\n\\\"\\n f\\\"```\\\\n\\\"\\n f\\\"{requirements}\\\\n\\\"\\n f\\\"```\\\\n\\\"\\n f\\\" \\\"\\n )\\n print(requirements)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def make_new_markdown_section_with_link(section, header=\"##\", do_print=True):\n",
" section_id = section.replace(\" \", \"_\") + \"_\"\n",
" section_link = f\"{header} [{section}](#{section_id})\"\n",
" section_header = f\"{header} {section} \"\n",
" if do_print:\n",
" print(section_link + \"\\n\" + section_header)\n",
" return section_link, section_header\n",
"\n",
"\n",
"def make_several_sections(\n",
" section_names=(\n",
" \"Description\",\n",
" \"Imports\",\n",
" \"Globals\",\n",
" \"Setup\",\n",
" \"Data\",\n",
" \"Data exploration\",\n",
" \"Model\",\n",
" \"Training\",\n",
" \"Results\",\n",
" )\n",
"):\n",
" links, headers = zip(\n",
" *[\n",
" make_new_markdown_section_with_link(sn, do_print=False)\n",
" for sn in section_names\n",
" ]\n",
" )\n",
" print(\"\\n\".join(links + (\"\",) + headers))\n"
]
},
{
"cell_type": "markdown",
"id": "bf4ed747",
"metadata": {},
"source": [
"### Terminal"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1e7c72a6",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:30.847062Z",
"start_time": "2023-04-18T15:47:30.804015Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 4;\n",
" var nbb_unformatted_code = \"def mkdir(path, error_if_exists=False):\\n !mkdir {\\\"-p\\\" if not error_if_exists else \\\"\\\"} {path}\\n\\n\\ndef unzip(zip_path, save_path=None, delete_zip=False):\\n !unzip {zip_path} {\\\"-d \\\"+ save_path if save_path else \\\"\\\"}\\n if delete_zip:\\n for path in glob.glob(zip_path):\\n if path.endswith(\\\".zip\\\"):\\n !trash {path}\\n\\n\\ndef unzip_to_data_and_delete():\\n unzip(\\\"data/*\\\", \\\"data\\\", delete_zip=True)\";\n",
" var nbb_formatted_code = \"def mkdir(path, error_if_exists=False):\\n !mkdir {\\\"-p\\\" if not error_if_exists else \\\"\\\"} {path}\\n\\n\\ndef unzip(zip_path, save_path=None, delete_zip=False):\\n !unzip {zip_path} {\\\"-d \\\"+ save_path if save_path else \\\"\\\"}\\n if delete_zip:\\n for path in glob.glob(zip_path):\\n if path.endswith(\\\".zip\\\"):\\n !trash {path}\\n\\n\\ndef unzip_to_data_and_delete():\\n unzip(\\\"data/*\\\", \\\"data\\\", delete_zip=True)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def mkdir(path, error_if_exists=False):\n",
" !mkdir {\"-p\" if not error_if_exists else \"\"} {path}\n",
"\n",
"\n",
"def unzip(zip_path, save_path=None, delete_zip=False):\n",
" !unzip {zip_path} {\"-d \"+ save_path if save_path else \"\"}\n",
" if delete_zip:\n",
" for path in glob.glob(zip_path):\n",
" if path.endswith(\".zip\"):\n",
" !trash {path}\n",
"\n",
"\n",
"def unzip_to_data_and_delete():\n",
" unzip(\"data/*\", \"data\", delete_zip=True)"
]
},
{
"cell_type": "markdown",
"id": "0fb17c9d",
"metadata": {},
"source": [
"### Kaggle"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "aae473b0",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:30.868185Z",
"start_time": "2023-04-18T15:47:30.851313Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 5;\n",
" var nbb_unformatted_code = \"def kaggle_competitions_search(search_term):\\n !kaggle competitions list -s {search_term}\\n\\n\\ndef kaggle_competitions_files(competition):\\n !kaggle competitions files {competition}\\n\\n\\ndef kaggle_competitions_download(competition, save_path=\\\"data\\\", filename=None):\\n mkdir(save_path)\\n !kaggle competitions download -p {save_path} {\\\"-f \\\" + filename if filename else \\\"\\\"} {competition}\\n\\n\\ndef kaggle_competitions_submit(competition, filename, message=\\\"submit\\\"):\\n !kaggle competitions submit -f {filename} -m {message} {competition}\\n\\n\\ndef kaggle_competitions_submissions(competition):\\n !kaggle competitions submissions {competition}\";\n",
" var nbb_formatted_code = \"def kaggle_competitions_search(search_term):\\n !kaggle competitions list -s {search_term}\\n\\n\\ndef kaggle_competitions_files(competition):\\n !kaggle competitions files {competition}\\n\\n\\ndef kaggle_competitions_download(competition, save_path=\\\"data\\\", filename=None):\\n mkdir(save_path)\\n !kaggle competitions download -p {save_path} {\\\"-f \\\" + filename if filename else \\\"\\\"} {competition}\\n\\n\\ndef kaggle_competitions_submit(competition, filename, message=\\\"submit\\\"):\\n !kaggle competitions submit -f {filename} -m {message} {competition}\\n\\n\\ndef kaggle_competitions_submissions(competition):\\n !kaggle competitions submissions {competition}\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def kaggle_competitions_search(search_term):\n",
" !kaggle competitions list -s {search_term}\n",
"\n",
"\n",
"def kaggle_competitions_files(competition):\n",
" !kaggle competitions files {competition}\n",
"\n",
"\n",
"def kaggle_competitions_download(competition, save_path=\"data\", filename=None):\n",
" mkdir(save_path)\n",
" !kaggle competitions download -p {save_path} {\"-f \" + filename if filename else \"\"} {competition}\n",
"\n",
"\n",
"def kaggle_competitions_submit(competition, filename, message=\"submit\"):\n",
" !kaggle competitions submit -f {filename} -m {message} {competition}\n",
"\n",
"\n",
"def kaggle_competitions_submissions(competition):\n",
" !kaggle competitions submissions {competition}"
]
},
{
"cell_type": "markdown",
"id": "0fdfe95e",
"metadata": {},
"source": [
"### Gpu server"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f5ba27be",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:30.967413Z",
"start_time": "2023-04-18T15:47:30.909020Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 7;\n",
" var nbb_unformatted_code = \"def get_shad_server_username_and_telegram_id_pairs(\\n copy_pasted_table: str or None = None,\\n) -> list[str, str]:\\n table_url = \\\"https://docs.google.com/spreadsheets/u/1/d/e/2PACX-1vRNGT6OeI7zKVFzYPoqmTPh1jCfeVjRLSvFziVgRleyFTOHi1GU39ERo_UixTGcgydG7QcurnSmHgSW/pubhtml?gid=1404550339&single=true\\\"\\n\\n if copy_pasted_table is not None:\\n table = copy_pasted_table\\n else:\\n home = os.environ[\\\"HOME\\\"]\\n table = open(f\\\"{home}/shad_server_username_to_telegram.txt\\\").read()\\n\\n shad_server_username_and_telegram_id_pairs = []\\n for row in table.splitlines():\\n if row.count(\\\"\\\\t\\\") == 0:\\n continue\\n cols = row.split(\\\"\\\\t\\\")\\n shad_server_username = cols[-2]\\n telegram_id = cols[-1]\\n shad_server_username_and_telegram_id_pairs.append(\\n (shad_server_username, telegram_id)\\n )\\n\\n return shad_server_username_and_telegram_id_pairs\\n\\n\\ndef get_nvidia_smi_pid_column():\\n nvidia_smi_pid_column = !nvidia-smi | awk '{print $5}'\\n return nvidia_smi_pid_column\\n\\n\\ndef get_pid_username(pid: int) -> str:\\n username = !ps -o uname= -p {pid}\\n return username[0]\\n\\n\\ndef get_usernames_using_gpu() -> list[str]:\\n nvidia_smi_pid_column = get_nvidia_smi_pid_column()\\n pids_using_gpu = []\\n for row in nvidia_smi_pid_column[::-1]:\\n if row == \\\"PID\\\":\\n break\\n try:\\n pid = int(row)\\n except ValueError:\\n continue\\n pids_using_gpu.append(int(pid))\\n\\n usernames_using_gpu = [get_pid_username(pid) for pid in pids_using_gpu]\\n usernames_using_gpu = list(set(usernames_using_gpu))\\n return usernames_using_gpu\\n\\n\\ndef print_telegram_usernames_using_gpu(table: str or None = None):\\n server_to_telegram = dict(get_shad_server_username_and_telegram_id_pairs(table))\\n usernames_using_gpu = get_usernames_using_gpu()\\n\\n telegram_usernames_using_gpu = []\\n server_usernames_with_unknown_telegram_id = []\\n for username in usernames_using_gpu:\\n if username in server_to_telegram:\\n telegram_usernames_using_gpu.append(server_to_telegram[username])\\n else:\\n server_usernames_with_unknown_telegram_id.append(username)\\n\\n print(\\\"Telegram id of users using gpu:\\\")\\n print(\\\"\\\\n\\\".join(telegram_usernames_using_gpu))\\n\\n if server_usernames_with_unknown_telegram_id:\\n print(\\\"Telegram id is unknown for users:\\\")\\n print(\\\"\\\\n\\\".join(server_usernames_with_unknown_telegram_id))\";\n",
" var nbb_formatted_code = \"def get_shad_server_username_and_telegram_id_pairs(\\n copy_pasted_table: str or None = None,\\n) -> list[str, str]:\\n table_url = \\\"https://docs.google.com/spreadsheets/u/1/d/e/2PACX-1vRNGT6OeI7zKVFzYPoqmTPh1jCfeVjRLSvFziVgRleyFTOHi1GU39ERo_UixTGcgydG7QcurnSmHgSW/pubhtml?gid=1404550339&single=true\\\"\\n\\n if copy_pasted_table is not None:\\n table = copy_pasted_table\\n else:\\n home = os.environ[\\\"HOME\\\"]\\n table = open(f\\\"{home}/shad_server_username_to_telegram.txt\\\").read()\\n\\n shad_server_username_and_telegram_id_pairs = []\\n for row in table.splitlines():\\n if row.count(\\\"\\\\t\\\") == 0:\\n continue\\n cols = row.split(\\\"\\\\t\\\")\\n shad_server_username = cols[-2]\\n telegram_id = cols[-1]\\n shad_server_username_and_telegram_id_pairs.append(\\n (shad_server_username, telegram_id)\\n )\\n\\n return shad_server_username_and_telegram_id_pairs\\n\\n\\ndef get_nvidia_smi_pid_column():\\n nvidia_smi_pid_column = !nvidia-smi | awk '{print $5}'\\n return nvidia_smi_pid_column\\n\\n\\ndef get_pid_username(pid: int) -> str:\\n username = !ps -o uname= -p {pid}\\n return username[0]\\n\\n\\ndef get_usernames_using_gpu() -> list[str]:\\n nvidia_smi_pid_column = get_nvidia_smi_pid_column()\\n pids_using_gpu = []\\n for row in nvidia_smi_pid_column[::-1]:\\n if row == \\\"PID\\\":\\n break\\n try:\\n pid = int(row)\\n except ValueError:\\n continue\\n pids_using_gpu.append(int(pid))\\n\\n usernames_using_gpu = [get_pid_username(pid) for pid in pids_using_gpu]\\n usernames_using_gpu = list(set(usernames_using_gpu))\\n return usernames_using_gpu\\n\\n\\ndef print_telegram_usernames_using_gpu(table: str or None = None):\\n server_to_telegram = dict(get_shad_server_username_and_telegram_id_pairs(table))\\n usernames_using_gpu = get_usernames_using_gpu()\\n\\n telegram_usernames_using_gpu = []\\n server_usernames_with_unknown_telegram_id = []\\n for username in usernames_using_gpu:\\n if username in server_to_telegram:\\n telegram_usernames_using_gpu.append(server_to_telegram[username])\\n else:\\n server_usernames_with_unknown_telegram_id.append(username)\\n\\n print(\\\"Telegram id of users using gpu:\\\")\\n print(\\\"\\\\n\\\".join(telegram_usernames_using_gpu))\\n\\n if server_usernames_with_unknown_telegram_id:\\n print(\\\"Telegram id is unknown for users:\\\")\\n print(\\\"\\\\n\\\".join(server_usernames_with_unknown_telegram_id))\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def get_shad_server_username_and_telegram_id_pairs(\n",
" copy_pasted_table: str or None = None,\n",
") -> list[str, str]:\n",
" table_url = \"https://docs.google.com/spreadsheets/u/1/d/e/2PACX-1vRNGT6OeI7zKVFzYPoqmTPh1jCfeVjRLSvFziVgRleyFTOHi1GU39ERo_UixTGcgydG7QcurnSmHgSW/pubhtml?gid=1404550339&single=true\"\n",
"\n",
" if copy_pasted_table is not None:\n",
" table = copy_pasted_table\n",
" else:\n",
" home = os.environ[\"HOME\"]\n",
" table = open(f\"{home}/shad_server_username_to_telegram.txt\").read()\n",
"\n",
" shad_server_username_and_telegram_id_pairs = []\n",
" for row in table.splitlines():\n",
" if row.count(\"\\t\") == 0:\n",
" continue\n",
" cols = row.split(\"\\t\")\n",
" shad_server_username = cols[-2]\n",
" telegram_id = cols[-1]\n",
" shad_server_username_and_telegram_id_pairs.append(\n",
" (shad_server_username, telegram_id)\n",
" )\n",
"\n",
" return shad_server_username_and_telegram_id_pairs\n",
"\n",
"\n",
"def get_nvidia_smi_pid_column():\n",
" nvidia_smi_pid_column = !nvidia-smi | awk '{print $5}'\n",
" return nvidia_smi_pid_column\n",
"\n",
"\n",
"def get_pid_username(pid: int) -> str:\n",
" username = !ps -o uname= -p {pid}\n",
" return username[0]\n",
"\n",
"\n",
"def get_usernames_using_gpu() -> list[str]:\n",
" nvidia_smi_pid_column = get_nvidia_smi_pid_column()\n",
" pids_using_gpu = []\n",
" for row in nvidia_smi_pid_column[::-1]:\n",
" if row == \"PID\":\n",
" break\n",
" try:\n",
" pid = int(row)\n",
" except ValueError:\n",
" continue\n",
" pids_using_gpu.append(int(pid))\n",
"\n",
" usernames_using_gpu = [get_pid_username(pid) for pid in pids_using_gpu]\n",
" usernames_using_gpu = list(set(usernames_using_gpu))\n",
" return usernames_using_gpu\n",
"\n",
"\n",
"def print_telegram_usernames_using_gpu(table: str or None = None):\n",
" server_to_telegram = dict(get_shad_server_username_and_telegram_id_pairs(table))\n",
" usernames_using_gpu = get_usernames_using_gpu()\n",
"\n",
" telegram_usernames_using_gpu = []\n",
" server_usernames_with_unknown_telegram_id = []\n",
" for username in usernames_using_gpu:\n",
" if username in server_to_telegram:\n",
" telegram_usernames_using_gpu.append(server_to_telegram[username])\n",
" else:\n",
" server_usernames_with_unknown_telegram_id.append(username)\n",
"\n",
" print(\"Telegram id of users using gpu:\")\n",
" print(\"\\n\".join(telegram_usernames_using_gpu))\n",
"\n",
" if server_usernames_with_unknown_telegram_id:\n",
" print(\"Telegram id is unknown for users:\")\n",
" print(\"\\n\".join(server_usernames_with_unknown_telegram_id))"
]
},
{
"cell_type": "markdown",
"id": "a5626f18",
"metadata": {},
"source": [
"### Environment variables "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e496647d",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:30.899947Z",
"start_time": "2023-04-18T15:47:30.872176Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 6;\n",
" var nbb_unformatted_code = \"def set_tokenizers_parallelism(enable: bool):\\n os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"true\\\" if enable else \\\"false\\\"\\n\\n\\ndef set_torch_device_order_pci_bus():\\n os.environ[\\\"CUDA_DEVICE_ORDER\\\"] = \\\"PCI_BUS_ID\\\"\\n\\n\\nset_tokenizers_parallelism(False)\\nset_torch_device_order_pci_bus()\";\n",
" var nbb_formatted_code = \"def set_tokenizers_parallelism(enable: bool):\\n os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"true\\\" if enable else \\\"false\\\"\\n\\n\\ndef set_torch_device_order_pci_bus():\\n os.environ[\\\"CUDA_DEVICE_ORDER\\\"] = \\\"PCI_BUS_ID\\\"\\n\\n\\nset_tokenizers_parallelism(False)\\nset_torch_device_order_pci_bus()\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def set_tokenizers_parallelism(enable: bool):\n",
" os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\" if enable else \"false\"\n",
"\n",
"\n",
"def set_torch_device_order_pci_bus():\n",
" os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
"\n",
"\n",
"set_tokenizers_parallelism(False)\n",
"set_torch_device_order_pci_bus()"
]
},
{
"cell_type": "markdown",
"id": "202c992a",
"metadata": {},
"source": [
"### Utils "
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "7a52ce27",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-25T19:29:59.305379Z",
"start_time": "2023-04-25T19:29:59.169804Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 17;\n",
" var nbb_unformatted_code = \"def path_to_dict(path, print_only_last_dirname=False):\\n dirpath, dirnames, filenames = next(os.walk(path))\\n path_contents = filenames\\n\\n for dirname in dirnames:\\n full_dirname = os.path.join(path, dirname)\\n path_contents.append(path_to_dict(full_dirname, print_only_last_dirname=True))\\n\\n if print_only_last_dirname:\\n path = os.path.split(path)[-1]\\n\\n return {path: path_contents}\\n\\n\\ndef pprint_path_contents(path):\\n path_dict = path_to_dict(path)\\n short_path_repr = reprlib.repr(path_dict)\\n short_path_dict = eval(short_path_repr)\\n string = pprint.pformat(short_path_dict).replace(\\\"Ellipsis\\\", \\\"...\\\")\\n print(string)\\n \\n \\ndef load_pickle_or_build_object_and_save(pickle_path:str, build_object: Callable[[], \\\"T\\\"]) -> \\\"T\\\":\\n if not os.path.exists(pickle_path):\\n pickle.dump(build_object(), open(pickle_path, \\\"wb\\\"))\\n else:\\n print(f\\\"Reusing object {pickle_path}.\\\")\\n return pickle.load(open(pickle_path, \\\"rb\\\"))\";\n",
" var nbb_formatted_code = \"def path_to_dict(path, print_only_last_dirname=False):\\n dirpath, dirnames, filenames = next(os.walk(path))\\n path_contents = filenames\\n\\n for dirname in dirnames:\\n full_dirname = os.path.join(path, dirname)\\n path_contents.append(path_to_dict(full_dirname, print_only_last_dirname=True))\\n\\n if print_only_last_dirname:\\n path = os.path.split(path)[-1]\\n\\n return {path: path_contents}\\n\\n\\ndef pprint_path_contents(path):\\n path_dict = path_to_dict(path)\\n short_path_repr = reprlib.repr(path_dict)\\n short_path_dict = eval(short_path_repr)\\n string = pprint.pformat(short_path_dict).replace(\\\"Ellipsis\\\", \\\"...\\\")\\n print(string)\\n\\n\\ndef load_pickle_or_build_object_and_save(\\n pickle_path: str, build_object: Callable[[], \\\"T\\\"]\\n) -> \\\"T\\\":\\n if not os.path.exists(pickle_path):\\n pickle.dump(build_object(), open(pickle_path, \\\"wb\\\"))\\n else:\\n print(f\\\"Reusing object {pickle_path}.\\\")\\n return pickle.load(open(pickle_path, \\\"rb\\\"))\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def path_to_dict(path, print_only_last_dirname=False):\n",
" dirpath, dirnames, filenames = next(os.walk(path))\n",
" path_contents = filenames\n",
"\n",
" for dirname in dirnames:\n",
" full_dirname = os.path.join(path, dirname)\n",
" path_contents.append(path_to_dict(full_dirname, print_only_last_dirname=True))\n",
"\n",
" if print_only_last_dirname:\n",
" path = os.path.split(path)[-1]\n",
"\n",
" return {path: path_contents}\n",
"\n",
"\n",
"def pprint_path_contents(path):\n",
" path_dict = path_to_dict(path)\n",
" short_path_repr = reprlib.repr(path_dict)\n",
" short_path_dict = eval(short_path_repr)\n",
" string = pprint.pformat(short_path_dict).replace(\"Ellipsis\", \"...\")\n",
" print(string)\n",
"\n",
"\n",
"def load_pickle_or_build_object_and_save(\n",
" pickle_path: str, build_object: Callable[[], \"T\"]\n",
") -> \"T\":\n",
" if not os.path.exists(pickle_path):\n",
" pickle.dump(build_object(), open(pickle_path, \"wb\"))\n",
" else:\n",
" print(f\"Reusing object {pickle_path}.\")\n",
" return pickle.load(open(pickle_path, \"rb\"))"
]
},
{
"cell_type": "markdown",
"id": "cdf2b470",
"metadata": {},
"source": [
"## Data "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "098e77ae",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:30.981098Z",
"start_time": "2023-04-18T15:47:30.971522Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 8;\n",
" var nbb_unformatted_code = \"if not os.path.exists(\\\"data\\\"):\\n kaggle_competitions_download(COMPETITION)\\n unzip_to_data_and_delete()\";\n",
" var nbb_formatted_code = \"if not os.path.exists(\\\"data\\\"):\\n kaggle_competitions_download(COMPETITION)\\n unzip_to_data_and_delete()\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"if not os.path.exists(\"data\"):\n",
" kaggle_competitions_download(COMPETITION)\n",
" unzip_to_data_and_delete()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "1c7232a4",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:31.219671Z",
"start_time": "2023-04-18T15:47:31.028004Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'data': ['sample_submission.csv',\n",
" {'train': [{'images': ['52ecbd029a07.jpg',\n",
" 'fd7e3f0e4d43.jpg',\n",
" 'f0122da6cbe1.jpg',\n",
" '2a186a0fa1ae.jpg',\n",
" '6559c7a7d153.jpg',\n",
" '5fd880333d07.jpg',\n",
" ...]},\n",
" {'annotations': ['0f4f52fc3f4b.json',\n",
" '35f0ec146509.json',\n",
" '2e374a37e404.json',\n",
" '96578b79c571.json',\n",
" 'dfbd6e21c301.json',\n",
" '0893be463049.json',\n",
" ...]}]},\n",
" {'test': [{'images': ['000b92c3b098.jpg',\n",
" '01b45b831589.jpg',\n",
" '00dcf883a459.jpg',\n",
" '007a18eb4e09.jpg',\n",
" '00f5404753cf.jpg']}]}]}\n"
]
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 10;\n",
" var nbb_unformatted_code = \"pprint_path_contents(\\\"data\\\")\";\n",
" var nbb_formatted_code = \"pprint_path_contents(\\\"data\\\")\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pprint_path_contents(\"data\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c0a85e8a",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-27T13:04:21.517594Z",
"start_time": "2023-04-27T13:04:21.491793Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 8;\n",
" var nbb_unformatted_code = \"@functools.cache\\ndef load_train_image_ids() -> list[str]:\\n train_image_ids = [i.replace(\\\".jpg\\\", \\\"\\\") for i in os.listdir(\\\"data/train/images\\\")]\\n return train_image_ids[: 1000 if DEBUG else None]\\n\\n\\n@functools.cache\\ndef load_test_image_ids() -> list[str]:\\n return [i.replace(\\\".jpg\\\", \\\"\\\") for i in os.listdir(\\\"data/test/images\\\")]\\n\\n\\n@functools.cache\\ndef load_image_annotation(image_id: str) -> dict:\\n return json.load(open(f\\\"data/train/annotations/{image_id}.json\\\"))\\n\\n\\ndef load_image(image_id: str) -> np.ndarray:\\n return imageio.v3.imread(open(f\\\"data/train/images/{image_id}.jpg\\\", \\\"rb\\\"))\";\n",
" var nbb_formatted_code = \"@functools.cache\\ndef load_train_image_ids() -> list[str]:\\n train_image_ids = [i.replace(\\\".jpg\\\", \\\"\\\") for i in os.listdir(\\\"data/train/images\\\")]\\n return train_image_ids[: 1000 if DEBUG else None]\\n\\n\\n@functools.cache\\ndef load_test_image_ids() -> list[str]:\\n return [i.replace(\\\".jpg\\\", \\\"\\\") for i in os.listdir(\\\"data/test/images\\\")]\\n\\n\\n@functools.cache\\ndef load_image_annotation(image_id: str) -> dict:\\n return json.load(open(f\\\"data/train/annotations/{image_id}.json\\\"))\\n\\n\\ndef load_image(image_id: str) -> np.ndarray:\\n return imageio.v3.imread(open(f\\\"data/train/images/{image_id}.jpg\\\", \\\"rb\\\"))\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"@functools.cache\n",
"def load_train_image_ids() -> list[str]:\n",
" train_image_ids = [i.replace(\".jpg\", \"\") for i in os.listdir(\"data/train/images\")]\n",
" return train_image_ids[: 1000 if DEBUG else None]\n",
"\n",
"\n",
"@functools.cache\n",
"def load_test_image_ids() -> list[str]:\n",
" return [i.replace(\".jpg\", \"\") for i in os.listdir(\"data/test/images\")]\n",
"\n",
"\n",
"@functools.cache\n",
"def load_image_annotation(image_id: str) -> dict:\n",
" return json.load(open(f\"data/train/annotations/{image_id}.json\"))\n",
"\n",
"\n",
"def load_image(image_id: str) -> np.ndarray:\n",
" return imageio.v3.imread(open(f\"data/train/images/{image_id}.jpg\", \"rb\"))"
]
},
{
"cell_type": "markdown",
"id": "e6e7d333",
"metadata": {},
"source": [
"### Annotation structure "
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1e98517b",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:31.349287Z",
"start_time": "2023-04-18T15:47:31.250789Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 12;\n",
" var nbb_unformatted_code = \"class Source(enum.Enum):\\n generated = \\\"generated\\\"\\n extracted = \\\"extracted\\\"\\n\\n\\nclass ChartType(enum.Enum):\\n dot = \\\"dot\\\"\\n horizontal_bar = \\\"horizontal_bar\\\"\\n vertical_bar = \\\"vertical_bar\\\"\\n line = \\\"line\\\"\\n scatter = \\\"scatter\\\"\\n\\n\\n@dataclasses.dataclass\\nclass PlotBoundingBox:\\n height: int\\n width: int\\n x0: int\\n y0: int\\n\\n def get_bounds(self):\\n xs = [self.x0, self.x0 + self.width, self.x0 + self.width, self.x0, self.x0]\\n ys = [self.y0, self.y0, self.y0 + self.height, self.y0 + self.height, self.y0]\\n return xs, ys\\n\\n\\n@dataclasses.dataclass\\nclass DataPoint:\\n x: float or str\\n y: float or str\\n\\n\\nclass TextRole(enum.Enum):\\n axis_title = \\\"axis_title\\\"\\n chart_title = \\\"chart_title\\\"\\n legend_label = \\\"legend_label\\\"\\n tick_grouping = \\\"tick_grouping\\\"\\n tick_label = \\\"tick_label\\\"\\n other = \\\"other\\\"\\n\\n\\n@dataclasses.dataclass\\nclass Polygon:\\n x0: int\\n x1: int\\n x2: int\\n x3: int\\n y0: int\\n y1: int\\n y2: int\\n y3: int\\n\\n def get_bounds(self):\\n xs = [\\n self.x0,\\n self.x1,\\n self.x2,\\n self.x3,\\n self.x0,\\n ]\\n ys = [\\n self.y0,\\n self.y1,\\n self.y2,\\n self.y3,\\n self.y0,\\n ]\\n return xs, ys\\n\\n\\n@dataclasses.dataclass\\nclass Text:\\n id: int\\n polygon: Polygon\\n role: TextRole\\n text: str\\n\\n def __post_init__(self):\\n self.polygon = Polygon(**self.polygon)\\n self.role = TextRole(self.role)\\n\\n\\nclass ValuesType(enum.Enum):\\n categorical = \\\"categorical\\\"\\n numerical = \\\"numerical\\\"\\n\\n\\n@dataclasses.dataclass\\nclass Tick:\\n id: int\\n x: int\\n y: int\\n\\n\\nclass TickType(enum.Enum):\\n markers = \\\"markers\\\"\\n separators = \\\"separators\\\"\\n\\n\\n@dataclasses.dataclass\\nclass Axis:\\n values_type: ValuesType\\n tick_type: TickType\\n ticks: list[Tick]\\n\\n def __post_init__(self):\\n self.values_type = ValuesType(self.values_type)\\n self.tick_type = TickType(self.tick_type)\\n self.ticks = [\\n Tick(id=kw[\\\"id\\\"], x=kw[\\\"tick_pt\\\"][\\\"x\\\"], y=kw[\\\"tick_pt\\\"][\\\"y\\\"])\\n for kw in self.ticks\\n ]\\n\\n def get_bounds(self):\\n min_x = min(tick.x for tick in self.ticks)\\n max_x = max(tick.x for tick in self.ticks)\\n min_y = min(tick.y for tick in self.ticks)\\n max_y = max(tick.y for tick in self.ticks)\\n xs = [min_x, max_x, max_x, min_x, min_x]\\n ys = [min_y, min_y, max_y, max_y, min_y]\\n return xs, ys\\n\\n\\ndef convert_dashes_to_underscores_in_key_names(dictionary):\\n return {k.replace(\\\"-\\\", \\\"_\\\"): v for k, v in dictionary.items()}\\n\\n\\n@dataclasses.dataclass\\nclass Axes:\\n x_axis: Axis\\n y_axis: Axis\\n\\n def __post_init__(self):\\n self.x_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.x_axis))\\n self.y_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.y_axis))\\n\\n\\ndef preprocess_numerical_value(value):\\n value = float(value)\\n value = 0 if np.isnan(value) else value\\n return value\\n\\n\\ndef preprocess_value(value, value_type: ValuesType):\\n if value_type == ValuesType.numerical:\\n return preprocess_numerical_value(value)\\n else:\\n return str(value)\\n\\n\\n@dataclasses.dataclass\\nclass Annotation:\\n source: Source\\n chart_type: ChartType\\n plot_bb: PlotBoundingBox\\n text: list[Text]\\n axes: Axes\\n data_series: list[DataPoint]\\n\\n def __post_init__(self):\\n self.source = Source(self.source)\\n self.chart_type = ChartType(self.chart_type)\\n self.plot_bb = PlotBoundingBox(**self.plot_bb)\\n self.text = [Text(**kw) for kw in self.text]\\n self.axes = Axes(**convert_dashes_to_underscores_in_key_names(self.axes))\\n self.data_series = [DataPoint(**kw) for kw in self.data_series]\\n\\n for i in range(len(self.data_series)):\\n self.data_series[i].x = preprocess_value(\\n self.data_series[i].x, self.axes.x_axis.values_type\\n )\\n self.data_series[i].y = preprocess_value(\\n self.data_series[i].y, self.axes.y_axis.values_type\\n )\\n\\n @staticmethod\\n def from_dict_with_dashes(kwargs):\\n return Annotation(**convert_dashes_to_underscores_in_key_names(kwargs))\\n\\n @staticmethod\\n def from_image_index(image_index: int):\\n image_id = load_train_image_ids()[image_index]\\n return Annotation.from_dict_with_dashes(load_image_annotation(image_id))\\n\\n def get_text_by_role(self, text_role: TextRole) -> list[Text]:\\n return [t for t in self.text if t.role == text_role]\\n\\n\\n@dataclasses.dataclass\\nclass AnnotatedImage:\\n id: str\\n image: np.ndarray\\n annotation: Annotation\\n\\n @staticmethod\\n def from_image_id(image_id: str):\\n return AnnotatedImage(\\n id=image_id,\\n image=load_image(image_id),\\n annotation=Annotation.from_dict_with_dashes(\\n load_image_annotation(image_id)\\n ),\\n )\\n\\n @staticmethod\\n def from_image_index(image_index: int):\\n return AnnotatedImage.from_image_id(load_train_image_ids()[image_index])\\n\\n\\ndef generate_annotated_images():\\n for image_id in tqdm.autonotebook.tqdm(\\n load_train_image_ids(), \\\"Iterating over annotated images\\\"\\n ):\\n yield AnnotatedImage.from_image_id(image_id)\";\n",
" var nbb_formatted_code = \"class Source(enum.Enum):\\n generated = \\\"generated\\\"\\n extracted = \\\"extracted\\\"\\n\\n\\nclass ChartType(enum.Enum):\\n dot = \\\"dot\\\"\\n horizontal_bar = \\\"horizontal_bar\\\"\\n vertical_bar = \\\"vertical_bar\\\"\\n line = \\\"line\\\"\\n scatter = \\\"scatter\\\"\\n\\n\\n@dataclasses.dataclass\\nclass PlotBoundingBox:\\n height: int\\n width: int\\n x0: int\\n y0: int\\n\\n def get_bounds(self):\\n xs = [self.x0, self.x0 + self.width, self.x0 + self.width, self.x0, self.x0]\\n ys = [self.y0, self.y0, self.y0 + self.height, self.y0 + self.height, self.y0]\\n return xs, ys\\n\\n\\n@dataclasses.dataclass\\nclass DataPoint:\\n x: float or str\\n y: float or str\\n\\n\\nclass TextRole(enum.Enum):\\n axis_title = \\\"axis_title\\\"\\n chart_title = \\\"chart_title\\\"\\n legend_label = \\\"legend_label\\\"\\n tick_grouping = \\\"tick_grouping\\\"\\n tick_label = \\\"tick_label\\\"\\n other = \\\"other\\\"\\n\\n\\n@dataclasses.dataclass\\nclass Polygon:\\n x0: int\\n x1: int\\n x2: int\\n x3: int\\n y0: int\\n y1: int\\n y2: int\\n y3: int\\n\\n def get_bounds(self):\\n xs = [\\n self.x0,\\n self.x1,\\n self.x2,\\n self.x3,\\n self.x0,\\n ]\\n ys = [\\n self.y0,\\n self.y1,\\n self.y2,\\n self.y3,\\n self.y0,\\n ]\\n return xs, ys\\n\\n\\n@dataclasses.dataclass\\nclass Text:\\n id: int\\n polygon: Polygon\\n role: TextRole\\n text: str\\n\\n def __post_init__(self):\\n self.polygon = Polygon(**self.polygon)\\n self.role = TextRole(self.role)\\n\\n\\nclass ValuesType(enum.Enum):\\n categorical = \\\"categorical\\\"\\n numerical = \\\"numerical\\\"\\n\\n\\n@dataclasses.dataclass\\nclass Tick:\\n id: int\\n x: int\\n y: int\\n\\n\\nclass TickType(enum.Enum):\\n markers = \\\"markers\\\"\\n separators = \\\"separators\\\"\\n\\n\\n@dataclasses.dataclass\\nclass Axis:\\n values_type: ValuesType\\n tick_type: TickType\\n ticks: list[Tick]\\n\\n def __post_init__(self):\\n self.values_type = ValuesType(self.values_type)\\n self.tick_type = TickType(self.tick_type)\\n self.ticks = [\\n Tick(id=kw[\\\"id\\\"], x=kw[\\\"tick_pt\\\"][\\\"x\\\"], y=kw[\\\"tick_pt\\\"][\\\"y\\\"])\\n for kw in self.ticks\\n ]\\n\\n def get_bounds(self):\\n min_x = min(tick.x for tick in self.ticks)\\n max_x = max(tick.x for tick in self.ticks)\\n min_y = min(tick.y for tick in self.ticks)\\n max_y = max(tick.y for tick in self.ticks)\\n xs = [min_x, max_x, max_x, min_x, min_x]\\n ys = [min_y, min_y, max_y, max_y, min_y]\\n return xs, ys\\n\\n\\ndef convert_dashes_to_underscores_in_key_names(dictionary):\\n return {k.replace(\\\"-\\\", \\\"_\\\"): v for k, v in dictionary.items()}\\n\\n\\n@dataclasses.dataclass\\nclass Axes:\\n x_axis: Axis\\n y_axis: Axis\\n\\n def __post_init__(self):\\n self.x_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.x_axis))\\n self.y_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.y_axis))\\n\\n\\ndef preprocess_numerical_value(value):\\n value = float(value)\\n value = 0 if np.isnan(value) else value\\n return value\\n\\n\\ndef preprocess_value(value, value_type: ValuesType):\\n if value_type == ValuesType.numerical:\\n return preprocess_numerical_value(value)\\n else:\\n return str(value)\\n\\n\\n@dataclasses.dataclass\\nclass Annotation:\\n source: Source\\n chart_type: ChartType\\n plot_bb: PlotBoundingBox\\n text: list[Text]\\n axes: Axes\\n data_series: list[DataPoint]\\n\\n def __post_init__(self):\\n self.source = Source(self.source)\\n self.chart_type = ChartType(self.chart_type)\\n self.plot_bb = PlotBoundingBox(**self.plot_bb)\\n self.text = [Text(**kw) for kw in self.text]\\n self.axes = Axes(**convert_dashes_to_underscores_in_key_names(self.axes))\\n self.data_series = [DataPoint(**kw) for kw in self.data_series]\\n\\n for i in range(len(self.data_series)):\\n self.data_series[i].x = preprocess_value(\\n self.data_series[i].x, self.axes.x_axis.values_type\\n )\\n self.data_series[i].y = preprocess_value(\\n self.data_series[i].y, self.axes.y_axis.values_type\\n )\\n\\n @staticmethod\\n def from_dict_with_dashes(kwargs):\\n return Annotation(**convert_dashes_to_underscores_in_key_names(kwargs))\\n\\n @staticmethod\\n def from_image_index(image_index: int):\\n image_id = load_train_image_ids()[image_index]\\n return Annotation.from_dict_with_dashes(load_image_annotation(image_id))\\n\\n def get_text_by_role(self, text_role: TextRole) -> list[Text]:\\n return [t for t in self.text if t.role == text_role]\\n\\n\\n@dataclasses.dataclass\\nclass AnnotatedImage:\\n id: str\\n image: np.ndarray\\n annotation: Annotation\\n\\n @staticmethod\\n def from_image_id(image_id: str):\\n return AnnotatedImage(\\n id=image_id,\\n image=load_image(image_id),\\n annotation=Annotation.from_dict_with_dashes(\\n load_image_annotation(image_id)\\n ),\\n )\\n\\n @staticmethod\\n def from_image_index(image_index: int):\\n return AnnotatedImage.from_image_id(load_train_image_ids()[image_index])\\n\\n\\ndef generate_annotated_images():\\n for image_id in tqdm.autonotebook.tqdm(\\n load_train_image_ids(), \\\"Iterating over annotated images\\\"\\n ):\\n yield AnnotatedImage.from_image_id(image_id)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"class Source(enum.Enum):\n",
" generated = \"generated\"\n",
" extracted = \"extracted\"\n",
"\n",
"\n",
"class ChartType(enum.Enum):\n",
" dot = \"dot\"\n",
" horizontal_bar = \"horizontal_bar\"\n",
" vertical_bar = \"vertical_bar\"\n",
" line = \"line\"\n",
" scatter = \"scatter\"\n",
"\n",
"\n",
"@dataclasses.dataclass\n",
"class PlotBoundingBox:\n",
" height: int\n",
" width: int\n",
" x0: int\n",
" y0: int\n",
"\n",
" def get_bounds(self):\n",
" xs = [self.x0, self.x0 + self.width, self.x0 + self.width, self.x0, self.x0]\n",
" ys = [self.y0, self.y0, self.y0 + self.height, self.y0 + self.height, self.y0]\n",
" return xs, ys\n",
"\n",
"\n",
"@dataclasses.dataclass\n",
"class DataPoint:\n",
" x: float or str\n",
" y: float or str\n",
"\n",
"\n",
"class TextRole(enum.Enum):\n",
" axis_title = \"axis_title\"\n",
" chart_title = \"chart_title\"\n",
" legend_label = \"legend_label\"\n",
" tick_grouping = \"tick_grouping\"\n",
" tick_label = \"tick_label\"\n",
" other = \"other\"\n",
"\n",
"\n",
"@dataclasses.dataclass\n",
"class Polygon:\n",
" x0: int\n",
" x1: int\n",
" x2: int\n",
" x3: int\n",
" y0: int\n",
" y1: int\n",
" y2: int\n",
" y3: int\n",
"\n",
" def get_bounds(self):\n",
" xs = [\n",
" self.x0,\n",
" self.x1,\n",
" self.x2,\n",
" self.x3,\n",
" self.x0,\n",
" ]\n",
" ys = [\n",
" self.y0,\n",
" self.y1,\n",
" self.y2,\n",
" self.y3,\n",
" self.y0,\n",
" ]\n",
" return xs, ys\n",
"\n",
"\n",
"@dataclasses.dataclass\n",
"class Text:\n",
" id: int\n",
" polygon: Polygon\n",
" role: TextRole\n",
" text: str\n",
"\n",
" def __post_init__(self):\n",
" self.polygon = Polygon(**self.polygon)\n",
" self.role = TextRole(self.role)\n",
"\n",
"\n",
"class ValuesType(enum.Enum):\n",
" categorical = \"categorical\"\n",
" numerical = \"numerical\"\n",
"\n",
"\n",
"@dataclasses.dataclass\n",
"class Tick:\n",
" id: int\n",
" x: int\n",
" y: int\n",
"\n",
"\n",
"class TickType(enum.Enum):\n",
" markers = \"markers\"\n",
" separators = \"separators\"\n",
"\n",
"\n",
"@dataclasses.dataclass\n",
"class Axis:\n",
" values_type: ValuesType\n",
" tick_type: TickType\n",
" ticks: list[Tick]\n",
"\n",
" def __post_init__(self):\n",
" self.values_type = ValuesType(self.values_type)\n",
" self.tick_type = TickType(self.tick_type)\n",
" self.ticks = [\n",
" Tick(id=kw[\"id\"], x=kw[\"tick_pt\"][\"x\"], y=kw[\"tick_pt\"][\"y\"])\n",
" for kw in self.ticks\n",
" ]\n",
"\n",
" def get_bounds(self):\n",
" min_x = min(tick.x for tick in self.ticks)\n",
" max_x = max(tick.x for tick in self.ticks)\n",
" min_y = min(tick.y for tick in self.ticks)\n",
" max_y = max(tick.y for tick in self.ticks)\n",
" xs = [min_x, max_x, max_x, min_x, min_x]\n",
" ys = [min_y, min_y, max_y, max_y, min_y]\n",
" return xs, ys\n",
"\n",
"\n",
"def convert_dashes_to_underscores_in_key_names(dictionary):\n",
" return {k.replace(\"-\", \"_\"): v for k, v in dictionary.items()}\n",
"\n",
"\n",
"@dataclasses.dataclass\n",
"class Axes:\n",
" x_axis: Axis\n",
" y_axis: Axis\n",
"\n",
" def __post_init__(self):\n",
" self.x_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.x_axis))\n",
" self.y_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.y_axis))\n",
"\n",
"\n",
"def preprocess_numerical_value(value):\n",
" value = float(value)\n",
" value = 0 if np.isnan(value) else value\n",
" return value\n",
"\n",
"\n",
"def preprocess_value(value, value_type: ValuesType):\n",
" if value_type == ValuesType.numerical:\n",
" return preprocess_numerical_value(value)\n",
" else:\n",
" return str(value)\n",
"\n",
"\n",
"@dataclasses.dataclass\n",
"class Annotation:\n",
" source: Source\n",
" chart_type: ChartType\n",
" plot_bb: PlotBoundingBox\n",
" text: list[Text]\n",
" axes: Axes\n",
" data_series: list[DataPoint]\n",
"\n",
" def __post_init__(self):\n",
" self.source = Source(self.source)\n",
" self.chart_type = ChartType(self.chart_type)\n",
" self.plot_bb = PlotBoundingBox(**self.plot_bb)\n",
" self.text = [Text(**kw) for kw in self.text]\n",
" self.axes = Axes(**convert_dashes_to_underscores_in_key_names(self.axes))\n",
" self.data_series = [DataPoint(**kw) for kw in self.data_series]\n",
"\n",
" for i in range(len(self.data_series)):\n",
" self.data_series[i].x = preprocess_value(\n",
" self.data_series[i].x, self.axes.x_axis.values_type\n",
" )\n",
" self.data_series[i].y = preprocess_value(\n",
" self.data_series[i].y, self.axes.y_axis.values_type\n",
" )\n",
"\n",
" @staticmethod\n",
" def from_dict_with_dashes(kwargs):\n",
" return Annotation(**convert_dashes_to_underscores_in_key_names(kwargs))\n",
"\n",
" @staticmethod\n",
" def from_image_index(image_index: int):\n",
" image_id = load_train_image_ids()[image_index]\n",
" return Annotation.from_dict_with_dashes(load_image_annotation(image_id))\n",
"\n",
" def get_text_by_role(self, text_role: TextRole) -> list[Text]:\n",
" return [t for t in self.text if t.role == text_role]\n",
"\n",
"\n",
"@dataclasses.dataclass\n",
"class AnnotatedImage:\n",
" id: str\n",
" image: np.ndarray\n",
" annotation: Annotation\n",
"\n",
" @staticmethod\n",
" def from_image_id(image_id: str):\n",
" return AnnotatedImage(\n",
" id=image_id,\n",
" image=load_image(image_id),\n",
" annotation=Annotation.from_dict_with_dashes(\n",
" load_image_annotation(image_id)\n",
" ),\n",
" )\n",
"\n",
" @staticmethod\n",
" def from_image_index(image_index: int):\n",
" return AnnotatedImage.from_image_id(load_train_image_ids()[image_index])\n",
"\n",
"\n",
"def generate_annotated_images():\n",
" for image_id in tqdm.autonotebook.tqdm(\n",
" load_train_image_ids(), \"Iterating over annotated images\"\n",
" ):\n",
" yield AnnotatedImage.from_image_id(image_id)"
]
},
{
"cell_type": "markdown",
"id": "dad819b2",
"metadata": {},
"source": [
"### Data exploration "
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f165119d",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:31.364012Z",
"start_time": "2023-04-18T15:47:31.352168Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 13;\n",
" var nbb_unformatted_code = \"def are_there_nan_values_in_axis_data():\\n for annotated_image in generate_annotated_images():\\n for datapoint in annotated_image.annotation.data_series:\\n for value in [datapoint.x, datapoint.y]:\\n if not isinstance(value, str) and np.isnan(value):\\n return True\\n return False\";\n",
" var nbb_formatted_code = \"def are_there_nan_values_in_axis_data():\\n for annotated_image in generate_annotated_images():\\n for datapoint in annotated_image.annotation.data_series:\\n for value in [datapoint.x, datapoint.y]:\\n if not isinstance(value, str) and np.isnan(value):\\n return True\\n return False\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def are_there_nan_values_in_axis_data():\n",
" for annotated_image in generate_annotated_images():\n",
" for datapoint in annotated_image.annotation.data_series:\n",
" for value in [datapoint.x, datapoint.y]:\n",
" if not isinstance(value, str) and np.isnan(value):\n",
" return True\n",
" return False"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "3ff0494b",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:31.396949Z",
"start_time": "2023-04-18T15:47:31.376901Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 14;\n",
" var nbb_unformatted_code = \"if DEBUG:\\n print(are_there_nan_values_in_axis_data())\";\n",
" var nbb_formatted_code = \"if DEBUG:\\n print(are_there_nan_values_in_axis_data())\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"if DEBUG:\n",
" print(are_there_nan_values_in_axis_data())"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "21b4baa0",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:31.426840Z",
"start_time": "2023-04-18T15:47:31.399796Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 15;\n",
" var nbb_unformatted_code = \"def get_image(image_index: int) -> np.ndarray:\\n return load_image(load_train_image_ids()[image_index])\\n\\n\\ndef build_random_image_animation(n_images=100, fps=1, figsize=(6, 4)):\\n image_indices = np.random.permutation(len(load_train_image_ids()))[:n_images]\\n first_image = get_image(image_indices[0])\\n\\n fig, ax = plt.subplots(figsize=figsize)\\n frame = plt.imshow(first_image)\\n plt.axis(\\\"off\\\")\\n plt.close()\\n\\n def animate(frame_index):\\n image_index = image_indices[frame_index]\\n image = get_image(image_index)\\n frame.set_data(image)\\n\\n return matplotlib.animation.FuncAnimation(\\n fig=fig,\\n func=animate,\\n frames=len(image_indices),\\n interval=int(1000 / fps),\\n )\";\n",
" var nbb_formatted_code = \"def get_image(image_index: int) -> np.ndarray:\\n return load_image(load_train_image_ids()[image_index])\\n\\n\\ndef build_random_image_animation(n_images=100, fps=1, figsize=(6, 4)):\\n image_indices = np.random.permutation(len(load_train_image_ids()))[:n_images]\\n first_image = get_image(image_indices[0])\\n\\n fig, ax = plt.subplots(figsize=figsize)\\n frame = plt.imshow(first_image)\\n plt.axis(\\\"off\\\")\\n plt.close()\\n\\n def animate(frame_index):\\n image_index = image_indices[frame_index]\\n image = get_image(image_index)\\n frame.set_data(image)\\n\\n return matplotlib.animation.FuncAnimation(\\n fig=fig,\\n func=animate,\\n frames=len(image_indices),\\n interval=int(1000 / fps),\\n )\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def get_image(image_index: int) -> np.ndarray:\n",
" return load_image(load_train_image_ids()[image_index])\n",
"\n",
"\n",
"def build_random_image_animation(n_images=100, fps=1, figsize=(6, 4)):\n",
" image_indices = np.random.permutation(len(load_train_image_ids()))[:n_images]\n",
" first_image = get_image(image_indices[0])\n",
"\n",
" fig, ax = plt.subplots(figsize=figsize)\n",
" frame = plt.imshow(first_image)\n",
" plt.axis(\"off\")\n",
" plt.close()\n",
"\n",
" def animate(frame_index):\n",
" image_index = image_indices[frame_index]\n",
" image = get_image(image_index)\n",
" frame.set_data(image)\n",
"\n",
" return matplotlib.animation.FuncAnimation(\n",
" fig=fig,\n",
" func=animate,\n",
" frames=len(image_indices),\n",
" interval=int(1000 / fps),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "0d592d35",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:38.818101Z",
"start_time": "2023-04-18T15:47:31.431284Z"
}
},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 16;\n",
" var nbb_unformatted_code = \"IPython.display.HTML(build_random_image_animation().to_html5_video())\";\n",
" var nbb_formatted_code = \"IPython.display.HTML(build_random_image_animation().to_html5_video())\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"IPython.display.HTML(build_random_image_animation().to_html5_video())"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "edf90004",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:38.868611Z",
"start_time": "2023-04-18T15:47:38.832024Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 17;\n",
" var nbb_unformatted_code = \"def visualize_image_stats(figsize=(12, 8)):\\n image_shapes = [ai.image.shape for ai in generate_annotated_images()]\\n\\n fig, axes = plt.subplots(nrows=2, ncols=2, figsize=figsize)\\n\\n height, width, channel = zip(*image_shapes)\\n\\n IPython.display.display(\\n pd.DataFrame(dict(width=width, height=height, channel=channel)).describe()\\n )\\n\\n plt.sca(axes[0][0])\\n plt.title(\\\"Image shapes\\\")\\n plt.xlabel(\\\"Width\\\")\\n plt.ylabel(\\\"Height\\\")\\n plt.scatter(\\n width,\\n height,\\n marker=\\\".\\\",\\n alpha=0.3,\\n )\\n plt.grid()\\n\\n plt.sca(axes[0][1])\\n plt.title(\\\"Width\\\")\\n plt.hist(width, bins=50)\\n plt.grid()\\n\\n plt.sca(axes[1][0])\\n plt.title(\\\"Height\\\")\\n plt.hist(height, bins=50)\\n plt.grid()\\n\\n plt.sca(axes[1][1])\\n plt.axis(\\\"off\\\")\\n\\n plt.tight_layout()\";\n",
" var nbb_formatted_code = \"def visualize_image_stats(figsize=(12, 8)):\\n image_shapes = [ai.image.shape for ai in generate_annotated_images()]\\n\\n fig, axes = plt.subplots(nrows=2, ncols=2, figsize=figsize)\\n\\n height, width, channel = zip(*image_shapes)\\n\\n IPython.display.display(\\n pd.DataFrame(dict(width=width, height=height, channel=channel)).describe()\\n )\\n\\n plt.sca(axes[0][0])\\n plt.title(\\\"Image shapes\\\")\\n plt.xlabel(\\\"Width\\\")\\n plt.ylabel(\\\"Height\\\")\\n plt.scatter(\\n width,\\n height,\\n marker=\\\".\\\",\\n alpha=0.3,\\n )\\n plt.grid()\\n\\n plt.sca(axes[0][1])\\n plt.title(\\\"Width\\\")\\n plt.hist(width, bins=50)\\n plt.grid()\\n\\n plt.sca(axes[1][0])\\n plt.title(\\\"Height\\\")\\n plt.hist(height, bins=50)\\n plt.grid()\\n\\n plt.sca(axes[1][1])\\n plt.axis(\\\"off\\\")\\n\\n plt.tight_layout()\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def visualize_image_stats(figsize=(12, 8)):\n",
" image_shapes = [ai.image.shape for ai in generate_annotated_images()]\n",
"\n",
" fig, axes = plt.subplots(nrows=2, ncols=2, figsize=figsize)\n",
"\n",
" height, width, channel = zip(*image_shapes)\n",
"\n",
" IPython.display.display(\n",
" pd.DataFrame(dict(width=width, height=height, channel=channel)).describe()\n",
" )\n",
"\n",
" plt.sca(axes[0][0])\n",
" plt.title(\"Image shapes\")\n",
" plt.xlabel(\"Width\")\n",
" plt.ylabel(\"Height\")\n",
" plt.scatter(\n",
" width,\n",
" height,\n",
" marker=\".\",\n",
" alpha=0.3,\n",
" )\n",
" plt.grid()\n",
"\n",
" plt.sca(axes[0][1])\n",
" plt.title(\"Width\")\n",
" plt.hist(width, bins=50)\n",
" plt.grid()\n",
"\n",
" plt.sca(axes[1][0])\n",
" plt.title(\"Height\")\n",
" plt.hist(height, bins=50)\n",
" plt.grid()\n",
"\n",
" plt.sca(axes[1][1])\n",
" plt.axis(\"off\")\n",
"\n",
" plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f385dc34",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:38.879630Z",
"start_time": "2023-04-18T15:47:38.875047Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 18;\n",
" var nbb_unformatted_code = \"if DEBUG:\\n visualize_image_stats()\";\n",
" var nbb_formatted_code = \"if DEBUG:\\n visualize_image_stats()\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"if DEBUG:\n",
" visualize_image_stats()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "c068b2ac",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:38.900221Z",
"start_time": "2023-04-18T15:47:38.881375Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 19;\n",
" var nbb_unformatted_code = \"CONFIG.image_width = 720\\nCONFIG.image_height = 512\";\n",
" var nbb_formatted_code = \"CONFIG.image_width = 720\\nCONFIG.image_height = 512\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"CONFIG.image_width = 720\n",
"CONFIG.image_height = 512"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "24f7f000",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:38.943528Z",
"start_time": "2023-04-18T15:47:38.902282Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 20;\n",
" var nbb_unformatted_code = \"def plot_image_with_annotations(image_id: str, show_categorical_data=True):\\n annotated_image = AnnotatedImage.from_image_id(image_id)\\n annotation = annotated_image.annotation\\n image = annotated_image.image\\n plt.subplots(figsize=(8, 6))\\n plt.imshow(image)\\n\\n if show_categorical_data:\\n IPython.display.display(\\n pd.Series(\\n dict(\\n source=annotation.source.value,\\n chart_type=annotation.chart_type.value,\\n x_values_type=annotation.axes.x_axis.values_type.value,\\n y_values_type=annotation.axes.y_axis.values_type.value,\\n x_tick_type=annotation.axes.x_axis.tick_type.value,\\n y_tick_type=annotation.axes.y_axis.tick_type.value,\\n )\\n )\\n )\\n\\n plt.plot(*annotation.plot_bb.get_bounds(), c=\\\"red\\\", label=\\\"bounding_box\\\")\\n\\n plt.scatter(\\n *list(zip(*[[tick.x, tick.y] for tick in annotation.axes.x_axis.ticks])),\\n label=\\\"x_ticks\\\"\\n )\\n plt.scatter(\\n *list(zip(*[[tick.x, tick.y] for tick in annotation.axes.y_axis.ticks])),\\n label=\\\"y_ticks\\\"\\n )\\n\\n text_role_colors = dict(zip(TextRole, plt.cm.Accent.colors))\\n seen_roles = set()\\n for i, text in enumerate(annotation.text):\\n xs = [\\n text.polygon.x0,\\n text.polygon.x1,\\n text.polygon.x2,\\n text.polygon.x3,\\n text.polygon.x0,\\n ]\\n ys = [\\n text.polygon.y0,\\n text.polygon.y1,\\n text.polygon.y2,\\n text.polygon.y3,\\n text.polygon.y0,\\n ]\\n plt.plot(\\n xs,\\n ys,\\n c=text_role_colors[text.role],\\n label=text.role.value if text.role not in seen_roles else None,\\n )\\n seen_roles.add(text.role)\\n\\n plt.legend(bbox_to_anchor=(1.04, 1), loc=\\\"upper left\\\")\";\n",
" var nbb_formatted_code = \"def plot_image_with_annotations(image_id: str, show_categorical_data=True):\\n annotated_image = AnnotatedImage.from_image_id(image_id)\\n annotation = annotated_image.annotation\\n image = annotated_image.image\\n plt.subplots(figsize=(8, 6))\\n plt.imshow(image)\\n\\n if show_categorical_data:\\n IPython.display.display(\\n pd.Series(\\n dict(\\n source=annotation.source.value,\\n chart_type=annotation.chart_type.value,\\n x_values_type=annotation.axes.x_axis.values_type.value,\\n y_values_type=annotation.axes.y_axis.values_type.value,\\n x_tick_type=annotation.axes.x_axis.tick_type.value,\\n y_tick_type=annotation.axes.y_axis.tick_type.value,\\n )\\n )\\n )\\n\\n plt.plot(*annotation.plot_bb.get_bounds(), c=\\\"red\\\", label=\\\"bounding_box\\\")\\n\\n plt.scatter(\\n *list(zip(*[[tick.x, tick.y] for tick in annotation.axes.x_axis.ticks])),\\n label=\\\"x_ticks\\\"\\n )\\n plt.scatter(\\n *list(zip(*[[tick.x, tick.y] for tick in annotation.axes.y_axis.ticks])),\\n label=\\\"y_ticks\\\"\\n )\\n\\n text_role_colors = dict(zip(TextRole, plt.cm.Accent.colors))\\n seen_roles = set()\\n for i, text in enumerate(annotation.text):\\n xs = [\\n text.polygon.x0,\\n text.polygon.x1,\\n text.polygon.x2,\\n text.polygon.x3,\\n text.polygon.x0,\\n ]\\n ys = [\\n text.polygon.y0,\\n text.polygon.y1,\\n text.polygon.y2,\\n text.polygon.y3,\\n text.polygon.y0,\\n ]\\n plt.plot(\\n xs,\\n ys,\\n c=text_role_colors[text.role],\\n label=text.role.value if text.role not in seen_roles else None,\\n )\\n seen_roles.add(text.role)\\n\\n plt.legend(bbox_to_anchor=(1.04, 1), loc=\\\"upper left\\\")\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def plot_image_with_annotations(image_id: str, show_categorical_data=True):\n",
" annotated_image = AnnotatedImage.from_image_id(image_id)\n",
" annotation = annotated_image.annotation\n",
" image = annotated_image.image\n",
" plt.subplots(figsize=(8, 6))\n",
" plt.imshow(image)\n",
"\n",
" if show_categorical_data:\n",
" IPython.display.display(\n",
" pd.Series(\n",
" dict(\n",
" source=annotation.source.value,\n",
" chart_type=annotation.chart_type.value,\n",
" x_values_type=annotation.axes.x_axis.values_type.value,\n",
" y_values_type=annotation.axes.y_axis.values_type.value,\n",
" x_tick_type=annotation.axes.x_axis.tick_type.value,\n",
" y_tick_type=annotation.axes.y_axis.tick_type.value,\n",
" )\n",
" )\n",
" )\n",
"\n",
" plt.plot(*annotation.plot_bb.get_bounds(), c=\"red\", label=\"bounding_box\")\n",
"\n",
" plt.scatter(\n",
" *list(zip(*[[tick.x, tick.y] for tick in annotation.axes.x_axis.ticks])),\n",
" label=\"x_ticks\"\n",
" )\n",
" plt.scatter(\n",
" *list(zip(*[[tick.x, tick.y] for tick in annotation.axes.y_axis.ticks])),\n",
" label=\"y_ticks\"\n",
" )\n",
"\n",
" text_role_colors = dict(zip(TextRole, plt.cm.Accent.colors))\n",
" seen_roles = set()\n",
" for i, text in enumerate(annotation.text):\n",
" xs = [\n",
" text.polygon.x0,\n",
" text.polygon.x1,\n",
" text.polygon.x2,\n",
" text.polygon.x3,\n",
" text.polygon.x0,\n",
" ]\n",
" ys = [\n",
" text.polygon.y0,\n",
" text.polygon.y1,\n",
" text.polygon.y2,\n",
" text.polygon.y3,\n",
" text.polygon.y0,\n",
" ]\n",
" plt.plot(\n",
" xs,\n",
" ys,\n",
" c=text_role_colors[text.role],\n",
" label=text.role.value if text.role not in seen_roles else None,\n",
" )\n",
" seen_roles.add(text.role)\n",
"\n",
" plt.legend(bbox_to_anchor=(1.04, 1), loc=\"upper left\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "a54cc20e",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:39.579100Z",
"start_time": "2023-04-18T15:47:38.949939Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"source generated\n",
"chart_type line\n",
"x_values_type categorical\n",
"y_values_type numerical\n",
"x_tick_type markers\n",
"y_tick_type markers\n",
"dtype: object"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 21;\n",
" var nbb_unformatted_code = \"plot_image_with_annotations(np.random.choice(load_train_image_ids()))\";\n",
" var nbb_formatted_code = \"plot_image_with_annotations(np.random.choice(load_train_image_ids()))\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_image_with_annotations(np.random.choice(load_train_image_ids()))"
]
},
{
"cell_type": "markdown",
"id": "88ae66a0",
"metadata": {},
"source": [
"### Data splits "
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "7b2e2e49",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:39.606668Z",
"start_time": "2023-04-18T15:47:39.581401Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 22;\n",
" var nbb_unformatted_code = \"def split_train_indices_by_source():\\n extracted_image_indices = []\\n generated_image_indices = []\\n for i, annotated_image in enumerate(generate_annotated_images()):\\n if annotated_image.annotation.source == Source.extracted:\\n extracted_image_indices.append(i)\\n else:\\n generated_image_indices.append(i)\\n return extracted_image_indices, generated_image_indices\\n\\n\\ndef get_train_val_split_indices(val_fraction=0.1, seed=42):\\n np.random.seed(42)\\n val_size = int(len(load_train_image_ids()) * val_fraction)\\n\\n extracted_image_indices, generated_image_indices = split_train_indices_by_source()\\n extracted_image_indices = np.random.permutation(extracted_image_indices)\\n generated_image_indices = np.random.permutation(generated_image_indices)\\n\\n val_indices = extracted_image_indices[:val_size]\\n n_generated_images_in_val = val_size - len(val_indices)\\n val_indices = np.concatenate(\\n [val_indices, generated_image_indices[:n_generated_images_in_val]]\\n )\\n\\n train_indices = generated_image_indices[n_generated_images_in_val:]\\n\\n assert len(set(train_indices) | set(val_indices)) == len(load_train_image_ids())\\n assert len(val_indices) == val_size\\n assert len(set(train_indices) & set(val_indices)) == 0\\n\\n return train_indices, val_indices\";\n",
" var nbb_formatted_code = \"def split_train_indices_by_source():\\n extracted_image_indices = []\\n generated_image_indices = []\\n for i, annotated_image in enumerate(generate_annotated_images()):\\n if annotated_image.annotation.source == Source.extracted:\\n extracted_image_indices.append(i)\\n else:\\n generated_image_indices.append(i)\\n return extracted_image_indices, generated_image_indices\\n\\n\\ndef get_train_val_split_indices(val_fraction=0.1, seed=42):\\n np.random.seed(42)\\n val_size = int(len(load_train_image_ids()) * val_fraction)\\n\\n extracted_image_indices, generated_image_indices = split_train_indices_by_source()\\n extracted_image_indices = np.random.permutation(extracted_image_indices)\\n generated_image_indices = np.random.permutation(generated_image_indices)\\n\\n val_indices = extracted_image_indices[:val_size]\\n n_generated_images_in_val = val_size - len(val_indices)\\n val_indices = np.concatenate(\\n [val_indices, generated_image_indices[:n_generated_images_in_val]]\\n )\\n\\n train_indices = generated_image_indices[n_generated_images_in_val:]\\n\\n assert len(set(train_indices) | set(val_indices)) == len(load_train_image_ids())\\n assert len(val_indices) == val_size\\n assert len(set(train_indices) & set(val_indices)) == 0\\n\\n return train_indices, val_indices\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def split_train_indices_by_source():\n",
" extracted_image_indices = []\n",
" generated_image_indices = []\n",
" for i, annotated_image in enumerate(generate_annotated_images()):\n",
" if annotated_image.annotation.source == Source.extracted:\n",
" extracted_image_indices.append(i)\n",
" else:\n",
" generated_image_indices.append(i)\n",
" return extracted_image_indices, generated_image_indices\n",
"\n",
"\n",
"def get_train_val_split_indices(val_fraction=0.1, seed=42):\n",
" np.random.seed(42)\n",
" val_size = int(len(load_train_image_ids()) * val_fraction)\n",
"\n",
" extracted_image_indices, generated_image_indices = split_train_indices_by_source()\n",
" extracted_image_indices = np.random.permutation(extracted_image_indices)\n",
" generated_image_indices = np.random.permutation(generated_image_indices)\n",
"\n",
" val_indices = extracted_image_indices[:val_size]\n",
" n_generated_images_in_val = val_size - len(val_indices)\n",
" val_indices = np.concatenate(\n",
" [val_indices, generated_image_indices[:n_generated_images_in_val]]\n",
" )\n",
"\n",
" train_indices = generated_image_indices[n_generated_images_in_val:]\n",
"\n",
" assert len(set(train_indices) | set(val_indices)) == len(load_train_image_ids())\n",
" assert len(val_indices) == val_size\n",
" assert len(set(train_indices) & set(val_indices)) == 0\n",
"\n",
" return train_indices, val_indices"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "5ae948ff",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:39.648756Z",
"start_time": "2023-04-18T15:47:39.608585Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reusing split indices.\n"
]
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 23;\n",
" var nbb_unformatted_code = \"CONFIG.val_fraction = 0.1\\nCONFIG.seed = 42\\nCONFIG.train_indices_path = \\\"train_indices.pickle\\\"\\nCONFIG.val_indices_path = \\\"val_indices.pickle\\\"\\n\\nif os.path.exists(CONFIG.train_indices_path) and os.path.exists(\\n CONFIG.val_indices_path\\n):\\n DATA.train_indices = pickle.load(open(CONFIG.train_indices_path, \\\"rb\\\"))\\n DATA.val_indices = pickle.load(open(CONFIG.val_indices_path, \\\"rb\\\"))\\n print(\\\"Reusing split indices.\\\")\\nelse:\\n DATA.train_indices = (\\n DATA.train_indices,\\n DATA.val_indices,\\n ) = get_train_val_split_indices(CONFIG.val_fraction, CONFIG.seed)\\n pickle.dump(DATA.train_indices, open(CONFIG.train_indices_path, \\\"wb\\\"))\\n pickle.dump(DATA.val_indices, open(CONFIG.val_indices_path, \\\"wb\\\"))\";\n",
" var nbb_formatted_code = \"CONFIG.val_fraction = 0.1\\nCONFIG.seed = 42\\nCONFIG.train_indices_path = \\\"train_indices.pickle\\\"\\nCONFIG.val_indices_path = \\\"val_indices.pickle\\\"\\n\\nif os.path.exists(CONFIG.train_indices_path) and os.path.exists(\\n CONFIG.val_indices_path\\n):\\n DATA.train_indices = pickle.load(open(CONFIG.train_indices_path, \\\"rb\\\"))\\n DATA.val_indices = pickle.load(open(CONFIG.val_indices_path, \\\"rb\\\"))\\n print(\\\"Reusing split indices.\\\")\\nelse:\\n DATA.train_indices = (\\n DATA.train_indices,\\n DATA.val_indices,\\n ) = get_train_val_split_indices(CONFIG.val_fraction, CONFIG.seed)\\n pickle.dump(DATA.train_indices, open(CONFIG.train_indices_path, \\\"wb\\\"))\\n pickle.dump(DATA.val_indices, open(CONFIG.val_indices_path, \\\"wb\\\"))\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"CONFIG.val_fraction = 0.1\n",
"CONFIG.seed = 42\n",
"CONFIG.train_val_indices_path = \"data/train_val_indices.pickle\"\n",
"\n",
"DATA.train_indices, DATA.val_indices = load_pickle_or_build_object_and_save(\n",
" CONFIG.train_val_indices_path,\n",
" lambda : get_train_val_split_indices(CONFIG.val_fraction, CONFIG.seed)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "2a8711a2",
"metadata": {},
"source": [
"### Expected model output format "
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "52e5fc7e",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:39.678486Z",
"start_time": "2023-04-18T15:47:39.650465Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" data_series | \n",
" chart_type | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 000b92c3b098_x | \n",
" abc;def | \n",
" vertical_bar | \n",
"
\n",
" \n",
" 1 | \n",
" 000b92c3b098_y | \n",
" 0.0;1.0 | \n",
" vertical_bar | \n",
"
\n",
" \n",
" 2 | \n",
" 007a18eb4e09_x | \n",
" abc;def | \n",
" vertical_bar | \n",
"
\n",
" \n",
" 3 | \n",
" 007a18eb4e09_y | \n",
" 0.0;1.0 | \n",
" vertical_bar | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" id data_series chart_type\n",
"0 000b92c3b098_x abc;def vertical_bar\n",
"1 000b92c3b098_y 0.0;1.0 vertical_bar\n",
"2 007a18eb4e09_x abc;def vertical_bar\n",
"3 007a18eb4e09_y 0.0;1.0 vertical_bar"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 24;\n",
" var nbb_unformatted_code = \"pd.read_csv(\\\"data/sample_submission.csv\\\").head(4)\";\n",
" var nbb_formatted_code = \"pd.read_csv(\\\"data/sample_submission.csv\\\").head(4)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pd.read_csv(\"data/sample_submission.csv\").head(4)"
]
},
{
"cell_type": "markdown",
"id": "4be2fa0d",
"metadata": {},
"source": [
"In the Benetech competition I need to predict chart type and axis values, so I will create appropriate tokens and later add them to the transformer."
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "6d209989",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:39.708025Z",
"start_time": "2023-04-18T15:47:39.680130Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 25;\n",
" var nbb_unformatted_code = \"def to_token_str(value: str or enum.Enum):\\n string = value.name if isinstance(value, enum.Enum) else value\\n if re.fullmatch(\\\"<.*>\\\", string):\\n return string\\n else:\\n return f\\\"<{string}>\\\"\\n\\n\\nTOKEN.benetech_prompt = to_token_str(\\\"benetech_prompt\\\")\\nTOKEN.benetech_prompt_end = to_token_str(\\\"/benetech_prompt\\\")\\n\\nfor chart_type in ChartType:\\n setattr(TOKEN, chart_type.name, to_token_str(chart_type))\\n\\nfor values_type in ValuesType:\\n setattr(TOKEN, values_type.name, to_token_str(values_type))\\n\\nTOKEN.x_start = to_token_str(\\\"x_start\\\")\\nTOKEN.y_start = to_token_str(\\\"y_start\\\")\\nTOKEN.value_separator = to_token_str(\\\";\\\")\";\n",
" var nbb_formatted_code = \"def to_token_str(value: str or enum.Enum):\\n string = value.name if isinstance(value, enum.Enum) else value\\n if re.fullmatch(\\\"<.*>\\\", string):\\n return string\\n else:\\n return f\\\"<{string}>\\\"\\n\\n\\nTOKEN.benetech_prompt = to_token_str(\\\"benetech_prompt\\\")\\nTOKEN.benetech_prompt_end = to_token_str(\\\"/benetech_prompt\\\")\\n\\nfor chart_type in ChartType:\\n setattr(TOKEN, chart_type.name, to_token_str(chart_type))\\n\\nfor values_type in ValuesType:\\n setattr(TOKEN, values_type.name, to_token_str(values_type))\\n\\nTOKEN.x_start = to_token_str(\\\"x_start\\\")\\nTOKEN.y_start = to_token_str(\\\"y_start\\\")\\nTOKEN.value_separator = to_token_str(\\\";\\\")\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def to_token_str(value: str or enum.Enum):\n",
" string = value.name if isinstance(value, enum.Enum) else value\n",
" if re.fullmatch(\"<.*>\", string):\n",
" return string\n",
" else:\n",
" return f\"<{string}>\"\n",
"\n",
"\n",
"TOKEN.benetech_prompt = to_token_str(\"benetech_prompt\")\n",
"TOKEN.benetech_prompt_end = to_token_str(\"/benetech_prompt\")\n",
"\n",
"for chart_type in ChartType:\n",
" setattr(TOKEN, chart_type.name, to_token_str(chart_type))\n",
"\n",
"for values_type in ValuesType:\n",
" setattr(TOKEN, values_type.name, to_token_str(values_type))\n",
"\n",
"TOKEN.x_start = to_token_str(\"x_start\")\n",
"TOKEN.y_start = to_token_str(\"y_start\")\n",
"TOKEN.value_separator = to_token_str(\";\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "6a100c8e",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:39.743966Z",
"start_time": "2023-04-18T15:47:39.722826Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 26;\n",
" var nbb_unformatted_code = \"CONFIG.float_scientific_notation_string_precision = 5\\n\\n\\ndef convert_number_to_scientific_string(value: int or float) -> str:\\n return f\\\"{value:.{CONFIG.float_scientific_notation_string_precision}e}\\\"\\n\\n\\ndef convert_axis_data_to_string(\\n axis_data: list[str or float], values_type: ValuesType\\n) -> str:\\n formatted_axis_data = []\\n for value in axis_data:\\n if values_type == ValuesType.numerical:\\n value = convert_number_to_scientific_string(value)\\n formatted_axis_data.append(value)\\n return TOKEN.value_separator.join(formatted_axis_data)\\n\\n\\ndef convert_string_to_axis_data(string, values_type: ValuesType):\\n data = string.split(TOKEN.value_separator)\\n if values_type == ValuesType.numerical:\\n data = [float(i) for i in data]\\n return data\\n\\n\\ndef compute_numeric_data_loss_due_to_string_conversion():\\n squared_error = 0\\n n_numeric_values = 0\\n for annotated_image in generate_annotated_images():\\n annotation = annotated_image.annotation\\n for axis, data in zip(\\n [annotation.axes.x_axis, annotation.axes.y_axis],\\n [\\n [dp.x for dp in annotation.data_series],\\n [dp.y for dp in annotation.data_series],\\n ],\\n ):\\n if axis.values_type == ValuesType.numerical:\\n string = convert_axis_data_to_string(data, ValuesType.numerical)\\n reconverted_data = convert_string_to_axis_data(\\n string, ValuesType.numerical\\n )\\n squared_error += (\\n (np.array(data) - np.array(reconverted_data)) ** 2\\n ).sum()\\n n_numeric_values += len(data)\\n\\n mse = squared_error**0.5 / n_numeric_values\\n return mse\";\n",
" var nbb_formatted_code = \"CONFIG.float_scientific_notation_string_precision = 5\\n\\n\\ndef convert_number_to_scientific_string(value: int or float) -> str:\\n return f\\\"{value:.{CONFIG.float_scientific_notation_string_precision}e}\\\"\\n\\n\\ndef convert_axis_data_to_string(\\n axis_data: list[str or float], values_type: ValuesType\\n) -> str:\\n formatted_axis_data = []\\n for value in axis_data:\\n if values_type == ValuesType.numerical:\\n value = convert_number_to_scientific_string(value)\\n formatted_axis_data.append(value)\\n return TOKEN.value_separator.join(formatted_axis_data)\\n\\n\\ndef convert_string_to_axis_data(string, values_type: ValuesType):\\n data = string.split(TOKEN.value_separator)\\n if values_type == ValuesType.numerical:\\n data = [float(i) for i in data]\\n return data\\n\\n\\ndef compute_numeric_data_loss_due_to_string_conversion():\\n squared_error = 0\\n n_numeric_values = 0\\n for annotated_image in generate_annotated_images():\\n annotation = annotated_image.annotation\\n for axis, data in zip(\\n [annotation.axes.x_axis, annotation.axes.y_axis],\\n [\\n [dp.x for dp in annotation.data_series],\\n [dp.y for dp in annotation.data_series],\\n ],\\n ):\\n if axis.values_type == ValuesType.numerical:\\n string = convert_axis_data_to_string(data, ValuesType.numerical)\\n reconverted_data = convert_string_to_axis_data(\\n string, ValuesType.numerical\\n )\\n squared_error += (\\n (np.array(data) - np.array(reconverted_data)) ** 2\\n ).sum()\\n n_numeric_values += len(data)\\n\\n mse = squared_error**0.5 / n_numeric_values\\n return mse\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"CONFIG.float_scientific_notation_string_precision = 5\n",
"\n",
"\n",
"def convert_number_to_scientific_string(value: int or float) -> str:\n",
" return f\"{value:.{CONFIG.float_scientific_notation_string_precision}e}\"\n",
"\n",
"\n",
"def convert_axis_data_to_string(\n",
" axis_data: list[str or float], values_type: ValuesType\n",
") -> str:\n",
" formatted_axis_data = []\n",
" for value in axis_data:\n",
" if values_type == ValuesType.numerical:\n",
" value = convert_number_to_scientific_string(value)\n",
" formatted_axis_data.append(value)\n",
" return TOKEN.value_separator.join(formatted_axis_data)\n",
"\n",
"\n",
"def convert_string_to_axis_data(string, values_type: ValuesType):\n",
" data = string.split(TOKEN.value_separator)\n",
" if values_type == ValuesType.numerical:\n",
" data = [float(i) for i in data]\n",
" return data\n",
"\n",
"\n",
"def compute_numeric_data_loss_due_to_string_conversion():\n",
" squared_error = 0\n",
" n_numeric_values = 0\n",
" for annotated_image in generate_annotated_images():\n",
" annotation = annotated_image.annotation\n",
" for axis, data in zip(\n",
" [annotation.axes.x_axis, annotation.axes.y_axis],\n",
" [\n",
" [dp.x for dp in annotation.data_series],\n",
" [dp.y for dp in annotation.data_series],\n",
" ],\n",
" ):\n",
" if axis.values_type == ValuesType.numerical:\n",
" string = convert_axis_data_to_string(data, ValuesType.numerical)\n",
" reconverted_data = convert_string_to_axis_data(\n",
" string, ValuesType.numerical\n",
" )\n",
" squared_error += (\n",
" (np.array(data) - np.array(reconverted_data)) ** 2\n",
" ).sum()\n",
" n_numeric_values += len(data)\n",
"\n",
" mse = squared_error**0.5 / n_numeric_values\n",
" return mse"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "e5ae33b0",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:39.782581Z",
"start_time": "2023-04-18T15:47:39.750579Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 27;\n",
" var nbb_unformatted_code = \"if DEBUG:\\n print(compute_numeric_data_loss_due_to_string_conversion())\";\n",
" var nbb_formatted_code = \"if DEBUG:\\n print(compute_numeric_data_loss_due_to_string_conversion())\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"if DEBUG:\n",
" print(compute_numeric_data_loss_due_to_string_conversion())"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "46dff28d",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:39.858163Z",
"start_time": "2023-04-18T15:47:39.785386Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 28;\n",
" var nbb_unformatted_code = \"@dataclasses.dataclass\\nclass BenetechOutput:\\n chart_type: ChartType\\n x_values_type: ValuesType\\n y_values_type: ValuesType\\n x_data: list[str or float]\\n y_data: list[str or float]\\n\\n def __post_init__(self):\\n self.chart_type = ChartType(self.chart_type)\\n self.x_values_type = ValuesType(self.x_values_type)\\n self.y_values_type = ValuesType(self.y_values_type)\\n assert isinstance(self.x_data, list)\\n assert isinstance(self.y_data, list)\\n\\n def get_main_characteristics(self):\\n return (\\n self.chart_type,\\n self.x_values_type,\\n self.y_values_type,\\n len(self.x_data),\\n len(self.y_data),\\n )\\n\\n @staticmethod\\n def from_annotation(annotation: Annotation):\\n return BenetechOutput(\\n chart_type=annotation.chart_type,\\n x_values_type=annotation.axes.x_axis.values_type,\\n y_values_type=annotation.axes.y_axis.values_type,\\n x_data=[dp.x for dp in annotation.data_series],\\n y_data=[dp.y for dp in annotation.data_series],\\n )\\n\\n def to_string(self):\\n return self.format_strings(\\n chart_type=self.chart_type,\\n x_values_type=self.x_values_type,\\n y_values_type=self.y_values_type,\\n x_data=convert_axis_data_to_string(self.x_data, self.x_values_type),\\n y_data=convert_axis_data_to_string(self.y_data, self.y_values_type),\\n )\\n\\n @staticmethod\\n def format_strings(*, chart_type, x_values_type, y_values_type, x_data, y_data):\\n chart_type = to_token_str(chart_type)\\n x_values_type = to_token_str(x_values_type)\\n y_values_type = to_token_str(y_values_type)\\n return (\\n f\\\"{TOKEN.benetech_prompt}{chart_type}\\\"\\n f\\\"{TOKEN.x_start}{x_values_type}{x_data}\\\"\\n f\\\"{TOKEN.y_start}{y_values_type}{y_data}\\\"\\n f\\\"{TOKEN.benetech_prompt_end}\\\"\\n )\\n\\n @staticmethod\\n def get_string_pattern():\\n field_names = [field.name for field in dataclasses.fields(BenetechOutput)]\\n pattern = BenetechOutput.format_strings(\\n **{field_name: f\\\"(?P<{field_name}>.*?)\\\" for field_name in field_names}\\n )\\n return pattern\\n\\n @staticmethod\\n def does_string_match_expected_pattern(string):\\n return bool(re.fullmatch(BenetechOutput.get_string_pattern(), string))\\n\\n @staticmethod\\n def from_string(string):\\n fullmatch = re.fullmatch(BenetechOutput.get_string_pattern(), string)\\n benetech_kwargs = fullmatch.groupdict()\\n benetech_kwargs[\\\"chart_type\\\"] = ChartType(benetech_kwargs[\\\"chart_type\\\"])\\n benetech_kwargs[\\\"x_values_type\\\"] = ValuesType(benetech_kwargs[\\\"x_values_type\\\"])\\n benetech_kwargs[\\\"y_values_type\\\"] = ValuesType(benetech_kwargs[\\\"y_values_type\\\"])\\n benetech_kwargs[\\\"x_data\\\"] = convert_string_to_axis_data(\\n benetech_kwargs[\\\"x_data\\\"], benetech_kwargs[\\\"x_values_type\\\"]\\n )\\n benetech_kwargs[\\\"y_data\\\"] = convert_string_to_axis_data(\\n benetech_kwargs[\\\"y_data\\\"], benetech_kwargs[\\\"y_values_type\\\"]\\n )\\n return BenetechOutput(**benetech_kwargs)\\n\\n\\ndef get_annotation_ground_truth_str(annotation: Annotation):\\n benetech_output = BenetechOutput(\\n chart_type=annotation.chart_type,\\n x_values_type=annotation.axes.x_axis.values_type,\\n x_data=[dp.x for dp in annotation.data_series],\\n y_values_type=annotation.axes.y_axis.values_type,\\n y_data=[dp.y for dp in annotation.data_series],\\n )\\n return benetech_output.to_string()\\n\\n\\ndef get_annotation_ground_truth_str_from_image_index(image_index: int) -> str:\\n return get_annotation_ground_truth_str(Annotation.from_image_index(0))\";\n",
" var nbb_formatted_code = \"@dataclasses.dataclass\\nclass BenetechOutput:\\n chart_type: ChartType\\n x_values_type: ValuesType\\n y_values_type: ValuesType\\n x_data: list[str or float]\\n y_data: list[str or float]\\n\\n def __post_init__(self):\\n self.chart_type = ChartType(self.chart_type)\\n self.x_values_type = ValuesType(self.x_values_type)\\n self.y_values_type = ValuesType(self.y_values_type)\\n assert isinstance(self.x_data, list)\\n assert isinstance(self.y_data, list)\\n\\n def get_main_characteristics(self):\\n return (\\n self.chart_type,\\n self.x_values_type,\\n self.y_values_type,\\n len(self.x_data),\\n len(self.y_data),\\n )\\n\\n @staticmethod\\n def from_annotation(annotation: Annotation):\\n return BenetechOutput(\\n chart_type=annotation.chart_type,\\n x_values_type=annotation.axes.x_axis.values_type,\\n y_values_type=annotation.axes.y_axis.values_type,\\n x_data=[dp.x for dp in annotation.data_series],\\n y_data=[dp.y for dp in annotation.data_series],\\n )\\n\\n def to_string(self):\\n return self.format_strings(\\n chart_type=self.chart_type,\\n x_values_type=self.x_values_type,\\n y_values_type=self.y_values_type,\\n x_data=convert_axis_data_to_string(self.x_data, self.x_values_type),\\n y_data=convert_axis_data_to_string(self.y_data, self.y_values_type),\\n )\\n\\n @staticmethod\\n def format_strings(*, chart_type, x_values_type, y_values_type, x_data, y_data):\\n chart_type = to_token_str(chart_type)\\n x_values_type = to_token_str(x_values_type)\\n y_values_type = to_token_str(y_values_type)\\n return (\\n f\\\"{TOKEN.benetech_prompt}{chart_type}\\\"\\n f\\\"{TOKEN.x_start}{x_values_type}{x_data}\\\"\\n f\\\"{TOKEN.y_start}{y_values_type}{y_data}\\\"\\n f\\\"{TOKEN.benetech_prompt_end}\\\"\\n )\\n\\n @staticmethod\\n def get_string_pattern():\\n field_names = [field.name for field in dataclasses.fields(BenetechOutput)]\\n pattern = BenetechOutput.format_strings(\\n **{field_name: f\\\"(?P<{field_name}>.*?)\\\" for field_name in field_names}\\n )\\n return pattern\\n\\n @staticmethod\\n def does_string_match_expected_pattern(string):\\n return bool(re.fullmatch(BenetechOutput.get_string_pattern(), string))\\n\\n @staticmethod\\n def from_string(string):\\n fullmatch = re.fullmatch(BenetechOutput.get_string_pattern(), string)\\n benetech_kwargs = fullmatch.groupdict()\\n benetech_kwargs[\\\"chart_type\\\"] = ChartType(benetech_kwargs[\\\"chart_type\\\"])\\n benetech_kwargs[\\\"x_values_type\\\"] = ValuesType(benetech_kwargs[\\\"x_values_type\\\"])\\n benetech_kwargs[\\\"y_values_type\\\"] = ValuesType(benetech_kwargs[\\\"y_values_type\\\"])\\n benetech_kwargs[\\\"x_data\\\"] = convert_string_to_axis_data(\\n benetech_kwargs[\\\"x_data\\\"], benetech_kwargs[\\\"x_values_type\\\"]\\n )\\n benetech_kwargs[\\\"y_data\\\"] = convert_string_to_axis_data(\\n benetech_kwargs[\\\"y_data\\\"], benetech_kwargs[\\\"y_values_type\\\"]\\n )\\n return BenetechOutput(**benetech_kwargs)\\n\\n\\ndef get_annotation_ground_truth_str(annotation: Annotation):\\n benetech_output = BenetechOutput(\\n chart_type=annotation.chart_type,\\n x_values_type=annotation.axes.x_axis.values_type,\\n x_data=[dp.x for dp in annotation.data_series],\\n y_values_type=annotation.axes.y_axis.values_type,\\n y_data=[dp.y for dp in annotation.data_series],\\n )\\n return benetech_output.to_string()\\n\\n\\ndef get_annotation_ground_truth_str_from_image_index(image_index: int) -> str:\\n return get_annotation_ground_truth_str(Annotation.from_image_index(0))\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"@dataclasses.dataclass\n",
"class BenetechOutput:\n",
" chart_type: ChartType\n",
" x_values_type: ValuesType\n",
" y_values_type: ValuesType\n",
" x_data: list[str or float]\n",
" y_data: list[str or float]\n",
"\n",
" def __post_init__(self):\n",
" self.chart_type = ChartType(self.chart_type)\n",
" self.x_values_type = ValuesType(self.x_values_type)\n",
" self.y_values_type = ValuesType(self.y_values_type)\n",
" assert isinstance(self.x_data, list)\n",
" assert isinstance(self.y_data, list)\n",
"\n",
" def get_main_characteristics(self):\n",
" return (\n",
" self.chart_type,\n",
" self.x_values_type,\n",
" self.y_values_type,\n",
" len(self.x_data),\n",
" len(self.y_data),\n",
" )\n",
"\n",
" @staticmethod\n",
" def from_annotation(annotation: Annotation):\n",
" return BenetechOutput(\n",
" chart_type=annotation.chart_type,\n",
" x_values_type=annotation.axes.x_axis.values_type,\n",
" y_values_type=annotation.axes.y_axis.values_type,\n",
" x_data=[dp.x for dp in annotation.data_series],\n",
" y_data=[dp.y for dp in annotation.data_series],\n",
" )\n",
"\n",
" def to_string(self):\n",
" return self.format_strings(\n",
" chart_type=self.chart_type,\n",
" x_values_type=self.x_values_type,\n",
" y_values_type=self.y_values_type,\n",
" x_data=convert_axis_data_to_string(self.x_data, self.x_values_type),\n",
" y_data=convert_axis_data_to_string(self.y_data, self.y_values_type),\n",
" )\n",
"\n",
" @staticmethod\n",
" def format_strings(*, chart_type, x_values_type, y_values_type, x_data, y_data):\n",
" chart_type = to_token_str(chart_type)\n",
" x_values_type = to_token_str(x_values_type)\n",
" y_values_type = to_token_str(y_values_type)\n",
" return (\n",
" f\"{TOKEN.benetech_prompt}{chart_type}\"\n",
" f\"{TOKEN.x_start}{x_values_type}{x_data}\"\n",
" f\"{TOKEN.y_start}{y_values_type}{y_data}\"\n",
" f\"{TOKEN.benetech_prompt_end}\"\n",
" )\n",
"\n",
" @staticmethod\n",
" def get_string_pattern():\n",
" field_names = [field.name for field in dataclasses.fields(BenetechOutput)]\n",
" pattern = BenetechOutput.format_strings(\n",
" **{field_name: f\"(?P<{field_name}>.*?)\" for field_name in field_names}\n",
" )\n",
" return pattern\n",
"\n",
" @staticmethod\n",
" def does_string_match_expected_pattern(string):\n",
" return bool(re.fullmatch(BenetechOutput.get_string_pattern(), string))\n",
"\n",
" @staticmethod\n",
" def from_string(string):\n",
" fullmatch = re.fullmatch(BenetechOutput.get_string_pattern(), string)\n",
" benetech_kwargs = fullmatch.groupdict()\n",
" benetech_kwargs[\"chart_type\"] = ChartType(benetech_kwargs[\"chart_type\"])\n",
" benetech_kwargs[\"x_values_type\"] = ValuesType(benetech_kwargs[\"x_values_type\"])\n",
" benetech_kwargs[\"y_values_type\"] = ValuesType(benetech_kwargs[\"y_values_type\"])\n",
" benetech_kwargs[\"x_data\"] = convert_string_to_axis_data(\n",
" benetech_kwargs[\"x_data\"], benetech_kwargs[\"x_values_type\"]\n",
" )\n",
" benetech_kwargs[\"y_data\"] = convert_string_to_axis_data(\n",
" benetech_kwargs[\"y_data\"], benetech_kwargs[\"y_values_type\"]\n",
" )\n",
" return BenetechOutput(**benetech_kwargs)\n",
"\n",
"\n",
"def get_annotation_ground_truth_str(annotation: Annotation):\n",
" benetech_output = BenetechOutput(\n",
" chart_type=annotation.chart_type,\n",
" x_values_type=annotation.axes.x_axis.values_type,\n",
" x_data=[dp.x for dp in annotation.data_series],\n",
" y_values_type=annotation.axes.y_axis.values_type,\n",
" y_data=[dp.y for dp in annotation.data_series],\n",
" )\n",
" return benetech_output.to_string()\n",
"\n",
"\n",
"def get_annotation_ground_truth_str_from_image_index(image_index: int) -> str:\n",
" return get_annotation_ground_truth_str(Annotation.from_image_index(image_index))"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "8342617b",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:39.881898Z",
"start_time": "2023-04-18T15:47:39.861073Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 29;\n",
" var nbb_unformatted_code = \"if DEBUG:\\n print(BenetechOutput.get_string_pattern(), \\\"\\\\n\\\")\\n print(\\n get_annotation_ground_truth_str(AnnotatedImage.from_image_index(0).annotation),\\n \\\"\\\\n\\\",\\n )\\n pprint.pprint(\\n BenetechOutput.from_string(get_annotation_ground_truth_str_from_image_index(0))\\n )\\n pprint.pprint(BenetechOutput.from_annotation(Annotation.from_image_index(0)))\";\n",
" var nbb_formatted_code = \"if DEBUG:\\n print(BenetechOutput.get_string_pattern(), \\\"\\\\n\\\")\\n print(\\n get_annotation_ground_truth_str(AnnotatedImage.from_image_index(0).annotation),\\n \\\"\\\\n\\\",\\n )\\n pprint.pprint(\\n BenetechOutput.from_string(get_annotation_ground_truth_str_from_image_index(0))\\n )\\n pprint.pprint(BenetechOutput.from_annotation(Annotation.from_image_index(0)))\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"if DEBUG:\n",
" print(BenetechOutput.get_string_pattern(), \"\\n\")\n",
" print(\n",
" get_annotation_ground_truth_str(AnnotatedImage.from_image_index(0).annotation),\n",
" \"\\n\",\n",
" )\n",
" pprint.pprint(\n",
" BenetechOutput.from_string(get_annotation_ground_truth_str_from_image_index(0))\n",
" )\n",
" pprint.pprint(BenetechOutput.from_annotation(Annotation.from_image_index(0)))"
]
},
{
"cell_type": "markdown",
"id": "3368ace9",
"metadata": {},
"source": [
"### Metrics "
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "3901ad2f",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:39.926904Z",
"start_time": "2023-04-18T15:47:39.883983Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 30;\n",
" var nbb_unformatted_code = \"def normalized_rmse(expected: list[float], predicted: list[float]) -> float:\\n return (1 - sklearn.metrics.r2_score(expected, predicted)) ** 0.5\\n\\n\\ndef normalized_levenshtein_distance(expected: list[str], predicted: list[str]) -> float:\\n total_distance = 0\\n for e, p in zip(expected, predicted):\\n total_distance += rapidfuzz.distance.Levenshtein.distance(e, p)\\n total_length = np.sum([len(e) for e in expected])\\n return total_distance / total_length\\n\\n\\ndef sigmoid(x):\\n return 1 / (1 + np.exp(-x))\\n\\n\\ndef positive_loss_to_score(x):\\n return 2 * sigmoid(-x)\\n\\n\\ndef score_axis_values(values_type, expected, predicted):\\n if values_type == ValuesType.numerical:\\n loss = normalized_rmse(expected, predicted)\\n else:\\n loss = normalized_levenshtein_distance(expected, predicted)\\n return positive_loss_to_score(loss)\\n\\n\\ndef benetech_score(expected: BenetechOutput, predicted: BenetechOutput) -> float:\\n if expected.get_main_characteristics() != predicted.get_main_characteristics():\\n return 0\\n x_score = score_axis_values(\\n expected.x_values_type, expected.x_data, predicted.x_data\\n )\\n y_score = score_axis_values(\\n expected.y_values_type, expected.y_data, predicted.y_data\\n )\\n return (x_score + y_score) / 2\\n\\n\\ndef benetech_score_string_prediction(expected_data_index: int, predicted_string: str):\\n if not BenetechOutput.does_string_match_expected_pattern(predicted_string):\\n return 0\\n expected_annotation = Annotation.from_image_index(expected_data_index)\\n expected_output = BenetechOutput.from_annotation(expected_annotation)\\n predicted_output = BenetechOutput.from_string(predicted_string)\\n return benetech_score(expected_output, predicted_output)\";\n",
" var nbb_formatted_code = \"def normalized_rmse(expected: list[float], predicted: list[float]) -> float:\\n return (1 - sklearn.metrics.r2_score(expected, predicted)) ** 0.5\\n\\n\\ndef normalized_levenshtein_distance(expected: list[str], predicted: list[str]) -> float:\\n total_distance = 0\\n for e, p in zip(expected, predicted):\\n total_distance += rapidfuzz.distance.Levenshtein.distance(e, p)\\n total_length = np.sum([len(e) for e in expected])\\n return total_distance / total_length\\n\\n\\ndef sigmoid(x):\\n return 1 / (1 + np.exp(-x))\\n\\n\\ndef positive_loss_to_score(x):\\n return 2 * sigmoid(-x)\\n\\n\\ndef score_axis_values(values_type, expected, predicted):\\n if values_type == ValuesType.numerical:\\n loss = normalized_rmse(expected, predicted)\\n else:\\n loss = normalized_levenshtein_distance(expected, predicted)\\n return positive_loss_to_score(loss)\\n\\n\\ndef benetech_score(expected: BenetechOutput, predicted: BenetechOutput) -> float:\\n if expected.get_main_characteristics() != predicted.get_main_characteristics():\\n return 0\\n x_score = score_axis_values(\\n expected.x_values_type, expected.x_data, predicted.x_data\\n )\\n y_score = score_axis_values(\\n expected.y_values_type, expected.y_data, predicted.y_data\\n )\\n return (x_score + y_score) / 2\\n\\n\\ndef benetech_score_string_prediction(expected_data_index: int, predicted_string: str):\\n if not BenetechOutput.does_string_match_expected_pattern(predicted_string):\\n return 0\\n expected_annotation = Annotation.from_image_index(expected_data_index)\\n expected_output = BenetechOutput.from_annotation(expected_annotation)\\n predicted_output = BenetechOutput.from_string(predicted_string)\\n return benetech_score(expected_output, predicted_output)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def normalized_rmse(expected: list[float], predicted: list[float]) -> float:\n",
" return (1 - sklearn.metrics.r2_score(expected, predicted)) ** 0.5\n",
"\n",
"\n",
"def normalized_levenshtein_distance(expected: list[str], predicted: list[str]) -> float:\n",
" total_distance = 0\n",
" for e, p in zip(expected, predicted):\n",
" total_distance += rapidfuzz.distance.Levenshtein.distance(e, p)\n",
" total_length = np.sum([len(e) for e in expected])\n",
" return total_distance / total_length\n",
"\n",
"\n",
"def sigmoid(x):\n",
" return 1 / (1 + np.exp(-x))\n",
"\n",
"\n",
"def positive_loss_to_score(x):\n",
" return 2 * sigmoid(-x)\n",
"\n",
"\n",
"def score_axis_values(values_type, expected, predicted):\n",
" if values_type == ValuesType.numerical:\n",
" loss = normalized_rmse(expected, predicted)\n",
" else:\n",
" loss = normalized_levenshtein_distance(expected, predicted)\n",
" return positive_loss_to_score(loss)\n",
"\n",
"\n",
"def benetech_score(expected: BenetechOutput, predicted: BenetechOutput) -> float:\n",
" if expected.get_main_characteristics() != predicted.get_main_characteristics():\n",
" return 0\n",
" x_score = score_axis_values(\n",
" expected.x_values_type, expected.x_data, predicted.x_data\n",
" )\n",
" y_score = score_axis_values(\n",
" expected.y_values_type, expected.y_data, predicted.y_data\n",
" )\n",
" return (x_score + y_score) / 2\n",
"\n",
"\n",
"def benetech_score_string_prediction(expected_data_index: int, predicted_string: str):\n",
" if not BenetechOutput.does_string_match_expected_pattern(predicted_string):\n",
" return 0\n",
" expected_annotation = Annotation.from_image_index(expected_data_index)\n",
" expected_output = BenetechOutput.from_annotation(expected_annotation)\n",
" predicted_output = BenetechOutput.from_string(predicted_string)\n",
" return benetech_score(expected_output, predicted_output)"
]
},
{
"cell_type": "markdown",
"id": "83bcf99d",
"metadata": {},
"source": [
"### Dataset "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f874683",
"metadata": {
"ExecuteTime": {
"start_time": "2023-04-19T11:32:23.159Z"
}
},
"outputs": [],
"source": [
"1"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "e532ac55",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:39.990566Z",
"start_time": "2023-04-18T15:47:39.933447Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 31;\n",
" var nbb_unformatted_code = \"@dataclasses.dataclass\\nclass DataItem:\\n image: torch.FloatTensor\\n target_string: str\\n data_index: int\\n\\n def __post_init__(self):\\n shape = einops.parse_shape(self.image, \\\"channel height width\\\")\\n assert shape[\\\"channel\\\"] == 3, \\\"Image is expected to have 3 channels.\\\"\\n\\n\\nclass Dataset(torch.utils.data.Dataset):\\n def __init__(self, split: Literal[\\\"train\\\", \\\"val\\\", \\\"complete\\\"]):\\n super().__init__()\\n match split:\\n case \\\"train\\\":\\n self.indices = DATA.train_indices\\n case \\\"val\\\":\\n self.indices = DATA.val_indices\\n case \\\"complete\\\":\\n self.indices = np.arange(len(load_train_image_ids()))\\n case _:\\n raise ValueError(f\\\"Unknown split {split}.\\\")\\n self.to_tensor = torchvision.transforms.ToTensor()\\n\\n def __len__(self):\\n return len(self.indices)\\n\\n def __getitem__(self, idx: int) -> DataItem:\\n data_index = self.indices[idx]\\n\\n annotated_image = AnnotatedImage.from_image_index(data_index)\\n\\n image = annotated_image.image\\n image = self.to_tensor(image)\\n\\n target_string = get_annotation_ground_truth_str(annotated_image.annotation)\\n\\n return DataItem(image=image, target_string=target_string, data_index=data_index)\";\n",
" var nbb_formatted_code = \"@dataclasses.dataclass\\nclass DataItem:\\n image: torch.FloatTensor\\n target_string: str\\n data_index: int\\n\\n def __post_init__(self):\\n shape = einops.parse_shape(self.image, \\\"channel height width\\\")\\n assert shape[\\\"channel\\\"] == 3, \\\"Image is expected to have 3 channels.\\\"\\n\\n\\nclass Dataset(torch.utils.data.Dataset):\\n def __init__(self, split: Literal[\\\"train\\\", \\\"val\\\", \\\"complete\\\"]):\\n super().__init__()\\n match split:\\n case \\\"train\\\":\\n self.indices = DATA.train_indices\\n case \\\"val\\\":\\n self.indices = DATA.val_indices\\n case \\\"complete\\\":\\n self.indices = np.arange(len(load_train_image_ids()))\\n case _:\\n raise ValueError(f\\\"Unknown split {split}.\\\")\\n self.to_tensor = torchvision.transforms.ToTensor()\\n\\n def __len__(self):\\n return len(self.indices)\\n\\n def __getitem__(self, idx: int) -> DataItem:\\n data_index = self.indices[idx]\\n\\n annotated_image = AnnotatedImage.from_image_index(data_index)\\n\\n image = annotated_image.image\\n image = self.to_tensor(image)\\n\\n target_string = get_annotation_ground_truth_str(annotated_image.annotation)\\n\\n return DataItem(image=image, target_string=target_string, data_index=data_index)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"@dataclasses.dataclass\n",
"class DataItem:\n",
" image: torch.FloatTensor\n",
" target_string: str\n",
" data_index: int\n",
"\n",
" def __post_init__(self):\n",
" shape = einops.parse_shape(self.image, \"channel height width\")\n",
" assert shape[\"channel\"] == 3, \"Image is expected to have 3 channels.\"\n",
"\n",
"\n",
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, split: Literal[\"train\", \"val\", \"complete\"]):\n",
" super().__init__()\n",
" match split:\n",
" case \"train\":\n",
" self.indices = DATA.train_indices\n",
" case \"val\":\n",
" self.indices = DATA.val_indices\n",
" case \"complete\":\n",
" self.indices = np.arange(len(load_train_image_ids()))\n",
" case _:\n",
" raise ValueError(f\"Unknown split {split}.\")\n",
" self.to_tensor = torchvision.transforms.ToTensor()\n",
"\n",
" def __len__(self):\n",
" return len(self.indices)\n",
"\n",
" def __getitem__(self, idx: int) -> DataItem:\n",
" data_index = self.indices[idx]\n",
"\n",
" annotated_image = AnnotatedImage.from_image_index(data_index)\n",
"\n",
" image = annotated_image.image\n",
" image = self.to_tensor(image)\n",
"\n",
" target_string = get_annotation_ground_truth_str(annotated_image.annotation)\n",
"\n",
" return DataItem(image=image, target_string=target_string, data_index=data_index)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "0ccf561f",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:40.023555Z",
"start_time": "2023-04-18T15:47:39.992916Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 32;\n",
" var nbb_unformatted_code = \"DATA.train_dataset = Dataset(\\\"train\\\")\\nDATA.val_dataset = Dataset(\\\"val\\\")\\nDATA.complete_dataset = Dataset(\\\"complete\\\")\";\n",
" var nbb_formatted_code = \"DATA.train_dataset = Dataset(\\\"train\\\")\\nDATA.val_dataset = Dataset(\\\"val\\\")\\nDATA.complete_dataset = Dataset(\\\"complete\\\")\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"DATA.train_dataset = Dataset(\"train\")\n",
"DATA.val_dataset = Dataset(\"val\")\n",
"DATA.complete_dataset = Dataset(\"complete\")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "773d4fcc",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:40.136084Z",
"start_time": "2023-04-18T15:47:40.031292Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16<;>17<;>18<;>19<;>20<;>21<;>22<;>23<;>24<;>25<;>26<;>27<;>28<;>29<;>303.79953e+02<;>4.12642e+02<;>3.82075e+02<;>3.69340e+02<;>2.86557e+02<;>2.65330e+02<;>2.35613e+02<;>2.56840e+02<;>1.99528e+02<;>1.95283e+02<;>2.12264e+02<;>1.88915e+02<;>1.91038e+02<;>1.94434e+02<;>2.18632e+02\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/dkkoshman/YSDA/python3.10/lib/python3.10/site-packages/torchvision/transforms/functional.py:152: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)\n",
" img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 33;\n",
" var nbb_unformatted_code = \"print(DATA.train_dataset[0].target_string)\\ntorchvision.transforms.functional.to_pil_image(DATA.train_dataset[0].image)\";\n",
" var nbb_formatted_code = \"print(DATA.train_dataset[0].target_string)\\ntorchvision.transforms.functional.to_pil_image(DATA.train_dataset[0].image)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print(DATA.train_dataset[0].target_string)\n",
"torchvision.transforms.functional.to_pil_image(DATA.train_dataset[0].image)"
]
},
{
"cell_type": "markdown",
"id": "ec80e30c",
"metadata": {},
"source": [
"## Model "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b44db7e4",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-19T11:35:49.248446Z",
"start_time": "2023-04-19T11:35:49.165590Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 4;\n",
" var nbb_unformatted_code = \"transformers.processing_utils.ProcessorMixin?\";\n",
" var nbb_formatted_code = \"transformers.processing_utils.ProcessorMixin?\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"transformers.processing_utils.ProcessorMixin?"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "954300a4",
"metadata": {},
"outputs": [],
"source": [
"transformers.VisionEncoderDecoderModel.to"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7ce37bda",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-19T11:39:30.070268Z",
"start_time": "2023-04-19T11:39:22.652334Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 12;\n",
" var nbb_unformatted_code = \"de = transformers.VisionEncoderDecoderModel.from_pretrained(\\n \\\"naver-clova-ix/donut-base\\\"\\n)\";\n",
" var nbb_formatted_code = \"de = transformers.VisionEncoderDecoderModel.from_pretrained(\\\"naver-clova-ix/donut-base\\\")\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"de = transformers.VisionEncoderDecoderModel.from_pretrained(\n",
" \"naver-clova-ix/donut-base\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d7dfcf78",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-19T11:38:51.404917Z",
"start_time": "2023-04-19T11:38:50.578616Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.\n"
]
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 9;\n",
" var nbb_unformatted_code = \"donut_processor = transformers.DonutProcessor.from_pretrained(\\n \\\"naver-clova-ix/donut-base\\\"\\n)\";\n",
" var nbb_formatted_code = \"donut_processor = transformers.DonutProcessor.from_pretrained(\\n \\\"naver-clova-ix/donut-base\\\"\\n)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"donut_processor = transformers.DonutProcessor.from_pretrained(\n",
" \"naver-clova-ix/donut-base\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "5257aba3",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:51.142992Z",
"start_time": "2023-04-18T15:47:40.137637Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.\n"
]
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 34;\n",
" var nbb_unformatted_code = \"CONFIG.pretrained_model_name = \\\"naver-clova-ix/donut-base\\\"\\nCONFIG.encoder_decoder_config = transformers.VisionEncoderDecoderConfig.from_pretrained(\\n CONFIG.pretrained_model_name\\n)\\nCONFIG.encoder_decoder_config.encoder.image_size = (\\n CONFIG.image_width,\\n CONFIG.image_height,\\n)\\n\\nMODEL.donut_processor = transformers.DonutProcessor.from_pretrained(\\n CONFIG.pretrained_model_name\\n)\\nMODEL.donut_processor.image_processor.size = dict(\\n width=CONFIG.image_width, height=CONFIG.image_height\\n)\\nMODEL.donut_processor.image_processor.do_align_long_axis = False\\nMODEL.tokenizer = MODEL.donut_processor.tokenizer\\nMODEL.encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained(\\n CONFIG.pretrained_model_name, config=CONFIG.encoder_decoder_config\\n)\\n\\nCONFIG.encoder_decoder_config.pad_token_id = MODEL.tokenizer.pad_token_id\\nCONFIG.encoder_decoder_config.decoder_start_token_id = (\\n MODEL.tokenizer.convert_tokens_to_ids(TOKEN.benetech_prompt)\\n)\\nCONFIG.encoder_decoder_config.bos_token_id = (\\n CONFIG.encoder_decoder_config.decoder_start_token_id\\n)\\nCONFIG.encoder_decoder_config.eos_token_id = MODEL.tokenizer.convert_tokens_to_ids(\\n TOKEN.benetech_prompt_end\\n)\\nMODEL.tokenizer.eos_token_id = CONFIG.encoder_decoder_config.eos_token_id\";\n",
" var nbb_formatted_code = \"CONFIG.pretrained_model_name = \\\"naver-clova-ix/donut-base\\\"\\nCONFIG.encoder_decoder_config = transformers.VisionEncoderDecoderConfig.from_pretrained(\\n CONFIG.pretrained_model_name\\n)\\nCONFIG.encoder_decoder_config.encoder.image_size = (\\n CONFIG.image_width,\\n CONFIG.image_height,\\n)\\n\\nMODEL.donut_processor = transformers.DonutProcessor.from_pretrained(\\n CONFIG.pretrained_model_name\\n)\\nMODEL.donut_processor.image_processor.size = dict(\\n width=CONFIG.image_width, height=CONFIG.image_height\\n)\\nMODEL.donut_processor.image_processor.do_align_long_axis = False\\nMODEL.tokenizer = MODEL.donut_processor.tokenizer\\nMODEL.encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained(\\n CONFIG.pretrained_model_name, config=CONFIG.encoder_decoder_config\\n)\\n\\nCONFIG.encoder_decoder_config.pad_token_id = MODEL.tokenizer.pad_token_id\\nCONFIG.encoder_decoder_config.decoder_start_token_id = (\\n MODEL.tokenizer.convert_tokens_to_ids(TOKEN.benetech_prompt)\\n)\\nCONFIG.encoder_decoder_config.bos_token_id = (\\n CONFIG.encoder_decoder_config.decoder_start_token_id\\n)\\nCONFIG.encoder_decoder_config.eos_token_id = MODEL.tokenizer.convert_tokens_to_ids(\\n TOKEN.benetech_prompt_end\\n)\\nMODEL.tokenizer.eos_token_id = CONFIG.encoder_decoder_config.eos_token_id\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"CONFIG.pretrained_model_name = \"naver-clova-ix/donut-base\"\n",
"CONFIG.encoder_decoder_config = transformers.VisionEncoderDecoderConfig.from_pretrained(\n",
" CONFIG.pretrained_model_name\n",
")\n",
"CONFIG.encoder_decoder_config.encoder.image_size = (\n",
" CONFIG.image_width,\n",
" CONFIG.image_height,\n",
")\n",
"\n",
"MODEL.donut_processor = transformers.DonutProcessor.from_pretrained(\n",
" CONFIG.pretrained_model_name\n",
")\n",
"MODEL.donut_processor.image_processor.size = dict(\n",
" width=CONFIG.image_width, height=CONFIG.image_height\n",
")\n",
"MODEL.donut_processor.image_processor.do_align_long_axis = False\n",
"MODEL.tokenizer = MODEL.donut_processor.tokenizer\n",
"MODEL.encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained(\n",
" CONFIG.pretrained_model_name, config=CONFIG.encoder_decoder_config\n",
")\n",
"\n",
"CONFIG.encoder_decoder_config.pad_token_id = MODEL.tokenizer.pad_token_id\n",
"CONFIG.encoder_decoder_config.decoder_start_token_id = (\n",
" MODEL.tokenizer.convert_tokens_to_ids(TOKEN.benetech_prompt)\n",
")\n",
"CONFIG.encoder_decoder_config.bos_token_id = (\n",
" CONFIG.encoder_decoder_config.decoder_start_token_id\n",
")\n",
"CONFIG.encoder_decoder_config.eos_token_id = MODEL.tokenizer.convert_tokens_to_ids(\n",
" TOKEN.benetech_prompt_end\n",
")\n",
"MODEL.tokenizer.eos_token_id = CONFIG.encoder_decoder_config.eos_token_id"
]
},
{
"cell_type": "markdown",
"id": "d40f590d",
"metadata": {},
"source": [
"### Add task specific tokens "
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "42516577",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:51.159825Z",
"start_time": "2023-04-18T15:47:51.144998Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 35;\n",
" var nbb_unformatted_code = \"def add_unknown_tokens_to_tokenizer(unknown_tokens: list[str]):\\n assert set(unknown_tokens) == set(unknown_tokens) - set(\\n MODEL.tokenizer.vocab.keys()\\n ), \\\"Tokens are not unknown.\\\"\\n\\n MODEL.tokenizer.add_tokens(unknown_tokens)\\n MODEL.encoder_decoder.decoder.resize_token_embeddings(len(MODEL.tokenizer))\";\n",
" var nbb_formatted_code = \"def add_unknown_tokens_to_tokenizer(unknown_tokens: list[str]):\\n assert set(unknown_tokens) == set(unknown_tokens) - set(\\n MODEL.tokenizer.vocab.keys()\\n ), \\\"Tokens are not unknown.\\\"\\n\\n MODEL.tokenizer.add_tokens(unknown_tokens)\\n MODEL.encoder_decoder.decoder.resize_token_embeddings(len(MODEL.tokenizer))\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def add_unknown_tokens_to_tokenizer(unknown_tokens: list[str]):\n",
" assert set(unknown_tokens) == set(unknown_tokens) - set(\n",
" MODEL.tokenizer.vocab.keys()\n",
" ), \"Tokens are not unknown.\"\n",
"\n",
" MODEL.tokenizer.add_tokens(unknown_tokens)\n",
" MODEL.encoder_decoder.decoder.resize_token_embeddings(len(MODEL.tokenizer))"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "81a93859",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:52.651571Z",
"start_time": "2023-04-18T15:47:51.162085Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 36;\n",
" var nbb_unformatted_code = \"add_unknown_tokens_to_tokenizer(list(TOKEN.__dict__.values()))\";\n",
" var nbb_formatted_code = \"add_unknown_tokens_to_tokenizer(list(TOKEN.__dict__.values()))\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"add_unknown_tokens_to_tokenizer(list(TOKEN.__dict__.values()))"
]
},
{
"cell_type": "markdown",
"id": "8070590a",
"metadata": {},
"source": [
"### Add dataset specific tokens "
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "fe319b38",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:52.681837Z",
"start_time": "2023-04-18T15:47:52.654564Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 37;\n",
" var nbb_unformatted_code = \"def find_unknown_tokens_for_tokenizer() -> collections.Counter:\\n unknown_tokens_counter = collections.Counter()\\n\\n for annotated_image in generate_annotated_images():\\n ground_truth = get_annotation_ground_truth_str(annotated_image.annotation)\\n\\n input_ids = MODEL.tokenizer(ground_truth).input_ids\\n tokens = MODEL.tokenizer.tokenize(ground_truth, add_special_tokens=True)\\n\\n for token_id, token in zip(input_ids, tokens, strict=True):\\n if token_id == MODEL.tokenizer.unk_token_id:\\n unknown_tokens_counter.update([token])\\n\\n return unknown_tokens_counter\";\n",
" var nbb_formatted_code = \"def find_unknown_tokens_for_tokenizer() -> collections.Counter:\\n unknown_tokens_counter = collections.Counter()\\n\\n for annotated_image in generate_annotated_images():\\n ground_truth = get_annotation_ground_truth_str(annotated_image.annotation)\\n\\n input_ids = MODEL.tokenizer(ground_truth).input_ids\\n tokens = MODEL.tokenizer.tokenize(ground_truth, add_special_tokens=True)\\n\\n for token_id, token in zip(input_ids, tokens, strict=True):\\n if token_id == MODEL.tokenizer.unk_token_id:\\n unknown_tokens_counter.update([token])\\n\\n return unknown_tokens_counter\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def find_unknown_tokens_for_tokenizer() -> collections.Counter:\n",
" unknown_tokens_counter = collections.Counter()\n",
"\n",
" for annotated_image in generate_annotated_images():\n",
" ground_truth = get_annotation_ground_truth_str(annotated_image.annotation)\n",
"\n",
" input_ids = MODEL.tokenizer(ground_truth).input_ids\n",
" tokens = MODEL.tokenizer.tokenize(ground_truth, add_special_tokens=True)\n",
"\n",
" for token_id, token in zip(input_ids, tokens, strict=True):\n",
" if token_id == MODEL.tokenizer.unk_token_id:\n",
" unknown_tokens_counter.update([token])\n",
"\n",
" return unknown_tokens_counter"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "91a5cc71",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:52.708844Z",
"start_time": "2023-04-18T15:47:52.687009Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 38;\n",
" var nbb_unformatted_code = \"if DEBUG:\\n print(find_unknown_tokens_for_tokenizer())\";\n",
" var nbb_formatted_code = \"if DEBUG:\\n print(find_unknown_tokens_for_tokenizer())\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"if DEBUG:\n",
" print(find_unknown_tokens_for_tokenizer())"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "02efe707",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:52.806582Z",
"start_time": "2023-04-18T15:47:52.714235Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 39;\n",
" var nbb_unformatted_code = \"CONFIG.unknown_tokens_for_tokenizer_path = \\\"unknown_tokens_for_tokenizer.pickle\\\"\\n\\nif not os.path.exists(CONFIG.unknown_tokens_for_tokenizer_path):\\n pickle.dump(\\n list(find_unknown_tokens_for_tokenizer().keys()),\\n open(CONFIG.unknown_tokens_for_tokenizer_path, \\\"wb\\\"),\\n )\\n\\nadd_unknown_tokens_to_tokenizer(\\n pickle.load(open(CONFIG.unknown_tokens_for_tokenizer_path, \\\"rb\\\"))\\n)\";\n",
" var nbb_formatted_code = \"CONFIG.unknown_tokens_for_tokenizer_path = \\\"unknown_tokens_for_tokenizer.pickle\\\"\\n\\nif not os.path.exists(CONFIG.unknown_tokens_for_tokenizer_path):\\n pickle.dump(\\n list(find_unknown_tokens_for_tokenizer().keys()),\\n open(CONFIG.unknown_tokens_for_tokenizer_path, \\\"wb\\\"),\\n )\\n\\nadd_unknown_tokens_to_tokenizer(\\n pickle.load(open(CONFIG.unknown_tokens_for_tokenizer_path, \\\"rb\\\"))\\n)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"CONFIG.unknown_tokens_for_tokenizer_path = \"data/unknown_tokens_for_tokenizer.pickle\"\n",
"\n",
"add_unknown_tokens_to_tokenizer(\n",
" load_pickle_or_build_object_and_save(\n",
" CONFIG.unknown_tokens_for_tokenizer_path,\n",
" lambda :list(find_unknown_tokens_for_tokenizer().keys())\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "2fa909a1",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:52.827973Z",
"start_time": "2023-04-18T15:47:52.817963Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 40;\n",
" var nbb_unformatted_code = \"def compute_target_tokens_length_distribution():\\n token_lenghts = []\\n for data_item in tqdm.autonotebook.tqdm(\\n DATA.complete_dataset, desc=\\\"Encoding target strings\\\"\\n ):\\n encoding = MODEL.tokenizer(data_item.target_string)\\n token_lenghts.append(len(encoding.input_ids))\\n return token_lenghts\\n\\n\\ndef visualize_target_tokens_length_distribution():\\n token_lenghts = compute_target_tokens_length_distribution()\\n plt.hist(token_lenghts, bins=50)\\n plt.title(\\\"Token length\\\")\\n series = pd.Series(token_lenghts, name=\\\"Token length\\\").to_frame().describe()\\n IPython.display.display(series)\";\n",
" var nbb_formatted_code = \"def compute_target_tokens_length_distribution():\\n token_lenghts = []\\n for data_item in tqdm.autonotebook.tqdm(\\n DATA.complete_dataset, desc=\\\"Encoding target strings\\\"\\n ):\\n encoding = MODEL.tokenizer(data_item.target_string)\\n token_lenghts.append(len(encoding.input_ids))\\n return token_lenghts\\n\\n\\ndef visualize_target_tokens_length_distribution():\\n token_lenghts = compute_target_tokens_length_distribution()\\n plt.hist(token_lenghts, bins=50)\\n plt.title(\\\"Token length\\\")\\n series = pd.Series(token_lenghts, name=\\\"Token length\\\").to_frame().describe()\\n IPython.display.display(series)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def compute_target_tokens_length_distribution():\n",
" token_lenghts = []\n",
" for data_item in tqdm.autonotebook.tqdm(\n",
" DATA.complete_dataset, desc=\"Encoding target strings\"\n",
" ):\n",
" encoding = MODEL.tokenizer(data_item.target_string)\n",
" token_lenghts.append(len(encoding.input_ids))\n",
" return token_lenghts\n",
"\n",
"\n",
"def visualize_target_tokens_length_distribution():\n",
" token_lenghts = compute_target_tokens_length_distribution()\n",
" plt.hist(token_lenghts, bins=50)\n",
" plt.title(\"Token length\")\n",
" series = pd.Series(token_lenghts, name=\"Token length\").to_frame().describe()\n",
" IPython.display.display(series)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "76eb6a64",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:52.870437Z",
"start_time": "2023-04-18T15:47:52.837124Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 41;\n",
" var nbb_unformatted_code = \"if DEBUG:\\n visualize_target_tokens_length_distribution()\";\n",
" var nbb_formatted_code = \"if DEBUG:\\n visualize_target_tokens_length_distribution()\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"if DEBUG:\n",
" visualize_target_tokens_length_distribution()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "b8a7f491",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:52.886816Z",
"start_time": "2023-04-18T15:47:52.873931Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 42;\n",
" var nbb_unformatted_code = \"CONFIG.encoder_decoder_config.decoder.max_length = 512\";\n",
" var nbb_formatted_code = \"CONFIG.encoder_decoder_config.decoder.max_length = 512\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"CONFIG.encoder_decoder_config.decoder.max_length = 512"
]
},
{
"cell_type": "markdown",
"id": "c688a4a9",
"metadata": {},
"source": [
"### Predicting "
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "36672135",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:52.919334Z",
"start_time": "2023-04-18T15:47:52.888629Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 43;\n",
" var nbb_unformatted_code = \"def generate_token_strings(images: torch.Tensor, skip_special_tokens=True) -> list[str]:\\n decoder_output = MODEL.encoder_decoder.generate(\\n images,\\n max_length=10 if DEBUG else CONFIG.encoder_decoder_config.decoder.max_length,\\n eos_token_id=MODEL.tokenizer.eos_token_id,\\n return_dict_in_generate=True,\\n )\\n return MODEL.tokenizer.batch_decode(\\n decoder_output.sequences, skip_special_tokens=skip_special_tokens\\n )\\n\\n\\ndef predict_string(image) -> str:\\n image = MODEL.donut_processor(\\n image, random_padding=False, return_tensors=\\\"pt\\\"\\n ).pixel_values\\n string = generate_token_strings(image)[0]\\n return string\\n\\n\\ndef predict_benetech_output(image):\\n string = predict_string(image)\\n assert BenetechOutput.does_string_match_expected_pattern(string)\\n return BenetechOutput.from_string(string)\";\n",
" var nbb_formatted_code = \"def generate_token_strings(images: torch.Tensor, skip_special_tokens=True) -> list[str]:\\n decoder_output = MODEL.encoder_decoder.generate(\\n images,\\n max_length=10 if DEBUG else CONFIG.encoder_decoder_config.decoder.max_length,\\n eos_token_id=MODEL.tokenizer.eos_token_id,\\n return_dict_in_generate=True,\\n )\\n return MODEL.tokenizer.batch_decode(\\n decoder_output.sequences, skip_special_tokens=skip_special_tokens\\n )\\n\\n\\ndef predict_string(image) -> str:\\n image = MODEL.donut_processor(\\n image, random_padding=False, return_tensors=\\\"pt\\\"\\n ).pixel_values\\n string = generate_token_strings(image)[0]\\n return string\\n\\n\\ndef predict_benetech_output(image):\\n string = predict_string(image)\\n assert BenetechOutput.does_string_match_expected_pattern(string)\\n return BenetechOutput.from_string(string)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def generate_token_strings(images: torch.Tensor, skip_special_tokens=True) -> list[str]:\n",
" decoder_output = MODEL.encoder_decoder.generate(\n",
" images,\n",
" max_length=10 if DEBUG else CONFIG.encoder_decoder_config.decoder.max_length,\n",
" eos_token_id=MODEL.tokenizer.eos_token_id,\n",
" return_dict_in_generate=True,\n",
" )\n",
" return MODEL.tokenizer.batch_decode(\n",
" decoder_output.sequences, skip_special_tokens=skip_special_tokens\n",
" )\n",
"\n",
"\n",
"def predict_string(image) -> str:\n",
" image = MODEL.donut_processor(\n",
" image, random_padding=False, return_tensors=\"pt\"\n",
" ).pixel_values\n",
" string = generate_token_strings(image)[0]\n",
" return string\n",
"\n",
"\n",
"def predict_benetech_output(image):\n",
" string = predict_string(image)\n",
" assert BenetechOutput.does_string_match_expected_pattern(string)\n",
" return BenetechOutput.from_string(string)"
]
},
{
"cell_type": "markdown",
"id": "2a090da9",
"metadata": {},
"source": [
"### Dataloader "
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "8637a86a",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:52.982298Z",
"start_time": "2023-04-18T15:47:52.921598Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 44;\n",
" var nbb_unformatted_code = \"@dataclasses.dataclass\\nclass Batch:\\n images: torch.FloatTensor\\n labels: torch.IntTensor\\n data_indices: list[int]\\n\\n def __post_init__(self):\\n if DEBUG:\\n images_shape = einops.parse_shape(self.images, \\\"batch channel height width\\\")\\n labels_shape = einops.parse_shape(self.labels, \\\"batch label\\\")\\n assert images_shape[\\\"batch\\\"] == labels_shape[\\\"batch\\\"]\\n assert len(self.data_indices) == images_shape[\\\"batch\\\"]\\n\\n\\ndef replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(\\n token_ids,\\n):\\n token_ids[token_ids == MODEL.tokenizer.pad_token_id] = -100\\n return token_ids\\n\\n\\ndef collate_function(batch: list[DataItem], split: Literal[\\\"train\\\", \\\"val\\\"]) -> Batch:\\n images = [di.image for di in batch]\\n images = MODEL.donut_processor(\\n images, random_padding=split == \\\"train\\\", return_tensors=\\\"pt\\\"\\n ).pixel_values\\n\\n target_token_ids = MODEL.tokenizer(\\n [di.target_string for di in batch],\\n add_special_tokens=False,\\n max_length=CONFIG.encoder_decoder_config.decoder.max_length,\\n padding=\\\"max_length\\\",\\n truncation=True,\\n return_tensors=\\\"pt\\\",\\n ).input_ids\\n labels = replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(\\n target_token_ids\\n )\\n\\n data_indices = [di.data_index for di in batch]\\n\\n return Batch(images=images, labels=labels, data_indices=data_indices)\\n\\n\\nCONFIG.batch_size = 2 if DEBUG else 2\\nCONFIG.num_workers = 4\\n\\n\\ndef build_dataloader(split: Literal[\\\"train\\\", \\\"val\\\"]):\\n return torch.utils.data.DataLoader(\\n DATA.train_dataset if split == \\\"train\\\" else DATA.val_dataset,\\n batch_size=CONFIG.batch_size,\\n shuffle=split == \\\"train\\\",\\n num_workers=CONFIG.num_workers,\\n collate_fn=functools.partial(collate_function, split=split),\\n )\\n\\n\\nDATA.train_dataloader = build_dataloader(\\\"train\\\")\\nDATA.val_dataloader = build_dataloader(\\\"val\\\")\";\n",
" var nbb_formatted_code = \"@dataclasses.dataclass\\nclass Batch:\\n images: torch.FloatTensor\\n labels: torch.IntTensor\\n data_indices: list[int]\\n\\n def __post_init__(self):\\n if DEBUG:\\n images_shape = einops.parse_shape(self.images, \\\"batch channel height width\\\")\\n labels_shape = einops.parse_shape(self.labels, \\\"batch label\\\")\\n assert images_shape[\\\"batch\\\"] == labels_shape[\\\"batch\\\"]\\n assert len(self.data_indices) == images_shape[\\\"batch\\\"]\\n\\n\\ndef replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(\\n token_ids,\\n):\\n token_ids[token_ids == MODEL.tokenizer.pad_token_id] = -100\\n return token_ids\\n\\n\\ndef collate_function(batch: list[DataItem], split: Literal[\\\"train\\\", \\\"val\\\"]) -> Batch:\\n images = [di.image for di in batch]\\n images = MODEL.donut_processor(\\n images, random_padding=split == \\\"train\\\", return_tensors=\\\"pt\\\"\\n ).pixel_values\\n\\n target_token_ids = MODEL.tokenizer(\\n [di.target_string for di in batch],\\n add_special_tokens=False,\\n max_length=CONFIG.encoder_decoder_config.decoder.max_length,\\n padding=\\\"max_length\\\",\\n truncation=True,\\n return_tensors=\\\"pt\\\",\\n ).input_ids\\n labels = replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(\\n target_token_ids\\n )\\n\\n data_indices = [di.data_index for di in batch]\\n\\n return Batch(images=images, labels=labels, data_indices=data_indices)\\n\\n\\nCONFIG.batch_size = 2 if DEBUG else 2\\nCONFIG.num_workers = 4\\n\\n\\ndef build_dataloader(split: Literal[\\\"train\\\", \\\"val\\\"]):\\n return torch.utils.data.DataLoader(\\n DATA.train_dataset if split == \\\"train\\\" else DATA.val_dataset,\\n batch_size=CONFIG.batch_size,\\n shuffle=split == \\\"train\\\",\\n num_workers=CONFIG.num_workers,\\n collate_fn=functools.partial(collate_function, split=split),\\n )\\n\\n\\nDATA.train_dataloader = build_dataloader(\\\"train\\\")\\nDATA.val_dataloader = build_dataloader(\\\"val\\\")\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"@dataclasses.dataclass\n",
"class Batch:\n",
" images: torch.FloatTensor\n",
" labels: torch.IntTensor\n",
" data_indices: list[int]\n",
"\n",
" def __post_init__(self):\n",
" if DEBUG:\n",
" images_shape = einops.parse_shape(self.images, \"batch channel height width\")\n",
" labels_shape = einops.parse_shape(self.labels, \"batch label\")\n",
" assert images_shape[\"batch\"] == labels_shape[\"batch\"]\n",
" assert len(self.data_indices) == images_shape[\"batch\"]\n",
"\n",
"\n",
"def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(\n",
" token_ids,\n",
"):\n",
" token_ids[token_ids == MODEL.tokenizer.pad_token_id] = -100\n",
" return token_ids\n",
"\n",
"\n",
"def collate_function(batch: list[DataItem], split: Literal[\"train\", \"val\"]) -> Batch:\n",
" images = [di.image for di in batch]\n",
" images = MODEL.donut_processor(\n",
" images, random_padding=split == \"train\", return_tensors=\"pt\"\n",
" ).pixel_values\n",
"\n",
" target_token_ids = MODEL.tokenizer(\n",
" [di.target_string for di in batch],\n",
" add_special_tokens=False,\n",
" max_length=CONFIG.encoder_decoder_config.decoder.max_length,\n",
" padding=\"max_length\",\n",
" truncation=True,\n",
" return_tensors=\"pt\",\n",
" ).input_ids\n",
" labels = replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(\n",
" target_token_ids\n",
" )\n",
"\n",
" data_indices = [di.data_index for di in batch]\n",
"\n",
" return Batch(images=images, labels=labels, data_indices=data_indices)\n",
"\n",
"\n",
"CONFIG.batch_size = 2 if DEBUG else 2\n",
"CONFIG.num_workers = 4\n",
"\n",
"\n",
"def build_dataloader(split: Literal[\"train\", \"val\"]):\n",
" return torch.utils.data.DataLoader(\n",
" DATA.train_dataset if split == \"train\" else DATA.val_dataset,\n",
" batch_size=CONFIG.batch_size,\n",
" shuffle=split == \"train\",\n",
" num_workers=CONFIG.num_workers,\n",
" collate_fn=functools.partial(collate_function, split=split),\n",
" )\n",
"\n",
"\n",
"DATA.train_dataloader = build_dataloader(\"train\")\n",
"DATA.val_dataloader = build_dataloader(\"val\")"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "bf389ff2",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:53.034897Z",
"start_time": "2023-04-18T15:47:52.984707Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 45;\n",
" var nbb_unformatted_code = \"def test_dataloaders():\\n for batch in tqdm.autonotebook.tqdm(\\n DATA.val_dataloader, \\\"Iterating over val dataloader\\\"\\n ):\\n pass\\n for batch in tqdm.autonotebook.tqdm(\\n DATA.train_dataloader, \\\"Iterating over train dataloader\\\"\\n ):\\n pass\";\n",
" var nbb_formatted_code = \"def test_dataloaders():\\n for batch in tqdm.autonotebook.tqdm(\\n DATA.val_dataloader, \\\"Iterating over val dataloader\\\"\\n ):\\n pass\\n for batch in tqdm.autonotebook.tqdm(\\n DATA.train_dataloader, \\\"Iterating over train dataloader\\\"\\n ):\\n pass\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def test_dataloaders():\n",
" for batch in tqdm.autonotebook.tqdm(\n",
" DATA.val_dataloader, \"Iterating over val dataloader\"\n",
" ):\n",
" pass\n",
" for batch in tqdm.autonotebook.tqdm(\n",
" DATA.train_dataloader, \"Iterating over train dataloader\"\n",
" ):\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "0eb3fed2",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:53.076744Z",
"start_time": "2023-04-18T15:47:53.037941Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 46;\n",
" var nbb_unformatted_code = \"if DEBUG:\\n test_dataloaders()\";\n",
" var nbb_formatted_code = \"if DEBUG:\\n test_dataloaders()\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"if DEBUG:\n",
" test_dataloaders()"
]
},
{
"cell_type": "markdown",
"id": "08146c41",
"metadata": {},
"source": [
"### Lightning module "
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "323bb5da",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:53.121327Z",
"start_time": "2023-04-18T15:47:53.078769Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 47;\n",
" var nbb_unformatted_code = \"CONFIG.learning_rate = 3e-5\\n\\n\\nclass LightningModule(pl.LightningModule):\\n def __init__(self):\\n super().__init__()\\n self.model = MODEL.encoder_decoder\\n\\n def training_step(self, batch: Batch, batch_idx: int) -> torch.Tensor:\\n outputs = self.model(pixel_values=batch.images, labels=batch.labels)\\n loss = outputs.loss\\n self.log(\\\"train_loss\\\", loss)\\n return loss\\n\\n def validation_step(self, batch: Batch, batch_idx: int, dataset_idx: int = 0):\\n outputs = self.model(pixel_values=batch.images, labels=batch.labels)\\n loss = outputs.loss\\n self.log(\\\"val_loss\\\", loss)\\n\\n def configure_optimizers(self) -> torch.optim.Optimizer:\\n optimizer = torch.optim.Adam(self.parameters(), lr=CONFIG.learning_rate)\\n return optimizer\\n\\n\\nMODEL.lightning_module = LightningModule()\";\n",
" var nbb_formatted_code = \"CONFIG.learning_rate = 3e-5\\n\\n\\nclass LightningModule(pl.LightningModule):\\n def __init__(self):\\n super().__init__()\\n self.model = MODEL.encoder_decoder\\n\\n def training_step(self, batch: Batch, batch_idx: int) -> torch.Tensor:\\n outputs = self.model(pixel_values=batch.images, labels=batch.labels)\\n loss = outputs.loss\\n self.log(\\\"train_loss\\\", loss)\\n return loss\\n\\n def validation_step(self, batch: Batch, batch_idx: int, dataset_idx: int = 0):\\n outputs = self.model(pixel_values=batch.images, labels=batch.labels)\\n loss = outputs.loss\\n self.log(\\\"val_loss\\\", loss)\\n\\n def configure_optimizers(self) -> torch.optim.Optimizer:\\n optimizer = torch.optim.Adam(self.parameters(), lr=CONFIG.learning_rate)\\n return optimizer\\n\\n\\nMODEL.lightning_module = LightningModule()\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"CONFIG.learning_rate = 3e-5\n",
"\n",
"\n",
"class LightningModule(pl.LightningModule):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.model = MODEL.encoder_decoder\n",
"\n",
" def training_step(self, batch: Batch, batch_idx: int) -> torch.Tensor:\n",
" outputs = self.model(pixel_values=batch.images, labels=batch.labels)\n",
" loss = outputs.loss\n",
" self.log(\"train_loss\", loss)\n",
" return loss\n",
"\n",
" def validation_step(self, batch: Batch, batch_idx: int, dataset_idx: int = 0):\n",
" outputs = self.model(pixel_values=batch.images, labels=batch.labels)\n",
" loss = outputs.loss\n",
" self.log(\"val_loss\", loss)\n",
"\n",
" def configure_optimizers(self) -> torch.optim.Optimizer:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=CONFIG.learning_rate)\n",
" return optimizer\n",
"\n",
"\n",
"MODEL.lightning_module = LightningModule()"
]
},
{
"cell_type": "markdown",
"id": "b375ad12",
"metadata": {},
"source": [
"### Callbacks "
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "441e54bb",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:53.157826Z",
"start_time": "2023-04-18T15:47:53.125547Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 48;\n",
" var nbb_unformatted_code = \"class MetricsCallback(pl.callbacks.Callback):\\n def on_validation_batch_start(\\n self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0\\n ):\\n predicted_strings = generate_token_strings(images=batch.images)\\n\\n for expected_data_index, predicted_string in zip(\\n batch.data_indices, predicted_strings, strict=True\\n ):\\n benetech_score = benetech_score_string_prediction(\\n expected_data_index=expected_data_index,\\n predicted_string=predicted_string,\\n )\\n wandb.log(dict(benetech_score=benetech_score))\\n\\n ground_truth_strings = [\\n get_annotation_ground_truth_str_from_image_index(i)\\n for i in batch.data_indices\\n ]\\n string_ids = [load_train_image_ids()[i] for i in batch.data_indices]\\n strings_dataframe = pd.DataFrame(\\n dict(\\n string_ids=string_ids,\\n ground_truth=ground_truth_strings,\\n predicted=predicted_strings,\\n )\\n )\\n wandb.log(dict(strings=wandb.Table(dataframe=strings_dataframe)))\\n\\n\\nclass TransformersCheckpointIO(pl.plugins.CheckpointIO):\\n def save_checkpoint(self, checkpoint, path, storage_options=None):\\n MODEL.donut_processor.save_pretrained(path)\\n MODEL.encoder_decoder.save_pretrained(path)\\n\\n def load_checkpoint(self, path, storage_options=None):\\n pass\\n\\n def remove_checkpoint(self, path):\\n pass\";\n",
" var nbb_formatted_code = \"class MetricsCallback(pl.callbacks.Callback):\\n def on_validation_batch_start(\\n self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0\\n ):\\n predicted_strings = generate_token_strings(images=batch.images)\\n\\n for expected_data_index, predicted_string in zip(\\n batch.data_indices, predicted_strings, strict=True\\n ):\\n benetech_score = benetech_score_string_prediction(\\n expected_data_index=expected_data_index,\\n predicted_string=predicted_string,\\n )\\n wandb.log(dict(benetech_score=benetech_score))\\n\\n ground_truth_strings = [\\n get_annotation_ground_truth_str_from_image_index(i)\\n for i in batch.data_indices\\n ]\\n string_ids = [load_train_image_ids()[i] for i in batch.data_indices]\\n strings_dataframe = pd.DataFrame(\\n dict(\\n string_ids=string_ids,\\n ground_truth=ground_truth_strings,\\n predicted=predicted_strings,\\n )\\n )\\n wandb.log(dict(strings=wandb.Table(dataframe=strings_dataframe)))\\n\\n\\nclass TransformersCheckpointIO(pl.plugins.CheckpointIO):\\n def save_checkpoint(self, checkpoint, path, storage_options=None):\\n MODEL.donut_processor.save_pretrained(path)\\n MODEL.encoder_decoder.save_pretrained(path)\\n\\n def load_checkpoint(self, path, storage_options=None):\\n pass\\n\\n def remove_checkpoint(self, path):\\n pass\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"class MetricsCallback(pl.callbacks.Callback):\n",
" def on_validation_batch_start(\n",
" self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0\n",
" ):\n",
" predicted_strings = generate_token_strings(images=batch.images)\n",
"\n",
" for expected_data_index, predicted_string in zip(\n",
" batch.data_indices, predicted_strings, strict=True\n",
" ):\n",
" benetech_score = benetech_score_string_prediction(\n",
" expected_data_index=expected_data_index,\n",
" predicted_string=predicted_string,\n",
" )\n",
" wandb.log(dict(benetech_score=benetech_score))\n",
"\n",
" ground_truth_strings = [\n",
" get_annotation_ground_truth_str_from_image_index(i)\n",
" for i in batch.data_indices\n",
" ]\n",
" string_ids = [load_train_image_ids()[i] for i in batch.data_indices]\n",
" strings_dataframe = pd.DataFrame(\n",
" dict(\n",
" string_ids=string_ids,\n",
" ground_truth=ground_truth_strings,\n",
" predicted=predicted_strings,\n",
" )\n",
" )\n",
" wandb.log(dict(strings=wandb.Table(dataframe=strings_dataframe)))\n",
"\n",
"\n",
"class TransformersCheckpointIO(pl.plugins.CheckpointIO):\n",
" def save_checkpoint(self, checkpoint, path, storage_options=None):\n",
" MODEL.donut_processor.save_pretrained(path)\n",
" MODEL.encoder_decoder.save_pretrained(path)\n",
"\n",
" def load_checkpoint(self, path, storage_options=None):\n",
" pass\n",
"\n",
" def remove_checkpoint(self, path):\n",
" pass"
]
},
{
"cell_type": "markdown",
"id": "7ef3f395",
"metadata": {},
"source": [
"## Training "
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "3d12b673",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T15:47:57.593057Z",
"start_time": "2023-04-18T15:47:53.160392Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mdkoshman\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
]
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.14.2"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in training/wandb/run-20230418_154756-56t9l4jj
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run young-forest-7 to Weights & Biases (docs)
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at https://wandb.ai/dkoshman/MakingGraphsAccessible"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at https://wandb.ai/dkoshman/MakingGraphsAccessible/runs/56t9l4jj"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n"
]
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 49;\n",
" var nbb_unformatted_code = \"TRAINING.accelerator = \\\"cpu\\\" if DEBUG else \\\"gpu\\\"\\nTRAINING.devices = \\\"auto\\\" if TRAINING.accelerator == \\\"cpu\\\" else [3]\\nTRAINING.directory = \\\"training\\\"\\nTRAINING.save_top_k_checkpoints = 3\\nTRAINING.wandb_project_name = \\\"MakingGraphsAccessible\\\"\\nTRAINING.limit_train_batches = 2 if DEBUG else None\\nTRAINING.limit_val_batches = 2 if DEBUG else 0.1\\n\\nTRAINING.model_checkpoint = pl.callbacks.ModelCheckpoint(\\n dirpath=TRAINING.directory,\\n monitor=\\\"val_loss\\\",\\n save_top_k=TRAINING.save_top_k_checkpoints,\\n)\\n\\nTRAINING.logger = pl.loggers.WandbLogger(\\n project=TRAINING.wandb_project_name, save_dir=TRAINING.directory\\n)\\n\\nTRAINING.trainer = pl.Trainer(\\n accelerator=TRAINING.accelerator,\\n devices=TRAINING.devices,\\n plugins=[TransformersCheckpointIO()],\\n callbacks=[TRAINING.model_checkpoint, MetricsCallback()],\\n logger=TRAINING.logger,\\n limit_train_batches=TRAINING.limit_train_batches,\\n limit_val_batches=TRAINING.limit_val_batches,\\n)\";\n",
" var nbb_formatted_code = \"TRAINING.accelerator = \\\"cpu\\\" if DEBUG else \\\"gpu\\\"\\nTRAINING.devices = \\\"auto\\\" if TRAINING.accelerator == \\\"cpu\\\" else [3]\\nTRAINING.directory = \\\"training\\\"\\nTRAINING.save_top_k_checkpoints = 3\\nTRAINING.wandb_project_name = \\\"MakingGraphsAccessible\\\"\\nTRAINING.limit_train_batches = 2 if DEBUG else None\\nTRAINING.limit_val_batches = 2 if DEBUG else 0.1\\n\\nTRAINING.model_checkpoint = pl.callbacks.ModelCheckpoint(\\n dirpath=TRAINING.directory,\\n monitor=\\\"val_loss\\\",\\n save_top_k=TRAINING.save_top_k_checkpoints,\\n)\\n\\nTRAINING.logger = pl.loggers.WandbLogger(\\n project=TRAINING.wandb_project_name, save_dir=TRAINING.directory\\n)\\n\\nTRAINING.trainer = pl.Trainer(\\n accelerator=TRAINING.accelerator,\\n devices=TRAINING.devices,\\n plugins=[TransformersCheckpointIO()],\\n callbacks=[TRAINING.model_checkpoint, MetricsCallback()],\\n logger=TRAINING.logger,\\n limit_train_batches=TRAINING.limit_train_batches,\\n limit_val_batches=TRAINING.limit_val_batches,\\n)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"TRAINING.accelerator = \"cpu\" if DEBUG else \"gpu\"\n",
"TRAINING.devices = \"auto\" if TRAINING.accelerator == \"cpu\" else [3]\n",
"TRAINING.directory = \"training\"\n",
"TRAINING.save_top_k_checkpoints = 3\n",
"TRAINING.wandb_project_name = \"MakingGraphsAccessible\"\n",
"TRAINING.limit_train_batches = 2 if DEBUG else None\n",
"TRAINING.limit_val_batches = 2 if DEBUG else 0.1\n",
"\n",
"TRAINING.model_checkpoint = pl.callbacks.ModelCheckpoint(\n",
" dirpath=TRAINING.directory,\n",
" monitor=\"val_loss\",\n",
" save_top_k=TRAINING.save_top_k_checkpoints,\n",
")\n",
"\n",
"TRAINING.logger = pl.loggers.WandbLogger(\n",
" project=TRAINING.wandb_project_name, save_dir=TRAINING.directory\n",
")\n",
"\n",
"TRAINING.trainer = pl.Trainer(\n",
" accelerator=TRAINING.accelerator,\n",
" devices=TRAINING.devices,\n",
" plugins=[TransformersCheckpointIO()],\n",
" callbacks=[TRAINING.model_checkpoint, MetricsCallback()],\n",
" logger=TRAINING.logger,\n",
" limit_train_batches=TRAINING.limit_train_batches,\n",
" limit_val_batches=TRAINING.limit_val_batches,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "5c883d58",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T19:50:57.849222Z",
"start_time": "2023-04-18T15:47:57.598224Z"
},
"collapsed": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/dkkoshman/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:70: PossibleUserWarning: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.\n",
" rank_zero_warn(\n",
"You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
"/home/dkkoshman/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:612: UserWarning: Checkpoint directory /home/dkkoshman/YSDA/machine_learning/transformers/MakingGraphsAccessible/training exists and is not empty.\n",
" rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5]\n",
"\n",
" | Name | Type | Params\n",
"----------------------------------------------------\n",
"0 | model | VisionEncoderDecoderModel | 201 M \n",
"----------------------------------------------------\n",
"201 M Trainable params\n",
"0 Non-trainable params\n",
"201 M Total params\n",
"807.457 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/dkkoshman/YSDA/python3.10/lib/python3.10/site-packages/transformers/generation/utils.py:1186: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n",
" warnings.warn(\n",
"/home/dkkoshman/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
" warning_cache.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "12adc4c9f9eb4cd095ac4dce87c500ae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9bf5111de98f4c3893eb5ac0597331e7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "ValueError",
"evalue": "could not convert string to float: ' 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[50], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mTRAINING\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mMODEL\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mDATA\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mDATA\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mval_dataloader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:520\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 518\u001b[0m model \u001b[38;5;241m=\u001b[39m _maybe_unwrap_optimized(model)\n\u001b[1;32m 519\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 520\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 521\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 522\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:44\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 44\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 47\u001b[0m _call_teardown_hook(trainer)\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:559\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_connector\u001b[38;5;241m.\u001b[39mattach_data(\n\u001b[1;32m 550\u001b[0m model, train_dataloaders\u001b[38;5;241m=\u001b[39mtrain_dataloaders, val_dataloaders\u001b[38;5;241m=\u001b[39mval_dataloaders, datamodule\u001b[38;5;241m=\u001b[39mdatamodule\n\u001b[1;32m 551\u001b[0m )\n\u001b[1;32m 553\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 554\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 555\u001b[0m ckpt_path,\n\u001b[1;32m 556\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 557\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 558\u001b[0m )\n\u001b[0;32m--> 559\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 561\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 562\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:935\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 930\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 932\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 933\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 934\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 935\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 937\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 938\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 939\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 940\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:978\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 977\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m--> 978\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:201\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 201\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:354\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher\u001b[38;5;241m.\u001b[39msetup(combined_loader)\n\u001b[1;32m 353\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 354\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:134\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madvance(data_fetcher)\n\u001b[0;32m--> 134\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_advance_end\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:248\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.on_advance_end\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m should_check_val:\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mvalidating \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 248\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 249\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;66;03m# update plateau LR scheduler after metrics are logged\u001b[39;00m\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:174\u001b[0m, in \u001b[0;36m_no_grad_context.._decorator\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 172\u001b[0m context_manager \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mno_grad\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context_manager():\n\u001b[0;32m--> 174\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mloop_run\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:115\u001b[0m, in \u001b[0;36m_EvaluationLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 113\u001b[0m previous_dataloader_idx \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[1;32m 114\u001b[0m \u001b[38;5;66;03m# run step hooks\u001b[39;00m\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader_idx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 117\u001b[0m \u001b[38;5;66;03m# this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support\u001b[39;00m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:369\u001b[0m, in \u001b[0;36m_EvaluationLoop._evaluation_step\u001b[0;34m(self, batch, batch_idx, dataloader_idx)\u001b[0m\n\u001b[1;32m 366\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_logger_connector\u001b[38;5;241m.\u001b[39mon_batch_start(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mstep_kwargs)\n\u001b[1;32m 368\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_test_batch_start\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_validation_batch_start\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 369\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_callback_hooks\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mstep_kwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 370\u001b[0m call\u001b[38;5;241m.\u001b[39m_call_lightning_module_hook(trainer, hook_name, \u001b[38;5;241m*\u001b[39mstep_kwargs\u001b[38;5;241m.\u001b[39mvalues())\n\u001b[1;32m 372\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n",
"File \u001b[0;32m~/YSDA/python3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:190\u001b[0m, in \u001b[0;36m_call_callback_hooks\u001b[0;34m(trainer, hook_name, monitoring_callbacks, *args, **kwargs)\u001b[0m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m callable(fn):\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Callback]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcallback\u001b[38;5;241m.\u001b[39mstate_key\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 190\u001b[0m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pl_module:\n\u001b[1;32m 193\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 194\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
"Cell \u001b[0;32mIn[48], line 10\u001b[0m, in \u001b[0;36mMetricsCallback.on_validation_batch_start\u001b[0;34m(self, trainer, pl_module, batch, batch_idx, dataloader_idx)\u001b[0m\n\u001b[1;32m 5\u001b[0m predicted_strings \u001b[38;5;241m=\u001b[39m generate_token_strings(images\u001b[38;5;241m=\u001b[39mbatch\u001b[38;5;241m.\u001b[39mimages)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m expected_data_index, predicted_string \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(\n\u001b[1;32m 8\u001b[0m batch\u001b[38;5;241m.\u001b[39mdata_indices, predicted_strings, strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 9\u001b[0m ):\n\u001b[0;32m---> 10\u001b[0m benetech_score \u001b[38;5;241m=\u001b[39m \u001b[43mbenetech_score_string_prediction\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mexpected_data_index\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexpected_data_index\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43mpredicted_string\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpredicted_string\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m wandb\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;28mdict\u001b[39m(benetech_score\u001b[38;5;241m=\u001b[39mbenetech_score))\n\u001b[1;32m 16\u001b[0m ground_truth_strings \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 17\u001b[0m get_annotation_ground_truth_str_from_image_index(i)\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m batch\u001b[38;5;241m.\u001b[39mdata_indices\n\u001b[1;32m 19\u001b[0m ]\n",
"Cell \u001b[0;32mIn[30], line 46\u001b[0m, in \u001b[0;36mbenetech_score_string_prediction\u001b[0;34m(expected_data_index, predicted_string)\u001b[0m\n\u001b[1;32m 44\u001b[0m expected_annotation \u001b[38;5;241m=\u001b[39m Annotation\u001b[38;5;241m.\u001b[39mfrom_image_index(expected_data_index)\n\u001b[1;32m 45\u001b[0m expected_output \u001b[38;5;241m=\u001b[39m BenetechOutput\u001b[38;5;241m.\u001b[39mfrom_annotation(expected_annotation)\n\u001b[0;32m---> 46\u001b[0m predicted_output \u001b[38;5;241m=\u001b[39m \u001b[43mBenetechOutput\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_string\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredicted_string\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m benetech_score(expected_output, predicted_output)\n",
"Cell \u001b[0;32mIn[28], line 78\u001b[0m, in \u001b[0;36mBenetechOutput.from_string\u001b[0;34m(string)\u001b[0m\n\u001b[1;32m 74\u001b[0m benetech_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my_values_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m ValuesType(benetech_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my_values_type\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 75\u001b[0m benetech_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx_data\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m convert_string_to_axis_data(\n\u001b[1;32m 76\u001b[0m benetech_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx_data\u001b[39m\u001b[38;5;124m\"\u001b[39m], benetech_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx_values_type\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 77\u001b[0m )\n\u001b[0;32m---> 78\u001b[0m benetech_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my_data\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mconvert_string_to_axis_data\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[43mbenetech_kwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43my_data\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbenetech_kwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43my_values_type\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m BenetechOutput(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mbenetech_kwargs)\n",
"Cell \u001b[0;32mIn[26], line 22\u001b[0m, in \u001b[0;36mconvert_string_to_axis_data\u001b[0;34m(string, values_type)\u001b[0m\n\u001b[1;32m 20\u001b[0m data \u001b[38;5;241m=\u001b[39m string\u001b[38;5;241m.\u001b[39msplit(TOKEN\u001b[38;5;241m.\u001b[39mvalue_separator)\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m values_type \u001b[38;5;241m==\u001b[39m ValuesType\u001b[38;5;241m.\u001b[39mnumerical:\n\u001b[0;32m---> 22\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mfloat\u001b[39m(i) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m data]\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
"Cell \u001b[0;32mIn[26], line 22\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 20\u001b[0m data \u001b[38;5;241m=\u001b[39m string\u001b[38;5;241m.\u001b[39msplit(TOKEN\u001b[38;5;241m.\u001b[39mvalue_separator)\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m values_type \u001b[38;5;241m==\u001b[39m ValuesType\u001b[38;5;241m.\u001b[39mnumerical:\n\u001b[0;32m---> 22\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mfloat\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mi\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m data]\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
"\u001b[0;31mValueError\u001b[0m: could not convert string to float: ' 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01 3.44975e-01'"
]
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 50;\n",
" var nbb_unformatted_code = \"TRAINING.trainer.fit(\\n model=MODEL.lightning_module,\\n train_dataloaders=DATA.train_dataloader,\\n val_dataloaders=DATA.val_dataloader,\\n)\";\n",
" var nbb_formatted_code = \"TRAINING.trainer.fit(\\n model=MODEL.lightning_module,\\n train_dataloaders=DATA.train_dataloader,\\n val_dataloaders=DATA.val_dataloader,\\n)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"TRAINING.trainer.fit(\n",
" model=MODEL.lightning_module,\n",
" train_dataloaders=DATA.train_dataloader,\n",
" val_dataloaders=DATA.val_dataloader,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "32541868",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T19:50:57.857936Z",
"start_time": "2023-04-18T19:50:57.857925Z"
}
},
"outputs": [],
"source": [
"TRAINING.trainer.validate(model=MODEL.lightning_module, dataloaders=DATA.val_dataloader)"
]
},
{
"cell_type": "markdown",
"id": "b36b5cf7",
"metadata": {},
"source": [
"## Results "
]
},
{
"cell_type": "markdown",
"id": "509c9eae",
"metadata": {},
"source": [
"### Gradio interface "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b569259",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T19:50:57.859236Z",
"start_time": "2023-04-18T19:50:57.859226Z"
}
},
"outputs": [],
"source": [
"checkpoint_path = \"training/epoch=0-step=2-v1.ckpt\"\n",
"MODEL.donut_processor = MODEL.donut_processor.from_pretrained(checkpoint_path)\n",
"MODEL.encoder_decoder = MODEL.encoder_decoder.from_pretrained(checkpoint_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6eeea089",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T19:50:57.860310Z",
"start_time": "2023-04-18T19:50:57.860301Z"
}
},
"outputs": [],
"source": [
"interface = gradio.Interface(\n",
" fn=predict_string,\n",
" inputs=gradio.Image(type=\"pil\"),\n",
" outputs=gradio.Text(),\n",
" examples=\"examples\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "39d1e3d8",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-18T19:50:57.861631Z",
"start_time": "2023-04-18T19:50:57.861618Z"
}
},
"outputs": [],
"source": [
"interface.launch(share=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "80124073",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-27T13:03:57.218129Z",
"start_time": "2023-04-27T13:03:57.048661Z"
}
},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 4;\n",
" var nbb_unformatted_code = \"import functools\\n\\nimport gradio\\n\\nfrom config import CONFIG\\nfrom model import (\\n predict_string,\\n build_model,\\n)\";\n",
" var nbb_formatted_code = \"import functools\\n\\nimport gradio\\n\\nfrom config import CONFIG\\nfrom model import (\\n predict_string,\\n build_model,\\n)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import functools\n",
"\n",
"import gradio\n",
"\n",
"from config import CONFIG\n",
"from model import (\n",
" predict_string,\n",
" build_model,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "575edbbe",
"metadata": {
"ExecuteTime": {
"end_time": "2023-04-27T13:04:07.359118Z",
"start_time": "2023-04-27T13:03:58.074214Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reusing object data/unknown_tokens_for_tokenizer.pickle.\n"
]
},
{
"data": {
"application/javascript": [
"\n",
" setTimeout(function() {\n",
" var nbb_cell_id = 5;\n",
" var nbb_unformatted_code = \"config = CONFIG\\nconfig.pretrained_model_name = \\\"training/epoch=2-step=163563.ckpt/\\\"\\nmodel = build_model(config)\";\n",
" var nbb_formatted_code = \"config = CONFIG\\nconfig.pretrained_model_name = \\\"training/epoch=2-step=163563.ckpt/\\\"\\nmodel = build_model(config)\";\n",
" var nbb_cells = Jupyter.notebook.get_cells();\n",
" for (var i = 0; i < nbb_cells.length; ++i) {\n",
" if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n",
" if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n",
" nbb_cells[i].set_text(nbb_formatted_code);\n",
" }\n",
" break;\n",
" }\n",
" }\n",
" }, 500);\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"config = CONFIG\n",
"config.pretrained_model_name = \"training/epoch=2-step=163563.ckpt/\"\n",
"model = build_model(config)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}