onuralpszr commited on
Commit
5645efe
1 Parent(s): d552355

feat: ✨ For segmentation methods are added

Browse files

Signed-off-by: Onuralp SEZER <[email protected]>

.gitignore ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
app.py CHANGED
@@ -6,7 +6,8 @@ import numpy as np
6
  from PIL import Image
7
  import gradio as gr
8
  import spaces
9
- from helpers.utils import create_directory, delete_directory, generate_unique_name
 
10
  import os
11
 
12
  BOX_ANNOTATOR = sv.BoxAnnotator()
@@ -14,10 +15,12 @@ LABEL_ANNOTATOR = sv.LabelAnnotator()
14
  MASK_ANNOTATOR = sv.MaskAnnotator()
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  VIDEO_TARGET_DIRECTORY = "tmp"
 
17
 
 
18
 
19
  INTRO_TEXT = """
20
- ## PaliGemma 2 Detection with Supervision - Demo
21
 
22
  <div style="display: flex; gap: 10px;">
23
  <a href="https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md">
@@ -60,6 +63,14 @@ def parse_class_names(prompt):
60
  classes_text = prompt[7:].strip()
61
  return [cls.strip() for cls in classes_text.split(';') if cls.strip()]
62
 
 
 
 
 
 
 
 
 
63
  @spaces.GPU
64
  def paligemma_detection(input_image, input_text, max_new_tokens):
65
  model_inputs = processor(text=input_text,
@@ -110,10 +121,58 @@ def annotate_image(result, resolution_wh, prompt, cv_image):
110
 
111
  def process_image(input_image, input_text, max_new_tokens):
112
  cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
113
- result = paligemma_detection(input_image, input_text, max_new_tokens)
114
- annotated_image = annotate_image(result,
115
- (input_image.width, input_image.height),
116
- input_text, cv_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  return annotated_image, result
118
 
119
 
@@ -188,13 +247,13 @@ def process_video(input_video, input_text, max_new_tokens, progress=gr.Progress(
188
  with gr.Blocks() as app:
189
  gr.Markdown(INTRO_TEXT)
190
 
191
- with gr.Tab("Image Detection"):
192
  with gr.Row():
193
  with gr.Column():
194
  input_image = gr.Image(type="pil", label="Input Image")
195
  input_text = gr.Textbox(
196
  lines=2,
197
- placeholder="Enter prompt in format like this: detect person;dog;building",
198
  label="Enter detection prompt"
199
  )
200
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max New Tokens", info="Set to larger for longer generation.")
@@ -213,7 +272,7 @@ with gr.Blocks() as app:
213
  input_video = gr.Video(label="Input Video")
214
  input_text = gr.Textbox(
215
  lines=2,
216
- placeholder="Enter prompt in format like this: detect person;dog;building",
217
  label="Enter detection prompt"
218
  )
219
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=1, label="Max New Tokens", info="Set to larger for longer generation.")
 
6
  from PIL import Image
7
  import gradio as gr
8
  import spaces
9
+ from helpers.file_utils import create_directory, delete_directory, generate_unique_name
10
+ from helpers.segment_utils import parse_segmentation, extract_objs
11
  import os
12
 
13
  BOX_ANNOTATOR = sv.BoxAnnotator()
 
15
  MASK_ANNOTATOR = sv.MaskAnnotator()
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  VIDEO_TARGET_DIRECTORY = "tmp"
18
+ VAE_MODEL = "vae-oid.npz"
19
 
20
+ COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
21
 
22
  INTRO_TEXT = """
23
+ ## PaliGemma 2 Detection/Segmentation with Supervision - Demo
24
 
25
  <div style="display: flex; gap: 10px;">
26
  <a href="https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md">
 
63
  classes_text = prompt[7:].strip()
64
  return [cls.strip() for cls in classes_text.split(';') if cls.strip()]
65
 
66
+ def parse_prompt_type(prompt):
67
+ """Determine if the prompt is for detection or segmentation."""
68
+ if prompt.lower().startswith('detect '):
69
+ return 'detection', prompt[7:].strip()
70
+ elif prompt.lower().startswith('segment '):
71
+ return 'segmentation', prompt[8:].strip()
72
+ return None, prompt
73
+
74
  @spaces.GPU
75
  def paligemma_detection(input_image, input_text, max_new_tokens):
76
  model_inputs = processor(text=input_text,
 
121
 
122
  def process_image(input_image, input_text, max_new_tokens):
123
  cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
124
+ prompt_type, cleaned_prompt = parse_prompt_type(input_text)
125
+
126
+ if prompt_type == 'detection':
127
+ # Existing detection logic
128
+ result = paligemma_detection(input_image, input_text, max_new_tokens)
129
+ class_names = [cls.strip() for cls in cleaned_prompt.split(';') if cls.strip()]
130
+
131
+ detections = sv.Detections.from_lmm(
132
+ sv.LMM.PALIGEMMA,
133
+ result,
134
+ resolution_wh=(input_image.width, input_image.height),
135
+ classes=class_names
136
+ )
137
+
138
+ annotated_image = BOX_ANNOTATOR.annotate(scene=cv_image.copy(), detections=detections)
139
+ annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections)
140
+ annotated_image = MASK_ANNOTATOR.annotate(scene=annotated_image, detections=detections)
141
+
142
+ elif prompt_type == 'segmentation':
143
+ # Use parse_segmentation for segmentation tasks
144
+ result = paligemma_detection(input_image, input_text, max_new_tokens)
145
+ input_image, annotations = parse_segmentation(input_image, result)
146
+
147
+ # Create annotated image
148
+ annotated_image = cv_image.copy()
149
+ for mask, label in annotations:
150
+ if isinstance(mask, np.ndarray): # If it's a segmentation mask
151
+ # Create colored mask
152
+ color_idx = hash(label) % len(COLORS)
153
+ color = tuple(int(COLORS[color_idx].lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
154
+ colored_mask = np.zeros_like(cv_image)
155
+ colored_mask[mask > 0] = color
156
+
157
+ # Blend mask with image
158
+ alpha = 0.5
159
+ annotated_image = cv2.addWeighted(annotated_image, 1, colored_mask, alpha, 0)
160
+
161
+ # Add label where mask starts
162
+ y_coords, x_coords = np.where(mask > 0)
163
+ if len(y_coords) > 0 and len(x_coords) > 0:
164
+ label_y = y_coords.min()
165
+ label_x = x_coords.min()
166
+ cv2.putText(annotated_image, label, (label_x, label_y-10),
167
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
168
+ else:
169
+ gr.Warning("Invalid prompt format. Please use 'detect' or 'segment' followed by class names")
170
+ return input_image, "Invalid prompt format"
171
+
172
+ # Convert back to RGB for display
173
+ annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
174
+ annotated_image = Image.fromarray(annotated_image)
175
+
176
  return annotated_image, result
177
 
178
 
 
247
  with gr.Blocks() as app:
248
  gr.Markdown(INTRO_TEXT)
249
 
250
+ with gr.Tab("Image Detection/Segmentation"):
251
  with gr.Row():
252
  with gr.Column():
253
  input_image = gr.Image(type="pil", label="Input Image")
254
  input_text = gr.Textbox(
255
  lines=2,
256
+ placeholder="Enter prompt in format like this: detect person;dog;building or segment person;dog;building",
257
  label="Enter detection prompt"
258
  )
259
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max New Tokens", info="Set to larger for longer generation.")
 
272
  input_video = gr.Video(label="Input Video")
273
  input_text = gr.Textbox(
274
  lines=2,
275
+ placeholder="Enter prompt in format like this: detect person;dog;building or segment person;dog;building",
276
  label="Enter detection prompt"
277
  )
278
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=1, label="Max New Tokens", info="Set to larger for longer generation.")
helpers/{utils.py → file_utils.py} RENAMED
File without changes
helpers/segment_utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flax.linen as nn
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import re
5
+ import numpy as np
6
+ import functools
7
+ from PIL import Image
8
+
9
+ ### Postprocessing Utils for Segmentation Tokens
10
+ ### Segmentation tokens are passed to another VAE which decodes them to a mask
11
+
12
+ _MODEL_PATH = 'vae-oid.npz'
13
+
14
+ _SEGMENT_DETECT_RE = re.compile(
15
+ r'(.*?)' +
16
+ r'<loc(\d{4})>' * 4 + r'\s*' +
17
+ '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
18
+ r'\s*([^;<>]+)? ?(?:; )?',
19
+ )
20
+ COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
21
+
22
+
23
+ def parse_segmentation(input_image,inference_output):
24
+ objs = extract_objs(inference_output.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
25
+ labels = set(obj.get('name') for obj in objs if obj.get('name'))
26
+ color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
27
+ highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
28
+ annotated_img = (
29
+ input_image,
30
+ [
31
+ (
32
+ obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
33
+ obj['name'] or '',
34
+ )
35
+ for obj in objs
36
+ if 'mask' in obj or 'xyxy' in obj
37
+ ],
38
+ )
39
+ has_annotations = bool(annotated_img[1])
40
+ return annotated_img
41
+
42
+
43
+ def _get_params(checkpoint):
44
+ """Converts PyTorch checkpoint to Flax params."""
45
+
46
+ def transp(kernel):
47
+ return np.transpose(kernel, (2, 3, 1, 0))
48
+
49
+ def conv(name):
50
+ return {
51
+ 'bias': checkpoint[name + '.bias'],
52
+ 'kernel': transp(checkpoint[name + '.weight']),
53
+ }
54
+
55
+ def resblock(name):
56
+ return {
57
+ 'Conv_0': conv(name + '.0'),
58
+ 'Conv_1': conv(name + '.2'),
59
+ 'Conv_2': conv(name + '.4'),
60
+ }
61
+
62
+ return {
63
+ '_embeddings': checkpoint['_vq_vae._embedding'],
64
+ 'Conv_0': conv('decoder.0'),
65
+ 'ResBlock_0': resblock('decoder.2.net'),
66
+ 'ResBlock_1': resblock('decoder.3.net'),
67
+ 'ConvTranspose_0': conv('decoder.4'),
68
+ 'ConvTranspose_1': conv('decoder.6'),
69
+ 'ConvTranspose_2': conv('decoder.8'),
70
+ 'ConvTranspose_3': conv('decoder.10'),
71
+ 'Conv_1': conv('decoder.12'),
72
+ }
73
+
74
+
75
+ def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
76
+ batch_size, num_tokens = codebook_indices.shape
77
+ assert num_tokens == 16, codebook_indices.shape
78
+ unused_num_embeddings, embedding_dim = embeddings.shape
79
+
80
+ encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
81
+ encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
82
+ return encodings
83
+
84
+
85
+ @functools.cache
86
+ def _get_reconstruct_masks():
87
+ """Reconstructs masks from codebook indices.
88
+ Returns:
89
+ A function that expects indices shaped `[B, 16]` of dtype int32, each
90
+ ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
91
+ `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
92
+ """
93
+
94
+ class ResBlock(nn.Module):
95
+ features: int
96
+
97
+ @nn.compact
98
+ def __call__(self, x):
99
+ original_x = x
100
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
101
+ x = nn.relu(x)
102
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
103
+ x = nn.relu(x)
104
+ x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
105
+ return x + original_x
106
+
107
+ class Decoder(nn.Module):
108
+ """Upscales quantized vectors to mask."""
109
+
110
+ @nn.compact
111
+ def __call__(self, x):
112
+ num_res_blocks = 2
113
+ dim = 128
114
+ num_upsample_layers = 4
115
+
116
+ x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
117
+ x = nn.relu(x)
118
+
119
+ for _ in range(num_res_blocks):
120
+ x = ResBlock(features=dim)(x)
121
+
122
+ for _ in range(num_upsample_layers):
123
+ x = nn.ConvTranspose(
124
+ features=dim,
125
+ kernel_size=(4, 4),
126
+ strides=(2, 2),
127
+ padding=2,
128
+ transpose_kernel=True,
129
+ )(x)
130
+ x = nn.relu(x)
131
+ dim //= 2
132
+
133
+ x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
134
+
135
+ return x
136
+
137
+ def reconstruct_masks(codebook_indices):
138
+ quantized = _quantized_values_from_codebook_indices(
139
+ codebook_indices, params['_embeddings']
140
+ )
141
+ return Decoder().apply({'params': params}, quantized)
142
+
143
+ with open(_MODEL_PATH, 'rb') as f:
144
+ params = _get_params(dict(np.load(f)))
145
+
146
+ return jax.jit(reconstruct_masks, backend='cpu')
147
+ def extract_objs(text, width, height, unique_labels=False):
148
+ """Returns objs for a string with "<loc>" and "<seg>" tokens."""
149
+ objs = []
150
+ seen = set()
151
+ while text:
152
+ m = _SEGMENT_DETECT_RE.match(text)
153
+ if not m:
154
+ break
155
+ print("m", m)
156
+ gs = list(m.groups())
157
+ before = gs.pop(0)
158
+ name = gs.pop()
159
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
160
+
161
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
162
+ seg_indices = gs[4:20]
163
+ if seg_indices[0] is None:
164
+ mask = None
165
+ else:
166
+ seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
167
+ m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
168
+ m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
169
+ m64 = Image.fromarray((m64 * 255).astype('uint8'))
170
+ mask = np.zeros([height, width])
171
+ if y2 > y1 and x2 > x1:
172
+ mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
173
+
174
+ content = m.group()
175
+ if before:
176
+ objs.append(dict(content=before))
177
+ content = content[len(before):]
178
+ while unique_labels and name in seen:
179
+ name = (name or '') + "'"
180
+ seen.add(name)
181
+ objs.append(dict(
182
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
183
+ text = text[len(before) + len(content):]
184
+
185
+ if text:
186
+ objs.append(dict(content=text))
187
+
188
+ return objs
189
+
190
+ #########
requirements.txt CHANGED
@@ -3,4 +3,6 @@ transformers==4.47.0
3
  requests
4
  tqdm
5
  spaces
6
- torch
 
 
 
3
  requests
4
  tqdm
5
  spaces
6
+ torch
7
+ jax
8
+ flax