|
import argparse |
|
|
|
import torch |
|
from fvcore.nn import FlopCountAnalysis |
|
from ultralytics import YOLO |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('model', type=str, help='Model path for validation.') |
|
parser.add_argument('--imgsz', default=928, type=int, help='Image size to validate.') |
|
args = parser.parse_args() |
|
|
|
model = torch.load(args.model, map_location='cpu')['model'].float() |
|
fca = FlopCountAnalysis( |
|
model=model.eval(), |
|
inputs=torch.rand(1, 3, args.imgsz, args.imgsz), |
|
) |
|
fca.unsupported_ops_warnings(False) |
|
fca.uncalled_modules_warnings(False) |
|
|
|
print(f"{fca.total() * 1e-9:.2f} GMACS ") |
|
|