jacklangerman commited on
Commit
3910cfe
·
1 Parent(s): 2ed8e4c

first commit

Browse files
Files changed (1) hide show
  1. hoho.py +0 -261
hoho.py DELETED
@@ -1,261 +0,0 @@
1
- import os
2
- import json
3
- import shutil
4
- from pathlib import Path
5
- from typing import Dict
6
-
7
- from PIL import ImageFile
8
- ImageFile.LOAD_TRUNCATED_IMAGES = True
9
-
10
- LOCAL_DATADIR = None
11
-
12
- def setup(local_dir='./data/usm-training-data/data'):
13
-
14
- # If we are in the test environment, we need to link the data directory to the correct location
15
- tmp_datadir = Path('/tmp/data/data')
16
- local_test_datadir = Path('./data/usm-test-data-x/data')
17
- local_val_datadir = Path(local_dir)
18
-
19
- os.system('pwd')
20
- os.system('ls -lahtr .')
21
-
22
- if tmp_datadir.exists() and not local_test_datadir.exists():
23
- global LOCAL_DATADIR
24
- LOCAL_DATADIR = local_test_datadir
25
- # shutil.move(datadir, './usm-test-data-x/data')
26
- print(f"Linking {tmp_datadir} to {LOCAL_DATADIR} (we are in the test environment)")
27
- LOCAL_DATADIR.parent.mkdir(parents=True, exist_ok=True)
28
- LOCAL_DATADIR.symlink_to(tmp_datadir)
29
- else:
30
- LOCAL_DATADIR = local_val_datadir
31
- print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
32
-
33
- # os.system("ls -lahtr")
34
-
35
- assert LOCAL_DATADIR.exists(), f"Data directory {LOCAL_DATADIR} does not exist"
36
- return LOCAL_DATADIR
37
-
38
-
39
-
40
-
41
- import importlib
42
- from pathlib import Path
43
- import subprocess
44
-
45
- def download_package(package_name, path_to_save='packages'):
46
- """
47
- Downloads a package using pip and saves it to a specified directory.
48
-
49
- Parameters:
50
- package_name (str): The name of the package to download.
51
- path_to_save (str): The path to the directory where the package will be saved.
52
- """
53
- try:
54
- # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
55
- subprocess.check_call([subprocess.sys.executable, "-m", "pip", "download", package_name,
56
- "-d", str(Path(path_to_save)/package_name), # Download the package to the specified directory
57
- "--platform", "manylinux1_x86_64", # Specify the platform
58
- "--python-version", "38", # Specify the Python version
59
- "--only-binary=:all:"]) # Download only binary packages
60
- print(f'Package "{package_name}" downloaded successfully')
61
- except subprocess.CalledProcessError as e:
62
- print(f'Failed to downloaded package "{package_name}". Error: {e}')
63
-
64
-
65
- def install_package_from_local_file(package_name, folder='packages'):
66
- """
67
- Installs a package from a local .whl file or a directory containing .whl files using pip.
68
-
69
- Parameters:
70
- path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
71
- """
72
- try:
73
- pth = str(Path(folder) / package_name)
74
- subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
75
- "--no-index", # Do not use package index
76
- "--find-links", pth, # Look for packages in the specified directory or at the file
77
- package_name]) # Specify the package to install
78
- print(f"Package installed successfully from {pth}")
79
- except subprocess.CalledProcessError as e:
80
- print(f"Failed to install package from {pth}. Error: {e}")
81
-
82
-
83
- def importt(module_name, as_name=None):
84
- """
85
- Imports a module and returns it.
86
-
87
- Parameters:
88
- module_name (str): The name of the module to import.
89
- as_name (str): The name to use for the imported module. If None, the original module name will be used.
90
-
91
- Returns:
92
- The imported module.
93
- """
94
- for _ in range(2):
95
- try:
96
- if as_name is None:
97
- print(f'imported {module_name}')
98
- return importlib.import_module(module_name)
99
- else:
100
- print(f'imported {module_name} as {as_name}')
101
- return importlib.import_module(module_name, as_name)
102
- except ModuleNotFoundError as e:
103
- install_package_from_local_file(module_name)
104
- print(f"Failed to import module {module_name}. Error: {e}")
105
-
106
-
107
- def prepare_submission():
108
- # Download packages from requirements.txt
109
- if Path('requirements.txt').exists():
110
- print('downloading packages from requirements.txt')
111
- Path('packages').mkdir(exist_ok=True)
112
- with open('requirements.txt') as f:
113
- packages = f.readlines()
114
- for p in packages:
115
- download_package(p.strip())
116
-
117
-
118
- print('all packages downloaded. Don\'t foget to include the packages in the submission by adding them with git lfs.')
119
-
120
-
121
- def Rt_to_eye_target(im, K, R, t):
122
- height = im.height
123
- focal_length = K[0,0]
124
- fov = 2.0 * np.arctan2((0.5 * height), focal_length) / (np.pi / 180.0)
125
-
126
- x_axis, y_axis, z_axis = R
127
-
128
- eye = -(R.T @ t).squeeze()
129
- z_axis = z_axis.squeeze()
130
- target = eye + z_axis
131
- up = -y_axis
132
-
133
- return eye, target, up, fov
134
-
135
-
136
- ########## general utilities ##########
137
- import contextlib
138
- import tempfile
139
- from pathlib import Path
140
-
141
- @contextlib.contextmanager
142
- def working_directory(path):
143
- """Changes working directory and returns to previous on exit."""
144
- prev_cwd = Path.cwd()
145
- os.chdir(path)
146
- try:
147
- yield
148
- finally:
149
- os.chdir(prev_cwd)
150
-
151
- @contextlib.contextmanager
152
- def temp_working_directory():
153
- with tempfile.TemporaryDirectory(dir='.') as D:
154
- with working_directory(D):
155
- yield
156
-
157
-
158
- ############# Dataset #############
159
- def proc(row, split='train'):
160
- # column_names_train = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'mesh', 'wireframe']
161
- # column_names_test = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'wireframe']
162
- # cols = column_names_train if split == 'train' else column_names_test
163
- out = {}
164
- for k, v in row.items():
165
- colname = k.split('.')[0]
166
- if colname in {'ade20k', 'depthcm', 'gestalt'}:
167
- if colname in out:
168
- out[colname].append(v)
169
- else:
170
- out[colname] = [v]
171
- elif colname in {'wireframe', 'mesh'}:
172
- # out.update({a: b.tolist() for a,b in v.items()})
173
- out.update({a: b for a,b in v.items()})
174
- elif colname in 'kr':
175
- out[colname.upper()] = v
176
- else:
177
- out[colname] = v
178
-
179
- return Sample(out)
180
-
181
-
182
- class Sample(Dict):
183
- def __repr__(self):
184
- return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
185
-
186
-
187
-
188
- def get_params():
189
- exmaple_param_dict = {
190
- "competition_id": "usm3d/S23DR",
191
- "competition_type": "script",
192
- "metric": "custom",
193
- "token": "hf_**********************************",
194
- "team_id": "local-test-team_id",
195
- "submission_id": "local-test-submission_id",
196
- "submission_id_col": "__key__",
197
- "submission_cols": [
198
- "__key__",
199
- "wf_edges",
200
- "wf_vertices",
201
- "edge_semantics"
202
- ],
203
- "submission_rows": 180,
204
- "output_path": ".",
205
- "submission_repo": "<THE HF MODEL ID of THIS REPO",
206
- "time_limit": 7200,
207
- "dataset": "usm3d/usm-test-data-x",
208
- "submission_filenames": [
209
- "submission.parquet"
210
- ]
211
- }
212
-
213
- param_path = Path('params.json')
214
-
215
- if not param_path.exists():
216
- print('params.json not found (this means we probably aren\'t in the test env). Using example params.')
217
- params = exmaple_param_dict
218
- else:
219
- print('found params.json (this means we are probably in the test env). Using params from file.')
220
- with param_path.open() as f:
221
- params = json.load(f)
222
- print(params)
223
- return params
224
-
225
-
226
-
227
- import webdataset as wds
228
- import numpy as np
229
-
230
- def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset'):
231
- if LOCAL_DATADIR is None:
232
- raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
233
-
234
- local_dir = Path(LOCAL_DATADIR)
235
- if split != 'all':
236
- local_dir = local_dir / split
237
-
238
- paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
239
-
240
- dataset = wds.WebDataset(paths)
241
- if decode is not None:
242
- dataset = dataset.decode(decode)
243
- else:
244
- dataset = dataset.decode()
245
-
246
- dataset = dataset.map(proc)
247
-
248
- if dataset_type == 'webdataset':
249
- return dataset
250
-
251
- if dataset_type == 'hf':
252
- import datasets
253
- from datasets import Features, Value, Sequence, Image, Array2D
254
-
255
- if split == 'train':
256
- return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
257
- elif split == 'val':
258
- return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
259
-
260
-
261
-