32 lines
867 B
Python
32 lines
867 B
Python
from contextlib import contextmanager
|
|
from collections import defaultdict
|
|
import torch
|
|
|
|
|
|
class CUDATimer:
|
|
|
|
def __init__(self):
|
|
self.events = defaultdict(list)
|
|
|
|
@contextmanager
|
|
def record(self, name, enabled=True):
|
|
if not enabled:
|
|
yield
|
|
else:
|
|
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
|
self.events[name].append((start, end))
|
|
start.record()
|
|
yield
|
|
end.record()
|
|
|
|
def log(self):
|
|
torch.cuda.synchronize()
|
|
ret = []
|
|
for name, events in self.events.items():
|
|
total = 0
|
|
count = len(self.events)
|
|
for start, end in events:
|
|
total += start.elapsed_time(end)
|
|
ret.append(f"{name} {total:.2f}ms/{count}times")
|
|
return ", ".join(ret)
|