Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions backend/memory_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ def is_intel_xpu():
return True
return False

def is_npu():
import importlib
if importlib.util.find_spec("torch_npu") is None:
return False
import torch_npu

try:
# Will raise a RuntimeError if no NPU is found
_ = torch_npu.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False

def get_torch_device():
global directml_enabled
Expand All @@ -96,6 +108,8 @@ def get_torch_device():
else:
if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device())
elif is_npu():
return torch.device("npu", torch.npu.current_device())
else:
return torch.device(torch.cuda.current_device())

Expand All @@ -105,7 +119,7 @@ def get_total_memory(dev=None, torch_total_too=False):
if dev is None:
dev = get_torch_device()

if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
if hasattr(dev, 'type') and (dev.type in ['cpu', 'mps']):
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
Expand All @@ -117,6 +131,11 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_reserved = stats['reserved_bytes.all.current']
mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory
elif is_npu():
stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total_torch = mem_reserved
mem_total = torch.npu.get_device_properties(dev).total_memory
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
Expand All @@ -129,7 +148,6 @@ def get_total_memory(dev=None, torch_total_too=False):
else:
return mem_total


total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
Expand Down Expand Up @@ -265,6 +283,8 @@ def get_torch_device_name(device):
except:
allocator_backend = ""
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
elif is_npu():
return "{} {}".format(device, torch.npu.get_device_name(device))
else:
return "{}".format(device.type)
elif is_intel_xpu():
Expand Down Expand Up @@ -915,6 +935,8 @@ def device_supports_non_blocking(device):
return False
if directml_enabled:
return False
if is_npu():
return False
return True


Expand Down Expand Up @@ -942,6 +964,8 @@ def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = True
elif is_intel_xpu():
device_supports_cast = True
elif is_npu():
device_supports_cast = True

non_blocking = device_should_use_non_blocking(device)

Expand All @@ -965,6 +989,8 @@ def xformers_enabled():
return False
if directml_enabled:
return False
if is_npu():
return False
return XFORMERS_IS_AVAILABLE


Expand All @@ -989,6 +1015,8 @@ def pytorch_attention_flash_attention():
return True
if is_intel_xpu():
return True
if is_npu():
return True
return False


Expand Down Expand Up @@ -1024,6 +1052,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
elif is_npu():
stats = torch.npu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_npu = torch.npu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_npu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
Expand Down Expand Up @@ -1099,6 +1134,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip:
return True

if is_npu():
return True

props = torch.cuda.get_device_properties("cuda")
if props.major >= 8:
return True
Expand Down Expand Up @@ -1153,6 +1191,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma

if is_intel_xpu():
return True

if is_npu():
return True

if device is None:
device = torch.device("cuda")
Expand Down Expand Up @@ -1201,6 +1242,11 @@ def soft_empty_cache(force=False):
if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif is_npu():
from modules import npu_specific
torch.npu.set_device(0)
npu_specific.torch_npu_gc()

signal_empty_cache = False
return

Expand Down
17 changes: 17 additions & 0 deletions backend/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ def stream_context():

if torch.xpu.is_available():
return torch.xpu.stream

if hasattr(torch, "npu") and torch.npu.is_available():
return torch.npu.stream

return None

Expand All @@ -28,6 +31,13 @@ def get_current_stream():
torch.zeros((1, 1)).to(device, torch.float32)
stream.synchronize()
return stream
if hasattr(torch, "npu") and torch.npu.is_available():
device = torch.device("npu")
stream = torch.npu.current_stream(device)
with torch.npu.stream(stream):
torch.zeros((1, 1)).to(device, torch.float32)
stream.synchronize()
return stream
except:
return None

Expand All @@ -48,6 +58,13 @@ def get_new_stream():
torch.zeros((1, 1)).to(device, torch.float32)
stream.synchronize()
return stream
if hasattr(torch, "npu") and torch.npu.is_available():
device = torch.device("npu")
stream = torch.npu.Stream(device)
with torch.npu.stream(stream):
torch.zeros((1, 1)).to(device, torch.float32)
stream.synchronize()
return stream
except:
return None

Expand Down
3 changes: 2 additions & 1 deletion modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def torch_gc():


def torch_npu_set_device():
return
if memory_management.is_npu():
torch.npu.set_device(0)


def enable_tf32():
Expand Down
3 changes: 3 additions & 0 deletions modules/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,7 @@ def initialize_rest(*, reload_script_modules=False):
extra_networks.register_default_extra_networks()
startup_timer.record("initialize extra networks")

from modules import devices
devices.torch_npu_set_device()

return
61 changes: 30 additions & 31 deletions modules/npu_specific.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
# import importlib
# import torch
#
# from modules import shared
#
#
# def check_for_npu():
# if importlib.util.find_spec("torch_npu") is None:
# return False
# import torch_npu
#
# try:
# # Will raise a RuntimeError if no NPU is found
# _ = torch_npu.npu.device_count()
# return torch.npu.is_available()
# except RuntimeError:
# return False
#
#
# def get_npu_device_string():
# if shared.cmd_opts.device_id is not None:
# return f"npu:{shared.cmd_opts.device_id}"
# return "npu:0"
#
#
# def torch_npu_gc():
# with torch.npu.device(get_npu_device_string()):
# torch.npu.empty_cache()
#
#
# has_npu = check_for_npu()
import importlib
import torch


def check_for_npu():
if importlib.util.find_spec("torch_npu") is None:
return False
import torch_npu

try:
# Will raise a RuntimeError if no NPU is found
_ = torch_npu.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False


def get_npu_device_string():
from modules import shared
if shared.cmd_opts.device_id is not None:
return f"npu:{shared.cmd_opts.device_id}"
return "npu:0"


def torch_npu_gc():
with torch.npu.device(get_npu_device_string()):
torch.npu.empty_cache()


has_npu = check_for_npu()
4 changes: 4 additions & 0 deletions requirements_npu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
cloudpickle
decorator
synr==0.5.0
tornado