diff --git a/player/src/lib.rs b/player/src/lib.rs index ca3a4b6a57..c67c605e58 100644 --- a/player/src/lib.rs +++ b/player/src/lib.rs @@ -99,7 +99,7 @@ impl GlobalPlay for wgc::global::Global { base, timestamp_writes, } => { - self.command_encoder_run_compute_pass_impl::( + self.command_encoder_run_compute_pass_with_unresolved_commands::( encoder, base.as_ref(), timestamp_writes.as_ref(), diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 49492ac037..4ee48f0086 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -1,3 +1,4 @@ +use crate::command::compute_command::{ArcComputeCommand, ComputeCommand}; use crate::device::DeviceError; use crate::resource::Resource; use crate::snatch::SnatchGuard; @@ -20,7 +21,6 @@ use crate::{ hal_label, id, id::DeviceId, init_tracker::MemoryInitKind, - pipeline, resource::{self}, storage::Storage, track::{Tracker, UsageConflict, UsageScope}, @@ -39,59 +39,6 @@ use thiserror::Error; use std::sync::Arc; use std::{fmt, mem, str}; -#[doc(hidden)] -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub enum ComputeCommand { - SetBindGroup { - index: u32, - num_dynamic_offsets: usize, - bind_group_id: id::BindGroupId, - }, - SetPipeline(id::ComputePipelineId), - - /// Set a range of push constants to values stored in [`BasePass::push_constant_data`]. - SetPushConstant { - /// The byte offset within the push constant storage to write to. This - /// must be a multiple of four. - offset: u32, - - /// The number of bytes to write. This must be a multiple of four. - size_bytes: u32, - - /// Index in [`BasePass::push_constant_data`] of the start of the data - /// to be written. - /// - /// Note: this is not a byte offset like `offset`. Rather, it is the - /// index of the first `u32` element in `push_constant_data` to read. - values_offset: u32, - }, - - Dispatch([u32; 3]), - DispatchIndirect { - buffer_id: id::BufferId, - offset: wgt::BufferAddress, - }, - PushDebugGroup { - color: u32, - len: usize, - }, - PopDebugGroup, - InsertDebugMarker { - color: u32, - len: usize, - }, - WriteTimestamp { - query_set_id: id::QuerySetId, - query_index: u32, - }, - BeginPipelineStatisticsQuery { - query_set_id: id::QuerySetId, - query_index: u32, - }, - EndPipelineStatisticsQuery, -} - #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct ComputePass { base: BasePass, @@ -185,7 +132,7 @@ pub enum ComputePassErrorInner { #[error(transparent)] Encoder(#[from] CommandEncoderError), #[error("Bind group at index {0:?} is invalid")] - InvalidBindGroup(usize), + InvalidBindGroup(u32), #[error("Device {0:?} is invalid")] InvalidDevice(DeviceId), #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")] @@ -250,7 +197,7 @@ impl PrettyError for ComputePassErrorInner { pub struct ComputePassError { pub scope: PassErrorScope, #[source] - inner: ComputePassErrorInner, + pub(super) inner: ComputePassErrorInner, } impl PrettyError for ComputePassError { fn fmt_pretty(&self, fmt: &mut ErrorFormatter) { @@ -347,7 +294,8 @@ impl Global { encoder_id: id::CommandEncoderId, pass: &ComputePass, ) -> Result<(), ComputePassError> { - self.command_encoder_run_compute_pass_impl::( + // TODO: This should go directly to `command_encoder_run_compute_pass_impl` by means of storing `ArcComputeCommand` internally. + self.command_encoder_run_compute_pass_with_unresolved_commands::( encoder_id, pass.base.as_ref(), pass.timestamp_writes.as_ref(), @@ -355,11 +303,33 @@ impl Global { } #[doc(hidden)] - pub fn command_encoder_run_compute_pass_impl( + pub fn command_encoder_run_compute_pass_with_unresolved_commands( &self, encoder_id: id::CommandEncoderId, base: BasePassRef, timestamp_writes: Option<&ComputePassTimestampWrites>, + ) -> Result<(), ComputePassError> { + let resolved_commands = + ComputeCommand::resolve_compute_command_ids(A::hub(self), base.commands)?; + + self.command_encoder_run_compute_pass_impl::( + encoder_id, + BasePassRef { + label: base.label, + commands: &resolved_commands, + dynamic_offsets: base.dynamic_offsets, + string_data: base.string_data, + push_constant_data: base.push_constant_data, + }, + timestamp_writes, + ) + } + + fn command_encoder_run_compute_pass_impl( + &self, + encoder_id: id::CommandEncoderId, + base: BasePassRef>, + timestamp_writes: Option<&ComputePassTimestampWrites>, ) -> Result<(), ComputePassError> { profiling::scope!("CommandEncoder::run_compute_pass"); let pass_scope = PassErrorScope::Pass(encoder_id); @@ -382,7 +352,13 @@ impl Global { #[cfg(feature = "trace")] if let Some(ref mut list) = cmd_buf_data.commands { list.push(crate::device::trace::Command::RunComputePass { - base: BasePass::from_ref(base), + base: BasePass { + label: base.label.map(str::to_string), + commands: base.commands.iter().map(Into::into).collect(), + dynamic_offsets: base.dynamic_offsets.to_vec(), + string_data: base.string_data.to_vec(), + push_constant_data: base.push_constant_data.to_vec(), + }, timestamp_writes: timestamp_writes.cloned(), }); } @@ -402,9 +378,7 @@ impl Global { let raw = encoder.open().map_pass_err(pass_scope)?; let bind_group_guard = hub.bind_groups.read(); - let pipeline_guard = hub.compute_pipelines.read(); let query_set_guard = hub.query_sets.read(); - let buffer_guard = hub.buffers.read(); let mut state = State { binder: Binder::new(), @@ -482,19 +456,21 @@ impl Global { // be inserted before texture reads. let mut pending_discard_init_fixups = SurfacesInDiscardState::new(); + // TODO: We should be draining the commands here, avoiding extra copies in the process. + // (A command encoder can't be executed twice!) for command in base.commands { - match *command { - ComputeCommand::SetBindGroup { + match command { + ArcComputeCommand::SetBindGroup { index, num_dynamic_offsets, - bind_group_id, + bind_group, } => { - let scope = PassErrorScope::SetBindGroup(bind_group_id); + let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id()); let max_bind_groups = cmd_buf.limits.max_bind_groups; - if index >= max_bind_groups { + if index >= &max_bind_groups { return Err(ComputePassErrorInner::BindGroupIndexOutOfRange { - index, + index: *index, max: max_bind_groups, }) .map_pass_err(scope); @@ -507,13 +483,9 @@ impl Global { ); dynamic_offset_count += num_dynamic_offsets; - let bind_group = tracker - .bind_groups - .add_single(&*bind_group_guard, bind_group_id) - .ok_or(ComputePassErrorInner::InvalidBindGroup(index as usize)) - .map_pass_err(scope)?; + let bind_group = tracker.bind_groups.insert_single(bind_group.clone()); bind_group - .validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits) + .validate_dynamic_bindings(*index, &temp_offsets, &cmd_buf.limits) .map_pass_err(scope)?; buffer_memory_init_actions.extend( @@ -535,14 +507,14 @@ impl Global { let entries = state .binder - .assign_group(index as usize, bind_group, &temp_offsets); + .assign_group(*index as usize, bind_group, &temp_offsets); if !entries.is_empty() && pipeline_layout.is_some() { let pipeline_layout = pipeline_layout.as_ref().unwrap().raw(); for (i, e) in entries.iter().enumerate() { if let Some(group) = e.group.as_ref() { let raw_bg = group .raw(&snatch_guard) - .ok_or(ComputePassErrorInner::InvalidBindGroup(i)) + .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32)) .map_pass_err(scope)?; unsafe { raw.set_bind_group( @@ -556,16 +528,13 @@ impl Global { } } } - ComputeCommand::SetPipeline(pipeline_id) => { + ArcComputeCommand::SetPipeline(pipeline) => { + let pipeline_id = pipeline.as_info().id(); let scope = PassErrorScope::SetPipelineCompute(pipeline_id); state.pipeline = Some(pipeline_id); - let pipeline: &pipeline::ComputePipeline = tracker - .compute_pipelines - .add_single(&*pipeline_guard, pipeline_id) - .ok_or(ComputePassErrorInner::InvalidPipeline(pipeline_id)) - .map_pass_err(scope)?; + tracker.compute_pipelines.insert_single(pipeline.clone()); unsafe { raw.set_compute_pipeline(pipeline.raw()); @@ -589,7 +558,7 @@ impl Global { if let Some(group) = e.group.as_ref() { let raw_bg = group .raw(&snatch_guard) - .ok_or(ComputePassErrorInner::InvalidBindGroup(i)) + .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32)) .map_pass_err(scope)?; unsafe { raw.set_bind_group( @@ -625,7 +594,7 @@ impl Global { } } } - ComputeCommand::SetPushConstant { + ArcComputeCommand::SetPushConstant { offset, size_bytes, values_offset, @@ -636,7 +605,7 @@ impl Global { let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize; let data_slice = - &base.push_constant_data[(values_offset as usize)..values_end_offset]; + &base.push_constant_data[(*values_offset as usize)..values_end_offset]; let pipeline_layout = state .binder @@ -651,7 +620,7 @@ impl Global { pipeline_layout .validate_push_constant_ranges( wgt::ShaderStages::COMPUTE, - offset, + *offset, end_offset_bytes, ) .map_pass_err(scope)?; @@ -660,12 +629,12 @@ impl Global { raw.set_push_constants( pipeline_layout.raw(), wgt::ShaderStages::COMPUTE, - offset, + *offset, data_slice, ); } } - ComputeCommand::Dispatch(groups) => { + ArcComputeCommand::Dispatch(groups) => { let scope = PassErrorScope::Dispatch { indirect: false, pipeline: state.pipeline, @@ -690,7 +659,7 @@ impl Global { { return Err(ComputePassErrorInner::Dispatch( DispatchError::InvalidGroupSize { - current: groups, + current: *groups, limit: groups_size_limit, }, )) @@ -698,10 +667,11 @@ impl Global { } unsafe { - raw.dispatch(groups); + raw.dispatch(*groups); } } - ComputeCommand::DispatchIndirect { buffer_id, offset } => { + ArcComputeCommand::DispatchIndirect { buffer, offset } => { + let buffer_id = buffer.as_info().id(); let scope = PassErrorScope::Dispatch { indirect: true, pipeline: state.pipeline, @@ -713,29 +683,25 @@ impl Global { .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION) .map_pass_err(scope)?; - let indirect_buffer = state + state .scope .buffers - .merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDIRECT) + .insert_merge_single(buffer.clone(), hal::BufferUses::INDIRECT) + .map_pass_err(scope)?; + check_buffer_usage(buffer_id, buffer.usage, wgt::BufferUsages::INDIRECT) .map_pass_err(scope)?; - check_buffer_usage( - buffer_id, - indirect_buffer.usage, - wgt::BufferUsages::INDIRECT, - ) - .map_pass_err(scope)?; let end_offset = offset + mem::size_of::() as u64; - if end_offset > indirect_buffer.size { + if end_offset > buffer.size { return Err(ComputePassErrorInner::IndirectBufferOverrun { - offset, + offset: *offset, end_offset, - buffer_size: indirect_buffer.size, + buffer_size: buffer.size, }) .map_pass_err(scope); } - let buf_raw = indirect_buffer + let buf_raw = buffer .raw .get(&snatch_guard) .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id)) @@ -744,9 +710,9 @@ impl Global { let stride = 3 * 4; // 3 integers, x/y/z group size buffer_memory_init_actions.extend( - indirect_buffer.initialization_status.read().create_action( - indirect_buffer, - offset..(offset + stride), + buffer.initialization_status.read().create_action( + buffer, + *offset..(*offset + stride), MemoryInitKind::NeedsInitializedMemory, ), ); @@ -756,15 +722,15 @@ impl Global { raw, &mut intermediate_trackers, &*bind_group_guard, - Some(indirect_buffer.as_info().tracker_index()), + Some(buffer.as_info().tracker_index()), &snatch_guard, ) .map_pass_err(scope)?; unsafe { - raw.dispatch_indirect(buf_raw, offset); + raw.dispatch_indirect(buf_raw, *offset); } } - ComputeCommand::PushDebugGroup { color: _, len } => { + ArcComputeCommand::PushDebugGroup { color: _, len } => { state.debug_scope_depth += 1; if !discard_hal_labels { let label = @@ -776,7 +742,7 @@ impl Global { } string_offset += len; } - ComputeCommand::PopDebugGroup => { + ArcComputeCommand::PopDebugGroup => { let scope = PassErrorScope::PopDebugGroup; if state.debug_scope_depth == 0 { @@ -790,7 +756,7 @@ impl Global { } } } - ComputeCommand::InsertDebugMarker { color: _, len } => { + ArcComputeCommand::InsertDebugMarker { color: _, len } => { if !discard_hal_labels { let label = str::from_utf8(&base.string_data[string_offset..string_offset + len]) @@ -799,49 +765,43 @@ impl Global { } string_offset += len; } - ComputeCommand::WriteTimestamp { - query_set_id, + ArcComputeCommand::WriteTimestamp { + query_set, query_index, } => { + let query_set_id = query_set.as_info().id(); let scope = PassErrorScope::WriteTimestamp; device .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES) .map_pass_err(scope)?; - let query_set: &resource::QuerySet = tracker - .query_sets - .add_single(&*query_set_guard, query_set_id) - .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id)) - .map_pass_err(scope)?; + let query_set = tracker.query_sets.insert_single(query_set.clone()); query_set - .validate_and_write_timestamp(raw, query_set_id, query_index, None) + .validate_and_write_timestamp(raw, query_set_id, *query_index, None) .map_pass_err(scope)?; } - ComputeCommand::BeginPipelineStatisticsQuery { - query_set_id, + ArcComputeCommand::BeginPipelineStatisticsQuery { + query_set, query_index, } => { + let query_set_id = query_set.as_info().id(); let scope = PassErrorScope::BeginPipelineStatisticsQuery; - let query_set: &resource::QuerySet = tracker - .query_sets - .add_single(&*query_set_guard, query_set_id) - .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id)) - .map_pass_err(scope)?; + let query_set = tracker.query_sets.insert_single(query_set.clone()); query_set .validate_and_begin_pipeline_statistics_query( raw, query_set_id, - query_index, + *query_index, None, &mut active_query, ) .map_pass_err(scope)?; } - ComputeCommand::EndPipelineStatisticsQuery => { + ArcComputeCommand::EndPipelineStatisticsQuery => { let scope = PassErrorScope::EndPipelineStatisticsQuery; end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query) diff --git a/wgpu-core/src/command/compute_command.rs b/wgpu-core/src/command/compute_command.rs new file mode 100644 index 0000000000..49fdbbec24 --- /dev/null +++ b/wgpu-core/src/command/compute_command.rs @@ -0,0 +1,322 @@ +use std::sync::Arc; + +use crate::{ + binding_model::BindGroup, + hal_api::HalApi, + id, + pipeline::ComputePipeline, + resource::{Buffer, QuerySet}, +}; + +use super::{ComputePassError, ComputePassErrorInner, PassErrorScope}; + +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum ComputeCommand { + SetBindGroup { + index: u32, + num_dynamic_offsets: usize, + bind_group_id: id::BindGroupId, + }, + + SetPipeline(id::ComputePipelineId), + + /// Set a range of push constants to values stored in `push_constant_data`. + SetPushConstant { + /// The byte offset within the push constant storage to write to. This + /// must be a multiple of four. + offset: u32, + + /// The number of bytes to write. This must be a multiple of four. + size_bytes: u32, + + /// Index in `push_constant_data` of the start of the data + /// to be written. + /// + /// Note: this is not a byte offset like `offset`. Rather, it is the + /// index of the first `u32` element in `push_constant_data` to read. + values_offset: u32, + }, + + Dispatch([u32; 3]), + + DispatchIndirect { + buffer_id: id::BufferId, + offset: wgt::BufferAddress, + }, + + PushDebugGroup { + color: u32, + len: usize, + }, + + PopDebugGroup, + + InsertDebugMarker { + color: u32, + len: usize, + }, + + WriteTimestamp { + query_set_id: id::QuerySetId, + query_index: u32, + }, + + BeginPipelineStatisticsQuery { + query_set_id: id::QuerySetId, + query_index: u32, + }, + + EndPipelineStatisticsQuery, +} + +impl ComputeCommand { + /// Resolves all ids in a list of commands into the corresponding resource Arc. + /// + // TODO: Once resolving is done on-the-fly during recording, this function should be only needed with the replay feature: + // #[cfg(feature = "replay")] + pub fn resolve_compute_command_ids( + hub: &crate::hub::Hub, + commands: &[ComputeCommand], + ) -> Result>, ComputePassError> { + let buffers_guard = hub.buffers.read(); + let bind_group_guard = hub.bind_groups.read(); + let query_set_guard = hub.query_sets.read(); + let pipelines_guard = hub.compute_pipelines.read(); + + let resolved_commands: Vec> = commands + .iter() + .map(|c| -> Result, ComputePassError> { + Ok(match *c { + ComputeCommand::SetBindGroup { + index, + num_dynamic_offsets, + bind_group_id, + } => ArcComputeCommand::SetBindGroup { + index, + num_dynamic_offsets, + bind_group: bind_group_guard.get_owned(bind_group_id).map_err(|_| { + ComputePassError { + scope: PassErrorScope::SetBindGroup(bind_group_id), + inner: ComputePassErrorInner::InvalidBindGroup(index), + } + })?, + }, + + ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline( + pipelines_guard + .get_owned(pipeline_id) + .map_err(|_| ComputePassError { + scope: PassErrorScope::SetPipelineCompute(pipeline_id), + inner: ComputePassErrorInner::InvalidPipeline(pipeline_id), + })?, + ), + + ComputeCommand::SetPushConstant { + offset, + size_bytes, + values_offset, + } => ArcComputeCommand::SetPushConstant { + offset, + size_bytes, + values_offset, + }, + + ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim), + + ComputeCommand::DispatchIndirect { buffer_id, offset } => { + ArcComputeCommand::DispatchIndirect { + buffer: buffers_guard.get_owned(buffer_id).map_err(|_| { + ComputePassError { + scope: PassErrorScope::Dispatch { + indirect: true, + pipeline: None, // TODO: not used right now, but once we do the resolve during recording we can use this again. + }, + inner: ComputePassErrorInner::InvalidBuffer(buffer_id), + } + })?, + offset, + } + } + + ComputeCommand::PushDebugGroup { color, len } => { + ArcComputeCommand::PushDebugGroup { color, len } + } + + ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup, + + ComputeCommand::InsertDebugMarker { color, len } => { + ArcComputeCommand::InsertDebugMarker { color, len } + } + + ComputeCommand::WriteTimestamp { + query_set_id, + query_index, + } => ArcComputeCommand::WriteTimestamp { + query_set: query_set_guard.get_owned(query_set_id).map_err(|_| { + ComputePassError { + scope: PassErrorScope::WriteTimestamp, + inner: ComputePassErrorInner::InvalidQuerySet(query_set_id), + } + })?, + query_index, + }, + + ComputeCommand::BeginPipelineStatisticsQuery { + query_set_id, + query_index, + } => ArcComputeCommand::BeginPipelineStatisticsQuery { + query_set: query_set_guard.get_owned(query_set_id).map_err(|_| { + ComputePassError { + scope: PassErrorScope::BeginPipelineStatisticsQuery, + inner: ComputePassErrorInner::InvalidQuerySet(query_set_id), + } + })?, + query_index, + }, + + ComputeCommand::EndPipelineStatisticsQuery => { + ArcComputeCommand::EndPipelineStatisticsQuery + } + }) + }) + .collect::, ComputePassError>>()?; + Ok(resolved_commands) + } +} + +/// Equivalent to `ComputeCommand` but the Ids resolved into resource Arcs. +#[derive(Clone, Debug)] +pub enum ArcComputeCommand { + SetBindGroup { + index: u32, + num_dynamic_offsets: usize, + bind_group: Arc>, + }, + + SetPipeline(Arc>), + + /// Set a range of push constants to values stored in `push_constant_data`. + SetPushConstant { + /// The byte offset within the push constant storage to write to. This + /// must be a multiple of four. + offset: u32, + + /// The number of bytes to write. This must be a multiple of four. + size_bytes: u32, + + /// Index in `push_constant_data` of the start of the data + /// to be written. + /// + /// Note: this is not a byte offset like `offset`. Rather, it is the + /// index of the first `u32` element in `push_constant_data` to read. + values_offset: u32, + }, + + Dispatch([u32; 3]), + + DispatchIndirect { + buffer: Arc>, + offset: wgt::BufferAddress, + }, + + PushDebugGroup { + color: u32, + len: usize, + }, + + PopDebugGroup, + + InsertDebugMarker { + color: u32, + len: usize, + }, + + WriteTimestamp { + query_set: Arc>, + query_index: u32, + }, + + BeginPipelineStatisticsQuery { + query_set: Arc>, + query_index: u32, + }, + + EndPipelineStatisticsQuery, +} + +#[cfg(feature = "trace")] +impl From<&ArcComputeCommand> for ComputeCommand { + fn from(value: &ArcComputeCommand) -> Self { + use crate::resource::Resource as _; + + match value { + ArcComputeCommand::SetBindGroup { + index, + num_dynamic_offsets, + bind_group, + } => ComputeCommand::SetBindGroup { + index: *index, + num_dynamic_offsets: *num_dynamic_offsets, + bind_group_id: bind_group.as_info().id(), + }, + + ArcComputeCommand::SetPipeline(pipeline) => { + ComputeCommand::SetPipeline(pipeline.as_info().id()) + } + + ArcComputeCommand::SetPushConstant { + offset, + size_bytes, + values_offset, + } => ComputeCommand::SetPushConstant { + offset: *offset, + size_bytes: *size_bytes, + values_offset: *values_offset, + }, + + ArcComputeCommand::Dispatch(dim) => ComputeCommand::Dispatch(*dim), + + ArcComputeCommand::DispatchIndirect { buffer, offset } => { + ComputeCommand::DispatchIndirect { + buffer_id: buffer.as_info().id(), + offset: *offset, + } + } + + ArcComputeCommand::PushDebugGroup { color, len } => ComputeCommand::PushDebugGroup { + color: *color, + len: *len, + }, + + ArcComputeCommand::PopDebugGroup => ComputeCommand::PopDebugGroup, + + ArcComputeCommand::InsertDebugMarker { color, len } => { + ComputeCommand::InsertDebugMarker { + color: *color, + len: *len, + } + } + + ArcComputeCommand::WriteTimestamp { + query_set, + query_index, + } => ComputeCommand::WriteTimestamp { + query_set_id: query_set.as_info().id(), + query_index: *query_index, + }, + + ArcComputeCommand::BeginPipelineStatisticsQuery { + query_set, + query_index, + } => ComputeCommand::BeginPipelineStatisticsQuery { + query_set_id: query_set.as_info().id(), + query_index: *query_index, + }, + + ArcComputeCommand::EndPipelineStatisticsQuery => { + ComputeCommand::EndPipelineStatisticsQuery + } + } + } +} diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index e812b4f8fd..d53f47bf42 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -3,6 +3,7 @@ mod bind; mod bundle; mod clear; mod compute; +mod compute_command; mod draw; mod memory_init; mod query; @@ -13,7 +14,8 @@ use std::sync::Arc; pub(crate) use self::clear::clear_texture; pub use self::{ - bundle::*, clear::ClearError, compute::*, draw::*, query::*, render::*, transfer::*, + bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, query::*, + render::*, transfer::*, }; pub(crate) use allocator::CommandAllocator; diff --git a/wgpu-core/src/track/buffer.rs b/wgpu-core/src/track/buffer.rs index 8ea92d4841..9a52a53253 100644 --- a/wgpu-core/src/track/buffer.rs +++ b/wgpu-core/src/track/buffer.rs @@ -245,6 +245,22 @@ impl BufferUsageScope { .get(id) .map_err(|_| UsageConflict::BufferInvalid { id })?; + self.insert_merge_single(buffer.clone(), new_state) + .map(|_| buffer) + } + + /// Merge a single state into the UsageScope, using an already resolved buffer. + /// + /// If the resulting state is invalid, returns a usage + /// conflict with the details of the invalid state. + /// + /// If the ID is higher than the length of internal vectors, + /// the vectors will be extended. A call to set_size is not needed. + pub fn insert_merge_single( + &mut self, + buffer: Arc>, + new_state: BufferUses, + ) -> Result<(), UsageConflict> { let index = buffer.info.tracker_index().as_usize(); self.allow_index(index); @@ -260,12 +276,12 @@ impl BufferUsageScope { index, BufferStateProvider::Direct { state: new_state }, ResourceMetadataProvider::Direct { - resource: Cow::Owned(buffer.clone()), + resource: Cow::Owned(buffer), }, )?; } - Ok(buffer) + Ok(()) } } diff --git a/wgpu-core/src/track/metadata.rs b/wgpu-core/src/track/metadata.rs index 3e71e0e084..d6e8d6f906 100644 --- a/wgpu-core/src/track/metadata.rs +++ b/wgpu-core/src/track/metadata.rs @@ -87,16 +87,18 @@ impl ResourceMetadata { /// Add the resource with the given index, epoch, and reference count to the /// set. /// + /// Returns a reference to the newly inserted resource. + /// (This allows avoiding a clone/reference count increase in many cases.) + /// /// # Safety /// /// The given `index` must be in bounds for this `ResourceMetadata`'s /// existing tables. See `tracker_assert_in_bounds`. #[inline(always)] - pub(super) unsafe fn insert(&mut self, index: usize, resource: Arc) { + pub(super) unsafe fn insert(&mut self, index: usize, resource: Arc) -> &Arc { self.owned.set(index, true); - unsafe { - *self.resources.get_unchecked_mut(index) = Some(resource); - } + let resource_dst = unsafe { self.resources.get_unchecked_mut(index) }; + resource_dst.insert(resource) } /// Get the resource with the given index. diff --git a/wgpu-core/src/track/stateless.rs b/wgpu-core/src/track/stateless.rs index 26caf165cb..25ffc027ee 100644 --- a/wgpu-core/src/track/stateless.rs +++ b/wgpu-core/src/track/stateless.rs @@ -158,16 +158,17 @@ impl StatelessTracker { /// /// If the ID is higher than the length of internal vectors, /// the vectors will be extended. A call to set_size is not needed. - pub fn insert_single(&mut self, resource: Arc) { + /// + /// Returns a reference to the newly inserted resource. + /// (This allows avoiding a clone/reference count increase in many cases.) + pub fn insert_single(&mut self, resource: Arc) -> &Arc { let index = resource.as_info().tracker_index().as_usize(); self.allow_index(index); self.tracker_assert_in_bounds(index); - unsafe { - self.metadata.insert(index, resource); - } + unsafe { self.metadata.insert(index, resource) } } /// Adds the given resource to the tracker.