jacklangerman commited on
Commit
8813b45
·
1 Parent(s): 3910cfe

first commit

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. hoho/__init__.py +2 -0
  3. hoho/hoho.py +260 -0
  4. hoho/vis.py +2 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .DS_Store
2
+ __pycache__
3
+ hoho.egg-info/
hoho/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .hoho import *
2
+ import vis
hoho/hoho.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import shutil
4
+ from pathlib import Path
5
+ from typing import Dict
6
+
7
+ from PIL import ImageFile
8
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
9
+
10
+ LOCAL_DATADIR = None
11
+
12
+ def setup(local_dir='./data/usm-training-data/data'):
13
+
14
+ # If we are in the test environment, we need to link the data directory to the correct location
15
+ tmp_datadir = Path('/tmp/data/data')
16
+ local_test_datadir = Path('./data/usm-test-data-x/data')
17
+ local_val_datadir = Path(local_dir)
18
+
19
+ os.system('pwd')
20
+ os.system('ls -lahtr .')
21
+
22
+ if tmp_datadir.exists() and not local_test_datadir.exists():
23
+ global LOCAL_DATADIR
24
+ LOCAL_DATADIR = local_test_datadir
25
+ # shutil.move(datadir, './usm-test-data-x/data')
26
+ print(f"Linking {tmp_datadir} to {LOCAL_DATADIR} (we are in the test environment)")
27
+ LOCAL_DATADIR.parent.mkdir(parents=True, exist_ok=True)
28
+ LOCAL_DATADIR.symlink_to(tmp_datadir)
29
+ else:
30
+ LOCAL_DATADIR = local_val_datadir
31
+ print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
32
+
33
+ # os.system("ls -lahtr")
34
+
35
+ assert LOCAL_DATADIR.exists(), f"Data directory {LOCAL_DATADIR} does not exist"
36
+ return LOCAL_DATADIR
37
+
38
+
39
+
40
+
41
+ import importlib
42
+ from pathlib import Path
43
+ import subprocess
44
+
45
+ def download_package(package_name, path_to_save='packages'):
46
+ """
47
+ Downloads a package using pip and saves it to a specified directory.
48
+
49
+ Parameters:
50
+ package_name (str): The name of the package to download.
51
+ path_to_save (str): The path to the directory where the package will be saved.
52
+ """
53
+ try:
54
+ # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
55
+ subprocess.check_call([subprocess.sys.executable, "-m", "pip", "download", package_name,
56
+ "-d", str(Path(path_to_save)/package_name), # Download the package to the specified directory
57
+ "--platform", "manylinux1_x86_64", # Specify the platform
58
+ "--python-version", "38", # Specify the Python version
59
+ "--only-binary=:all:"]) # Download only binary packages
60
+ print(f'Package "{package_name}" downloaded successfully')
61
+ except subprocess.CalledProcessError as e:
62
+ print(f'Failed to downloaded package "{package_name}". Error: {e}')
63
+
64
+
65
+ def install_package_from_local_file(package_name, folder='packages'):
66
+ """
67
+ Installs a package from a local .whl file or a directory containing .whl files using pip.
68
+
69
+ Parameters:
70
+ path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
71
+ """
72
+ try:
73
+ pth = str(Path(folder) / package_name)
74
+ subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
75
+ "--no-index", # Do not use package index
76
+ "--find-links", pth, # Look for packages in the specified directory or at the file
77
+ package_name]) # Specify the package to install
78
+ print(f"Package installed successfully from {pth}")
79
+ except subprocess.CalledProcessError as e:
80
+ print(f"Failed to install package from {pth}. Error: {e}")
81
+
82
+
83
+ def importt(module_name, as_name=None):
84
+ """
85
+ Imports a module and returns it.
86
+
87
+ Parameters:
88
+ module_name (str): The name of the module to import.
89
+ as_name (str): The name to use for the imported module. If None, the original module name will be used.
90
+
91
+ Returns:
92
+ The imported module.
93
+ """
94
+ for _ in range(2):
95
+ try:
96
+ if as_name is None:
97
+ print(f'imported {module_name}')
98
+ return importlib.import_module(module_name)
99
+ else:
100
+ print(f'imported {module_name} as {as_name}')
101
+ return importlib.import_module(module_name, as_name)
102
+ except ModuleNotFoundError as e:
103
+ install_package_from_local_file(module_name)
104
+ print(f"Failed to import module {module_name}. Error: {e}")
105
+
106
+
107
+ def prepare_submission():
108
+ # Download packages from requirements.txt
109
+ if Path('requirements.txt').exists():
110
+ print('downloading packages from requirements.txt')
111
+ Path('packages').mkdir(exist_ok=True)
112
+ with open('requirements.txt') as f:
113
+ packages = f.readlines()
114
+ for p in packages:
115
+ download_package(p.strip())
116
+
117
+ print('all packages downloaded. Don\'t foget to include the packages in the submission by adding them with git lfs.')
118
+
119
+
120
+ def Rt_to_eye_target(im, K, R, t):
121
+ height = im.height
122
+ focal_length = K[0,0]
123
+ fov = 2.0 * np.arctan2((0.5 * height), focal_length) / (np.pi / 180.0)
124
+
125
+ x_axis, y_axis, z_axis = R
126
+
127
+ eye = -(R.T @ t).squeeze()
128
+ z_axis = z_axis.squeeze()
129
+ target = eye + z_axis
130
+ up = -y_axis
131
+
132
+ return eye, target, up, fov
133
+
134
+
135
+ ########## general utilities ##########
136
+ import contextlib
137
+ import tempfile
138
+ from pathlib import Path
139
+
140
+ @contextlib.contextmanager
141
+ def working_directory(path):
142
+ """Changes working directory and returns to previous on exit."""
143
+ prev_cwd = Path.cwd()
144
+ os.chdir(path)
145
+ try:
146
+ yield
147
+ finally:
148
+ os.chdir(prev_cwd)
149
+
150
+ @contextlib.contextmanager
151
+ def temp_working_directory():
152
+ with tempfile.TemporaryDirectory(dir='.') as D:
153
+ with working_directory(D):
154
+ yield
155
+
156
+
157
+ ############# Dataset #############
158
+ def proc(row, split='train'):
159
+ # column_names_train = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'mesh', 'wireframe']
160
+ # column_names_test = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'wireframe']
161
+ # cols = column_names_train if split == 'train' else column_names_test
162
+ out = {}
163
+ for k, v in row.items():
164
+ colname = k.split('.')[0]
165
+ if colname in {'ade20k', 'depthcm', 'gestalt'}:
166
+ if colname in out:
167
+ out[colname].append(v)
168
+ else:
169
+ out[colname] = [v]
170
+ elif colname in {'wireframe', 'mesh'}:
171
+ # out.update({a: b.tolist() for a,b in v.items()})
172
+ out.update({a: b for a,b in v.items()})
173
+ elif colname in 'kr':
174
+ out[colname.upper()] = v
175
+ else:
176
+ out[colname] = v
177
+
178
+ return Sample(out)
179
+
180
+
181
+ class Sample(Dict):
182
+ def __repr__(self):
183
+ return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
184
+
185
+
186
+
187
+ def get_params():
188
+ exmaple_param_dict = {
189
+ "competition_id": "usm3d/S23DR",
190
+ "competition_type": "script",
191
+ "metric": "custom",
192
+ "token": "hf_**********************************",
193
+ "team_id": "local-test-team_id",
194
+ "submission_id": "local-test-submission_id",
195
+ "submission_id_col": "__key__",
196
+ "submission_cols": [
197
+ "__key__",
198
+ "wf_edges",
199
+ "wf_vertices",
200
+ "edge_semantics"
201
+ ],
202
+ "submission_rows": 180,
203
+ "output_path": ".",
204
+ "submission_repo": "<THE HF MODEL ID of THIS REPO",
205
+ "time_limit": 7200,
206
+ "dataset": "usm3d/usm-test-data-x",
207
+ "submission_filenames": [
208
+ "submission.parquet"
209
+ ]
210
+ }
211
+
212
+ param_path = Path('params.json')
213
+
214
+ if not param_path.exists():
215
+ print('params.json not found (this means we probably aren\'t in the test env). Using example params.')
216
+ params = exmaple_param_dict
217
+ else:
218
+ print('found params.json (this means we are probably in the test env). Using params from file.')
219
+ with param_path.open() as f:
220
+ params = json.load(f)
221
+ print(params)
222
+ return params
223
+
224
+
225
+
226
+ import webdataset as wds
227
+ import numpy as np
228
+
229
+ def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset'):
230
+ if LOCAL_DATADIR is None:
231
+ raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
232
+
233
+ local_dir = Path(LOCAL_DATADIR)
234
+ if split != 'all':
235
+ local_dir = local_dir / split
236
+
237
+ paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
238
+
239
+ dataset = wds.WebDataset(paths)
240
+ if decode is not None:
241
+ dataset = dataset.decode(decode)
242
+ else:
243
+ dataset = dataset.decode()
244
+
245
+ dataset = dataset.map(proc)
246
+
247
+ if dataset_type == 'webdataset':
248
+ return dataset
249
+
250
+ if dataset_type == 'hf':
251
+ import datasets
252
+ from datasets import Features, Value, Sequence, Image, Array2D
253
+
254
+ if split == 'train':
255
+ return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
256
+ elif split == 'val':
257
+ return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
258
+
259
+
260
+
hoho/vis.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ import trimesh
2
+