Jofthomas HF staff commited on
Commit
02b1331
·
1 Parent(s): 8a3d535

replace get_default_device with PartialState

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. text_to_image.py +5 -3
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- transformers>=4.29.0
2
  diffusers
3
  accelerate
4
  torch
 
1
+ transformers>=4.35.2
2
  diffusers
3
  accelerate
4
  torch
text_to_image.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers.tools.base import Tool, get_default_device
2
  from transformers.utils import is_accelerate_available
3
  import torch
4
 
@@ -18,7 +18,9 @@ class TextToImageTool(Tool):
18
  outputs = ['image']
19
 
20
  def __init__(self, device=None, **hub_kwargs) -> None:
21
- if not is_accelerate_available():
 
 
22
  raise ImportError("Accelerate should be installed in order to use tools.")
23
 
24
  super().__init__()
@@ -29,7 +31,7 @@ class TextToImageTool(Tool):
29
 
30
  def setup(self):
31
  if self.device is None:
32
- self.device = get_default_device()
33
 
34
  self.pipeline = DiffusionPipeline.from_pretrained(self.default_checkpoint)
35
  self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
 
1
+ from transformers.tools.base import Tool
2
  from transformers.utils import is_accelerate_available
3
  import torch
4
 
 
18
  outputs = ['image']
19
 
20
  def __init__(self, device=None, **hub_kwargs) -> None:
21
+ if is_accelerate_available():
22
+ from accelerate import PartialState
23
+ else:
24
  raise ImportError("Accelerate should be installed in order to use tools.")
25
 
26
  super().__init__()
 
31
 
32
  def setup(self):
33
  if self.device is None:
34
+ self.device = PartialState().default_device
35
 
36
  self.pipeline = DiffusionPipeline.from_pretrained(self.default_checkpoint)
37
  self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)