File size: 1,719 Bytes
6eb1d7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from dataclasses import dataclass
from typing import Union
import torch
@dataclass
class DensePoseEmbeddingPredictorOutput:
"""
Predictor output that contains embedding and coarse segmentation data:
* embedding: float tensor of size [N, D, H, W], contains estimated embeddings
* coarse_segm: float tensor of size [N, K, H, W]
Here D = MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
K = MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS
"""
embedding: torch.Tensor
coarse_segm: torch.Tensor
def __len__(self):
"""
Number of instances (N) in the output
"""
return self.coarse_segm.size(0)
def __getitem__(
self, item: Union[int, slice, torch.BoolTensor]
) -> "DensePoseEmbeddingPredictorOutput":
"""
Get outputs for the selected instance(s)
Args:
item (int or slice or tensor): selected items
"""
if isinstance(item, int):
return DensePoseEmbeddingPredictorOutput(
coarse_segm=self.coarse_segm[item].unsqueeze(0),
embedding=self.embedding[item].unsqueeze(0),
)
else:
return DensePoseEmbeddingPredictorOutput(
coarse_segm=self.coarse_segm[item], embedding=self.embedding[item]
)
def to(self, device: torch.device):
"""
Transfers all tensors to the given device
"""
coarse_segm = self.coarse_segm.to(device)
embedding = self.embedding.to(device)
return DensePoseEmbeddingPredictorOutput(coarse_segm=coarse_segm, embedding=embedding)
|