diff --git a/.gitignore b/.gitignore index 8b8235e6..1761ebd1 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ util/__pycache__/ index.html?linkid=2289031 wget-log weights/icon_caption_florence_v2/ -omnitool/gradio/uploads/ \ No newline at end of file +omnitool/gradio/uploads/ +.DS_Store \ No newline at end of file diff --git a/gradio_demo.py b/gradio_demo.py index 15664d31..78a71ffa 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -12,6 +12,8 @@ import torch from PIL import Image +os.environ["NO_PROXY"] = "localhost,127.0.0.1" + yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt') caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence") # caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2") @@ -27,7 +29,17 @@ OmniParser is a screen parsing tool to convert general GUI screen to structured elements. """ -DEVICE = torch.device('cuda') +# DEVICE = torch.device('cuda') +# Check if MPS is available (for Mac with Apple Silicon) +if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + DEVICE = torch.device('mps') +# Fall back to CUDA if MPS is not available +elif torch.cuda.is_available(): + DEVICE = torch.device('cuda') +# Fall back to CPU as last resort +else: + DEVICE = torch.device('cpu') + print("Warning: Neither MPS nor CUDA is available. Using CPU instead.") # @spaces.GPU # @torch.inference_mode() diff --git a/requirements.txt b/requirements.txt index 901a27fa..ebd212a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,14 +9,14 @@ azure-identity numpy==1.26.4 opencv-python opencv-python-headless -gradio +gradio==5.25.2 dill accelerate timm einops==0.8.0 paddlepaddle paddleocr -ruff==0.6.7 +ruff pre-commit==3.8.0 pytest==8.3.3 pytest-asyncio==0.23.6 diff --git a/util/utils.py b/util/utils.py index eb7c8b25..c5b19e3f 100644 --- a/util/utils.py +++ b/util/utils.py @@ -46,7 +46,14 @@ def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None): if not device: - device = "cuda" if torch.cuda.is_available() else "cpu" + # device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + if model_name == "blip2": from transformers import Blip2Processor, Blip2ForConditionalGeneration processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") @@ -54,6 +61,10 @@ def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2 model = Blip2ForConditionalGeneration.from_pretrained( model_name_or_path, device_map=None, torch_dtype=torch.float32 ) + elif device == 'mps': + model = Blip2ForConditionalGeneration.from_pretrained( + model_name_or_path, device_map=None, torch_dtype=torch.float32 + ).to(device) else: model = Blip2ForConditionalGeneration.from_pretrained( model_name_or_path, device_map=None, torch_dtype=torch.float16 @@ -63,6 +74,8 @@ def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2 processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) if device == 'cpu': model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True) + elif device == 'mps': + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True).to(device) else: model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device) return {'model': model.to(device), 'processor': processor}