JeffreyXiang commited on
Commit
690b53e
·
1 Parent(s): bd46f72
app.py CHANGED
@@ -3,6 +3,7 @@ import spaces
3
  from gradio_litmodel3d import LitModel3D
4
 
5
  import os
 
6
  from typing import *
7
  import torch
8
  import numpy as np
@@ -131,7 +132,7 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s
131
  str: The path to the extracted GLB file.
132
  """
133
  gs, mesh, model_id = unpack_state(state)
134
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size)
135
  glb_path = f"/tmp/Trellis-demo/{model_id}.glb"
136
  glb.export(glb_path)
137
  return glb_path, glb_path
@@ -161,12 +162,12 @@ with gr.Blocks() as demo:
161
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
162
  gr.Markdown("Stage 1: Sparse Structure Generation")
163
  with gr.Row():
164
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=5.0, step=0.1)
165
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
166
  gr.Markdown("Stage 2: Structured Latent Generation")
167
  with gr.Row():
168
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=5.0, step=0.1)
169
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
170
 
171
  generate_btn = gr.Button("Generate")
172
 
 
3
  from gradio_litmodel3d import LitModel3D
4
 
5
  import os
6
+ os.environ['SPCONV_ALGO'] = 'native'
7
  from typing import *
8
  import torch
9
  import numpy as np
 
132
  str: The path to the extracted GLB file.
133
  """
134
  gs, mesh, model_id = unpack_state(state)
135
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
136
  glb_path = f"/tmp/Trellis-demo/{model_id}.glb"
137
  glb.export(glb_path)
138
  return glb_path, glb_path
 
162
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
163
  gr.Markdown("Stage 1: Sparse Structure Generation")
164
  with gr.Row():
165
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
166
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
167
  gr.Markdown("Stage 2: Structured Latent Generation")
168
  with gr.Row():
169
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
170
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
171
 
172
  generate_btn = gr.Button("Generate")
173
 
trellis/modules/sparse/__init__.py CHANGED
@@ -24,6 +24,8 @@ def __from_env():
24
  if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
25
  ATTN = env_sparse_attn
26
 
 
 
27
 
28
  __from_env()
29
 
 
24
  if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
25
  ATTN = env_sparse_attn
26
 
27
+ print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
28
+
29
 
30
  __from_env()
31
 
trellis/modules/sparse/conv/__init__.py CHANGED
@@ -1,6 +1,21 @@
1
  from .. import BACKEND
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  if BACKEND == 'torchsparse':
4
  from .conv_torchsparse import *
5
  elif BACKEND == 'spconv':
6
- from .conv_spconv import *
 
1
  from .. import BACKEND
2
 
3
+
4
+ SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
5
+
6
+ def __from_env():
7
+ import os
8
+
9
+ global SPCONV_ALGO
10
+ env_spconv_algo = os.environ.get('SPCONV_ALGO')
11
+ if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
12
+ SPCONV_ALGO = env_spconv_algo
13
+ print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
14
+
15
+
16
+ __from_env()
17
+
18
  if BACKEND == 'torchsparse':
19
  from .conv_torchsparse import *
20
  elif BACKEND == 'spconv':
21
+ from .conv_spconv import *
trellis/modules/sparse/conv/conv_spconv.py CHANGED
@@ -2,16 +2,22 @@ import torch
2
  import torch.nn as nn
3
  from .. import SparseTensor
4
  from .. import DEBUG
 
5
 
6
  class SparseConv3d(nn.Module):
7
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
8
  super(SparseConv3d, self).__init__()
9
  if 'spconv' not in globals():
10
  import spconv.pytorch as spconv
 
 
 
 
 
11
  if stride == 1 and (padding is None):
12
- self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key)
13
  else:
14
- self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key)
15
  self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
16
  self.padding = padding
17
 
 
2
  import torch.nn as nn
3
  from .. import SparseTensor
4
  from .. import DEBUG
5
+ from . import SPCONV_ALGO
6
 
7
  class SparseConv3d(nn.Module):
8
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
9
  super(SparseConv3d, self).__init__()
10
  if 'spconv' not in globals():
11
  import spconv.pytorch as spconv
12
+ algo = None
13
+ if SPCONV_ALGO == 'native':
14
+ algo = spconv.ConvAlgo.Native
15
+ elif SPCONV_ALGO == 'implicit_gemm':
16
+ algo = spconv.ConvAlgo.MaskImplicitGemm
17
  if stride == 1 and (padding is None):
18
+ self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
19
  else:
20
+ self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
21
  self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
22
  self.padding = padding
23
 
trellis/utils/postprocessing_utils.py CHANGED
@@ -448,7 +448,7 @@ def to_glb(
448
  observations, masks, extrinsics, intrinsics,
449
  texture_size=texture_size, mode='opt',
450
  lambda_tv=0.01,
451
- verbose=True
452
  )
453
  texture = Image.fromarray(texture)
454
 
 
448
  observations, masks, extrinsics, intrinsics,
449
  texture_size=texture_size, mode='opt',
450
  lambda_tv=0.01,
451
+ verbose=verbose
452
  )
453
  texture = Image.fromarray(texture)
454