Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ def get_dram_gbps(device=None):
import torch

from .runtime import driver
if not device:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device can be 0

device = torch.cuda.current_device()
if device is None:
device = driver.active.get_device_interface().current_device()
mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"]
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
Expand Down
Loading