Skip to content

Commit ed22e6a

Browse files
Desktop: Add rudimentary support for custom WGPU adapter selection (#3201)
* rudimentary custom wgpu adapter selection * WgpuContextBuilder * wasm fix * fix wasm warnings * Clean up * Review suggestions * fix
1 parent 4e47b5d commit ed22e6a

File tree

8 files changed

+175
-48
lines changed

8 files changed

+175
-48
lines changed

desktop/src/gpu_context.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use graphite_desktop_wrapper::{WgpuContext, WgpuContextBuilder, WgpuFeatures};
2+
3+
pub(super) async fn create_wgpu_context() -> WgpuContext {
4+
let wgpu_context_builder = WgpuContextBuilder::new().with_features(WgpuFeatures::PUSH_CONSTANTS);
5+
6+
// TODO: add a cli flag to list adapters and exit instead of always printing
7+
println!("\nAvailable WGPU adapters:\n{}", wgpu_context_builder.available_adapters_fmt().await);
8+
9+
// TODO: make this configurable via cli flags instead
10+
let wgpu_context = match std::env::var("GRAPHITE_WGPU_ADAPTER").ok().and_then(|s| s.parse().ok()) {
11+
None => wgpu_context_builder.build().await,
12+
Some(adapter_index) => {
13+
tracing::info!("Overriding WGPU adapter selection with adapter index {adapter_index}");
14+
wgpu_context_builder.build_with_adapter_selection(|_| Some(adapter_index)).await
15+
}
16+
}
17+
.expect("Failed to create WGPU context");
18+
19+
// TODO: add a cli flag to list adapters and exit instead of always printing
20+
println!("Using WGPU adapter: {:?}", wgpu_context.adapter.get_info());
21+
22+
wgpu_context
23+
}

desktop/src/main.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ use std::process::exit;
22
use tracing_subscriber::EnvFilter;
33
use winit::event_loop::EventLoop;
44

5-
use graphite_desktop_wrapper::WgpuContext;
6-
75
pub(crate) mod consts;
86

97
mod app;
@@ -14,6 +12,8 @@ mod native_window;
1412
mod persist;
1513
mod render;
1614

15+
mod gpu_context;
16+
1717
use app::App;
1818
use cef::CefHandler;
1919
use event::CreateAppEventSchedulerEventLoopExt;
@@ -31,7 +31,7 @@ fn main() {
3131
return;
3232
}
3333

34-
let wgpu_context = futures::executor::block_on(WgpuContext::new()).unwrap();
34+
let wgpu_context = futures::executor::block_on(gpu_context::create_wgpu_context());
3535

3636
let event_loop = EventLoop::new().unwrap();
3737
let (app_event_sender, app_event_receiver) = std::sync::mpsc::channel();

desktop/wrapper/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ use graphite_editor::messages::prelude::{FrontendMessage, Message};
55
// TODO: Remove usage of this reexport in desktop create and remove this line
66
pub use graphene_std::Color;
77

8-
pub use wgpu_executor::Context as WgpuContext;
8+
pub use wgpu_executor::WgpuContext;
9+
pub use wgpu_executor::WgpuContextBuilder;
910
pub use wgpu_executor::WgpuExecutor;
11+
pub use wgpu_executor::WgpuFeatures;
1012

1113
pub mod messages;
1214
use messages::{DesktopFrontendMessage, DesktopWrapperMessage};

node-graph/graph-craft/src/wasm_application_io.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ impl WasmApplicationIo {
143143
io
144144
}
145145
#[cfg(all(not(target_family = "wasm"), feature = "wgpu"))]
146-
pub fn new_with_context(context: wgpu_executor::Context) -> Self {
146+
pub fn new_with_context(context: wgpu_executor::WgpuContext) -> Self {
147147
#[cfg(feature = "wgpu")]
148148
let executor = WgpuExecutor::with_context(context);
149149

Lines changed: 126 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,151 @@
11
use std::sync::Arc;
2-
use wgpu::{Device, Instance, Queue};
2+
use wgpu::{Adapter, Backends, Device, Features, Instance, Queue};
33

44
#[derive(Debug, Clone)]
55
pub struct Context {
66
pub device: Arc<Device>,
77
pub queue: Arc<Queue>,
88
pub instance: Arc<Instance>,
9-
pub adapter: Arc<wgpu::Adapter>,
9+
pub adapter: Arc<Adapter>,
1010
}
1111

1212
impl Context {
1313
pub async fn new() -> Option<Self> {
14-
// Instantiates instance of WebGPU
15-
let instance_descriptor = wgpu::InstanceDescriptor {
16-
backends: wgpu::Backends::all(),
17-
..Default::default()
18-
};
19-
let instance = Instance::new(&instance_descriptor);
14+
ContextBuilder::new().build().await
15+
}
16+
}
2017

21-
let adapter_options = wgpu::RequestAdapterOptions {
18+
#[derive(Default)]
19+
pub struct ContextBuilder {
20+
backends: Backends,
21+
features: Features,
22+
}
23+
impl ContextBuilder {
24+
pub fn new() -> Self {
25+
Self {
26+
backends: Backends::all(),
27+
features: Features::empty(),
28+
}
29+
}
30+
pub fn with_backends(mut self, backends: Backends) -> Self {
31+
self.backends = backends;
32+
self
33+
}
34+
pub fn with_features(mut self, features: Features) -> Self {
35+
self.features = features;
36+
self
37+
}
38+
}
39+
#[cfg(not(target_family = "wasm"))]
40+
impl ContextBuilder {
41+
pub async fn build(self) -> Option<Context> {
42+
self.build_with_adapter_selection_inner(None::<fn(&[Adapter]) -> Option<usize>>).await
43+
}
44+
pub async fn build_with_adapter_selection<S>(self, select: S) -> Option<Context>
45+
where
46+
S: Fn(&[Adapter]) -> Option<usize>,
47+
{
48+
self.build_with_adapter_selection_inner(Some(select)).await
49+
}
50+
pub async fn available_adapters_fmt(&self) -> impl std::fmt::Display {
51+
let instance = self.build_instance();
52+
fmt::AvailableAdaptersFormatter(instance.enumerate_adapters(self.backends))
53+
}
54+
}
55+
#[cfg(target_family = "wasm")]
56+
impl ContextBuilder {
57+
pub async fn build(self) -> Option<Context> {
58+
let instance = self.build_instance();
59+
let adapter = self.request_adapter(&instance).await?;
60+
let (device, queue) = self.request_device(&adapter).await?;
61+
Some(Context {
62+
device: Arc::new(device),
63+
queue: Arc::new(queue),
64+
adapter: Arc::new(adapter),
65+
instance: Arc::new(instance),
66+
})
67+
}
68+
}
69+
impl ContextBuilder {
70+
fn build_instance(&self) -> Instance {
71+
Instance::new(&wgpu::InstanceDescriptor {
72+
backends: self.backends,
73+
..Default::default()
74+
})
75+
}
76+
async fn request_adapter(&self, instance: &Instance) -> Option<Adapter> {
77+
let request_adapter_options = wgpu::RequestAdapterOptions {
2278
power_preference: wgpu::PowerPreference::HighPerformance,
2379
compatible_surface: None,
2480
force_fallback_adapter: false,
2581
};
26-
// `request_adapter` instantiates the general connection to the GPU
27-
let adapter = instance.request_adapter(&adapter_options).await.ok()?;
82+
instance.request_adapter(&request_adapter_options).await.ok()
83+
}
84+
async fn request_device(&self, adapter: &Adapter) -> Option<(Device, Queue)> {
85+
let device_descriptor = wgpu::DeviceDescriptor {
86+
label: None,
87+
required_features: self.features,
88+
required_limits: adapter.limits(),
89+
memory_hints: Default::default(),
90+
trace: wgpu::Trace::Off,
91+
};
92+
adapter.request_device(&device_descriptor).await.ok()
93+
}
94+
}
95+
#[cfg(not(target_family = "wasm"))]
96+
impl ContextBuilder {
97+
async fn build_with_adapter_selection_inner<S>(self, select: Option<S>) -> Option<Context>
98+
where
99+
S: Fn(&[Adapter]) -> Option<usize>,
100+
{
101+
let instance = self.build_instance();
102+
103+
let selected_adapter = if let Some(select) = select {
104+
self.select_adapter(&instance, select)
105+
} else if cfg!(target_os = "windows") {
106+
self.select_adapter(&instance, |adapters: &[Adapter]| adapters.iter().position(|a| a.get_info().backend == wgpu::Backend::Dx12))
107+
} else {
108+
None
109+
};
28110

29-
let required_limits = adapter.limits();
30-
// `request_device` instantiates the feature specific connection to the GPU, defining some parameters,
31-
// `features` being the available features.
32-
let (device, queue) = adapter
33-
.request_device(&wgpu::DeviceDescriptor {
34-
label: None,
35-
#[cfg(target_family = "wasm")]
36-
required_features: wgpu::Features::empty(),
37-
#[cfg(not(target_family = "wasm"))]
38-
required_features: wgpu::Features::PUSH_CONSTANTS,
39-
required_limits,
40-
memory_hints: Default::default(),
41-
trace: wgpu::Trace::Off,
42-
})
43-
.await
44-
.ok()?;
111+
let adapter = if let Some(adapter) = selected_adapter { adapter } else { self.request_adapter(&instance).await? };
45112

46-
Some(Self {
113+
let (device, queue) = self.request_device(&adapter).await?;
114+
Some(Context {
47115
device: Arc::new(device),
48116
queue: Arc::new(queue),
49117
adapter: Arc::new(adapter),
50118
instance: Arc::new(instance),
51119
})
52120
}
121+
fn select_adapter<S>(&self, instance: &Instance, select: S) -> Option<Adapter>
122+
where
123+
S: Fn(&[Adapter]) -> Option<usize>,
124+
{
125+
let mut adapters = instance.enumerate_adapters(self.backends);
126+
let selected_index = select(&adapters)?;
127+
if selected_index >= adapters.len() {
128+
return None;
129+
}
130+
Some(adapters.remove(selected_index))
131+
}
132+
}
133+
#[cfg(not(target_family = "wasm"))]
134+
mod fmt {
135+
use super::*;
136+
137+
pub(super) struct AvailableAdaptersFormatter(pub(super) Vec<Adapter>);
138+
impl std::fmt::Display for AvailableAdaptersFormatter {
139+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140+
for (i, adapter) in self.0.iter().enumerate() {
141+
let info = adapter.get_info();
142+
writeln!(
143+
f,
144+
"[{}] {:?} {:?} (Name: {}, Driver: {}, Device: {})",
145+
i, info.backend, info.device_type, info.name, info.driver, info.device,
146+
)?;
147+
}
148+
Ok(())
149+
}
150+
}
53151
}

node-graph/wgpu-executor/src/lib.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ pub mod texture_upload;
44

55
use crate::shader_runtime::ShaderRuntime;
66
use anyhow::Result;
7-
pub use context::Context;
87
use dyn_any::StaticType;
98
use futures::lock::Mutex;
109
use glam::UVec2;
@@ -16,9 +15,14 @@ use vello::{AaConfig, AaSupport, RenderParams, Renderer, RendererOptions, Scene}
1615
use wgpu::util::TextureBlitter;
1716
use wgpu::{Origin3d, SurfaceConfiguration, TextureAspect};
1817

18+
pub use context::Context as WgpuContext;
19+
pub use context::ContextBuilder as WgpuContextBuilder;
20+
pub use wgpu::Backends as WgpuBackends;
21+
pub use wgpu::Features as WgpuFeatures;
22+
1923
#[derive(dyn_any::DynAny)]
2024
pub struct WgpuExecutor {
21-
pub context: Context,
25+
pub context: WgpuContext,
2226
vello_renderer: Mutex<Renderer>,
2327
pub shader_runtime: ShaderRuntime,
2428
}
@@ -182,10 +186,10 @@ impl WgpuExecutor {
182186

183187
impl WgpuExecutor {
184188
pub async fn new() -> Option<Self> {
185-
Self::with_context(Context::new().await?)
189+
Self::with_context(WgpuContext::new().await?)
186190
}
187191

188-
pub fn with_context(context: Context) -> Option<Self> {
192+
pub fn with_context(context: WgpuContext) -> Option<Self> {
189193
let vello_renderer = Renderer::new(
190194
&context.device,
191195
RendererOptions {

node-graph/wgpu-executor/src/shader_runtime/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
use crate::Context;
1+
use crate::WgpuContext;
22
use crate::shader_runtime::per_pixel_adjust_runtime::PerPixelAdjustShaderRuntime;
33

44
pub mod per_pixel_adjust_runtime;
55

66
pub const FULLSCREEN_VERTEX_SHADER_NAME: &str = "fullscreen_vertexfullscreen_vertex";
77

88
pub struct ShaderRuntime {
9-
context: Context,
9+
context: WgpuContext,
1010
per_pixel_adjust: PerPixelAdjustShaderRuntime,
1111
}
1212

1313
impl ShaderRuntime {
14-
pub fn new(context: &Context) -> Self {
14+
pub fn new(context: &WgpuContext) -> Self {
1515
Self {
1616
context: context.clone(),
1717
per_pixel_adjust: PerPixelAdjustShaderRuntime::new(),

node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::Context;
1+
use crate::WgpuContext;
22
use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime};
33
use futures::lock::Mutex;
44
use graphene_core::raster_types::{GPU, Raster};
@@ -31,7 +31,7 @@ impl ShaderRuntime {
3131
let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await;
3232
let pipeline = cache
3333
.entry(shaders.fragment_shader_name.to_owned())
34-
.or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &shaders));
34+
.or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, shaders));
3535

3636
let arg_buffer = args.map(|args| {
3737
let device = &self.context.device;
@@ -58,7 +58,7 @@ pub struct PerPixelAdjustGraphicsPipeline {
5858
}
5959

6060
impl PerPixelAdjustGraphicsPipeline {
61-
pub fn new(context: &Context, info: &Shaders) -> Self {
61+
pub fn new(context: &WgpuContext, info: &Shaders) -> Self {
6262
let device = &context.device;
6363
let name = info.fragment_shader_name.to_owned();
6464

@@ -67,7 +67,7 @@ impl PerPixelAdjustGraphicsPipeline {
6767
// TODO workaround to naga removing `:`
6868
let fragment_name = fragment_name.replace(":", "");
6969
let shader_module = device.create_shader_module(ShaderModuleDescriptor {
70-
label: Some(&format!("PerPixelAdjust {} wgsl shader", name)),
70+
label: Some(&format!("PerPixelAdjust {name} wgsl shader")),
7171
source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)),
7272
});
7373

@@ -107,16 +107,16 @@ impl PerPixelAdjustGraphicsPipeline {
107107
}]
108108
};
109109
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
110-
label: Some(&format!("PerPixelAdjust {} PipelineLayout", name)),
110+
label: Some(&format!("PerPixelAdjust {name} PipelineLayout")),
111111
bind_group_layouts: &[&device.create_bind_group_layout(&BindGroupLayoutDescriptor {
112-
label: Some(&format!("PerPixelAdjust {} BindGroupLayout 0", name)),
112+
label: Some(&format!("PerPixelAdjust {name} BindGroupLayout 0")),
113113
entries,
114114
})],
115115
push_constant_ranges: &[],
116116
});
117117

118118
let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor {
119-
label: Some(&format!("PerPixelAdjust {} Pipeline", name)),
119+
label: Some(&format!("PerPixelAdjust {name} Pipeline")),
120120
layout: Some(&pipeline_layout),
121121
vertex: VertexState {
122122
module: &shader_module,
@@ -155,7 +155,7 @@ impl PerPixelAdjustGraphicsPipeline {
155155
}
156156
}
157157

158-
pub fn dispatch(&self, context: &Context, textures: Table<Raster<GPU>>, arg_buffer: Option<Buffer>) -> Table<Raster<GPU>> {
158+
pub fn dispatch(&self, context: &WgpuContext, textures: Table<Raster<GPU>>, arg_buffer: Option<Buffer>) -> Table<Raster<GPU>> {
159159
assert_eq!(self.has_uniform, arg_buffer.is_some());
160160
let device = &context.device;
161161
let name = self.name.as_str();

0 commit comments

Comments
 (0)