diff --git a/backend/memory_management.py b/backend/memory_management.py index 5f0c8312d..061f2b433 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -82,6 +82,18 @@ def is_intel_xpu(): return True return False +def is_npu(): + import importlib + if importlib.util.find_spec("torch_npu") is None: + return False + import torch_npu + + try: + # Will raise a RuntimeError if no NPU is found + _ = torch_npu.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False def get_torch_device(): global directml_enabled @@ -96,6 +108,8 @@ def get_torch_device(): else: if is_intel_xpu(): return torch.device("xpu", torch.xpu.current_device()) + elif is_npu(): + return torch.device("npu", torch.npu.current_device()) else: return torch.device(torch.cuda.current_device()) @@ -105,7 +119,7 @@ def get_total_memory(dev=None, torch_total_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and (dev.type in ['cpu', 'mps']): mem_total = psutil.virtual_memory().total mem_total_torch = mem_total else: @@ -117,6 +131,11 @@ def get_total_memory(dev=None, torch_total_too=False): mem_reserved = stats['reserved_bytes.all.current'] mem_total_torch = mem_reserved mem_total = torch.xpu.get_device_properties(dev).total_memory + elif is_npu(): + stats = torch.npu.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + mem_total_torch = mem_reserved + mem_total = torch.npu.get_device_properties(dev).total_memory else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -129,7 +148,6 @@ def get_total_memory(dev=None, torch_total_too=False): else: return mem_total - total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) @@ -265,6 +283,8 @@ def get_torch_device_name(device): except: allocator_backend = "" return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) + elif is_npu(): + return "{} {}".format(device, torch.npu.get_device_name(device)) else: return "{}".format(device.type) elif is_intel_xpu(): @@ -915,6 +935,8 @@ def device_supports_non_blocking(device): return False if directml_enabled: return False + if is_npu(): + return False return True @@ -942,6 +964,8 @@ def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = True elif is_intel_xpu(): device_supports_cast = True + elif is_npu(): + device_supports_cast = True non_blocking = device_should_use_non_blocking(device) @@ -965,6 +989,8 @@ def xformers_enabled(): return False if directml_enabled: return False + if is_npu(): + return False return XFORMERS_IS_AVAILABLE @@ -989,6 +1015,8 @@ def pytorch_attention_flash_attention(): return True if is_intel_xpu(): return True + if is_npu(): + return True return False @@ -1024,6 +1052,13 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_torch = mem_reserved - mem_active mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved mem_free_total = mem_free_xpu + mem_free_torch + elif is_npu(): + stats = torch.npu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_torch = mem_reserved - mem_active + mem_free_npu = torch.npu.get_device_properties(dev).total_memory - mem_reserved + mem_free_total = mem_free_npu + mem_free_torch else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -1099,6 +1134,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if torch.version.hip: return True + if is_npu(): + return True + props = torch.cuda.get_device_properties("cuda") if props.major >= 8: return True @@ -1153,6 +1191,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True + + if is_npu(): + return True if device is None: device = torch.device("cuda") @@ -1201,6 +1242,11 @@ def soft_empty_cache(force=False): if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache() torch.cuda.ipc_collect() + elif is_npu(): + from modules import npu_specific + torch.npu.set_device(0) + npu_specific.torch_npu_gc() + signal_empty_cache = False return diff --git a/backend/stream.py b/backend/stream.py index f3fcd7bc6..d471872bf 100644 --- a/backend/stream.py +++ b/backend/stream.py @@ -8,6 +8,9 @@ def stream_context(): if torch.xpu.is_available(): return torch.xpu.stream + + if hasattr(torch, "npu") and torch.npu.is_available(): + return torch.npu.stream return None @@ -28,6 +31,13 @@ def get_current_stream(): torch.zeros((1, 1)).to(device, torch.float32) stream.synchronize() return stream + if hasattr(torch, "npu") and torch.npu.is_available(): + device = torch.device("npu") + stream = torch.npu.current_stream(device) + with torch.npu.stream(stream): + torch.zeros((1, 1)).to(device, torch.float32) + stream.synchronize() + return stream except: return None @@ -48,6 +58,13 @@ def get_new_stream(): torch.zeros((1, 1)).to(device, torch.float32) stream.synchronize() return stream + if hasattr(torch, "npu") and torch.npu.is_available(): + device = torch.device("npu") + stream = torch.npu.Stream(device) + with torch.npu.stream(stream): + torch.zeros((1, 1)).to(device, torch.float32) + stream.synchronize() + return stream except: return None diff --git a/modules/devices.py b/modules/devices.py index f8daafc0e..4ccf2b7a2 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -40,7 +40,8 @@ def torch_gc(): def torch_npu_set_device(): - return + if memory_management.is_npu(): + torch.npu.set_device(0) def enable_tf32(): diff --git a/modules/initialize.py b/modules/initialize.py index e0605d480..f9bd16c1c 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -133,4 +133,7 @@ def initialize_rest(*, reload_script_modules=False): extra_networks.register_default_extra_networks() startup_timer.record("initialize extra networks") + from modules import devices + devices.torch_npu_set_device() + return diff --git a/modules/npu_specific.py b/modules/npu_specific.py index 66ba3102c..001b8cb83 100644 --- a/modules/npu_specific.py +++ b/modules/npu_specific.py @@ -1,31 +1,30 @@ -# import importlib -# import torch -# -# from modules import shared -# -# -# def check_for_npu(): -# if importlib.util.find_spec("torch_npu") is None: -# return False -# import torch_npu -# -# try: -# # Will raise a RuntimeError if no NPU is found -# _ = torch_npu.npu.device_count() -# return torch.npu.is_available() -# except RuntimeError: -# return False -# -# -# def get_npu_device_string(): -# if shared.cmd_opts.device_id is not None: -# return f"npu:{shared.cmd_opts.device_id}" -# return "npu:0" -# -# -# def torch_npu_gc(): -# with torch.npu.device(get_npu_device_string()): -# torch.npu.empty_cache() -# -# -# has_npu = check_for_npu() +import importlib +import torch + + +def check_for_npu(): + if importlib.util.find_spec("torch_npu") is None: + return False + import torch_npu + + try: + # Will raise a RuntimeError if no NPU is found + _ = torch_npu.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + + +def get_npu_device_string(): + from modules import shared + if shared.cmd_opts.device_id is not None: + return f"npu:{shared.cmd_opts.device_id}" + return "npu:0" + + +def torch_npu_gc(): + with torch.npu.device(get_npu_device_string()): + torch.npu.empty_cache() + + +has_npu = check_for_npu() diff --git a/requirements_npu.txt b/requirements_npu.txt new file mode 100644 index 000000000..5e6a43646 --- /dev/null +++ b/requirements_npu.txt @@ -0,0 +1,4 @@ +cloudpickle +decorator +synr==0.5.0 +tornado