lixinhao commited on
Commit
af77f1b
·
verified ·
1 Parent(s): be4ed9b

Update vision_tower_builder.py

Browse files
Files changed (1) hide show
  1. vision_tower_builder.py +3 -6
vision_tower_builder.py CHANGED
@@ -2,9 +2,6 @@ from typing import Optional, Tuple, Union, Dict
2
  from dataclasses import dataclass
3
  from functools import partial, reduce
4
  from PIL import Image
5
- import torch
6
- import torch.utils.checkpoint
7
- from torch import nn
8
  import os
9
  from transformers.image_processing_utils import BatchFeature, get_size_dict
10
  from transformers.image_transforms import (
@@ -29,7 +26,7 @@ try:
29
  from flash_attn import flash_attn_qkvpacked_func
30
  except:
31
  print("You need to install flash_attn")
32
- from timm.models.layers import drop_path, to_2tuple, trunc_normal_
33
 
34
 
35
 
@@ -516,7 +513,7 @@ def build_vit(config, pt_type='origin'):
516
  drop_path_rate=0.,
517
  num_frames=config.num_frames,
518
  tubelet_size=1,
519
- use_checkpoint=True,
520
  checkpoint_num=24,
521
  return_index=config.return_idx,
522
  with_ln=True, # merge vision_layernorm in it
@@ -619,4 +616,4 @@ def build_vision_tower(vision_tower_cfg, **kwargs):
619
  raise NotImplementedError
620
  return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
621
 
622
- raise ValueError(f"Unknown vision tower: {vision_tower}")
 
2
  from dataclasses import dataclass
3
  from functools import partial, reduce
4
  from PIL import Image
 
 
 
5
  import os
6
  from transformers.image_processing_utils import BatchFeature, get_size_dict
7
  from transformers.image_transforms import (
 
26
  from flash_attn import flash_attn_qkvpacked_func
27
  except:
28
  print("You need to install flash_attn")
29
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
30
 
31
 
32
 
 
513
  drop_path_rate=0.,
514
  num_frames=config.num_frames,
515
  tubelet_size=1,
516
+ use_checkpoint=False,
517
  checkpoint_num=24,
518
  return_index=config.return_idx,
519
  with_ln=True, # merge vision_layernorm in it
 
616
  raise NotImplementedError
617
  return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
618
 
619
+ raise ValueError(f"Unknown vision tower: {vision_tower}")