mod allocator;
mod bind;
mod bundle;
mod clear;
mod compute;
mod compute_command;
mod draw;
mod memory_init;
mod query;
mod ray_tracing;
mod render;
mod render_command;
mod timestamp_writes;
mod transfer;
use std::mem::{self, ManuallyDrop};
use std::sync::Arc;
pub(crate) use self::clear::clear_texture;
pub use self::{
bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, query::*,
render::*, render_command::RenderCommand, transfer::*,
};
pub(crate) use allocator::CommandAllocator;
pub(crate) use timestamp_writes::ArcPassTimestampWrites;
pub use timestamp_writes::PassTimestampWrites;
use self::memory_init::CommandBufferTextureMemoryActions;
use crate::device::{Device, DeviceError, MissingFeatures};
use crate::lock::{rank, Mutex};
use crate::snatch::SnatchGuard;
use crate::init_tracker::BufferInitTrackerAction;
use crate::ray_tracing::{BlasAction, TlasAction};
use crate::resource::{Fallible, InvalidResourceError, Labeled, ParentDevice as _, QuerySet};
use crate::storage::Storage;
use crate::track::{DeviceTracker, Tracker, UsageScope};
use crate::LabelHelpers;
use crate::{api_log, global::Global, id, resource_log, Label};
use thiserror::Error;
#[cfg(feature = "trace")]
use crate::device::trace::Command as TraceCommand;
const PUSH_CONSTANT_CLEAR_ARRAY: &[u32] = &[0_u32; 64];
pub(crate) enum CommandEncoderStatus {
Recording(CommandBufferMutable),
Locked(CommandBufferMutable),
Finished(CommandBufferMutable),
Error,
}
impl CommandEncoderStatus {
pub(crate) fn record(&mut self) -> Result<RecordingGuard<'_>, CommandEncoderError> {
match self {
Self::Recording(_) => Ok(RecordingGuard { inner: self }),
Self::Locked(_) => {
*self = Self::Error;
Err(CommandEncoderError::Locked)
}
Self::Finished(_) => Err(CommandEncoderError::NotRecording),
Self::Error => Err(CommandEncoderError::Invalid),
}
}
#[cfg(feature = "trace")]
fn get_inner(&mut self) -> Result<&mut CommandBufferMutable, CommandEncoderError> {
match self {
Self::Locked(inner) | Self::Finished(inner) | Self::Recording(inner) => Ok(inner),
Self::Error => Err(CommandEncoderError::Invalid),
}
}
fn lock_encoder(&mut self) -> Result<(), CommandEncoderError> {
match mem::replace(self, Self::Error) {
Self::Recording(inner) => {
*self = Self::Locked(inner);
Ok(())
}
Self::Finished(inner) => {
*self = Self::Finished(inner);
Err(CommandEncoderError::NotRecording)
}
Self::Locked(_) => Err(CommandEncoderError::Locked),
Self::Error => Err(CommandEncoderError::Invalid),
}
}
fn unlock_encoder(&mut self) -> Result<RecordingGuard<'_>, CommandEncoderError> {
match mem::replace(self, Self::Error) {
Self::Locked(inner) => {
*self = Self::Recording(inner);
Ok(RecordingGuard { inner: self })
}
Self::Finished(inner) => {
*self = Self::Finished(inner);
Err(CommandEncoderError::NotRecording)
}
Self::Recording(_) => Err(CommandEncoderError::Invalid),
Self::Error => Err(CommandEncoderError::Invalid),
}
}
fn finish(&mut self, device: &Device) -> Result<(), CommandEncoderError> {
match mem::replace(self, Self::Error) {
Self::Recording(mut inner) => {
if let Err(e) = inner.encoder.close(device) {
Err(e.into())
} else {
*self = Self::Finished(inner);
Ok(())
}
}
Self::Finished(inner) => {
*self = Self::Finished(inner);
Err(CommandEncoderError::NotRecording)
}
Self::Locked(_) => Err(CommandEncoderError::Locked),
Self::Error => Err(CommandEncoderError::Invalid),
}
}
}
pub(crate) struct RecordingGuard<'a> {
inner: &'a mut CommandEncoderStatus,
}
impl<'a> RecordingGuard<'a> {
pub(crate) fn mark_successful(self) {
mem::forget(self)
}
}
impl<'a> Drop for RecordingGuard<'a> {
fn drop(&mut self) {
*self.inner = CommandEncoderStatus::Error;
}
}
impl<'a> std::ops::Deref for RecordingGuard<'a> {
type Target = CommandBufferMutable;
fn deref(&self) -> &Self::Target {
match &*self.inner {
CommandEncoderStatus::Recording(command_buffer_mutable) => command_buffer_mutable,
_ => unreachable!(),
}
}
}
impl<'a> std::ops::DerefMut for RecordingGuard<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self.inner {
CommandEncoderStatus::Recording(command_buffer_mutable) => command_buffer_mutable,
_ => unreachable!(),
}
}
}
pub(crate) struct CommandEncoder {
pub(crate) raw: ManuallyDrop<Box<dyn hal::DynCommandEncoder>>,
pub(crate) list: Vec<Box<dyn hal::DynCommandBuffer>>,
pub(crate) device: Arc<Device>,
pub(crate) is_open: bool,
pub(crate) hal_label: Option<String>,
}
impl CommandEncoder {
fn close_and_swap(&mut self, device: &Device) -> Result<(), DeviceError> {
if self.is_open {
self.is_open = false;
let new = unsafe { self.raw.end_encoding() }.map_err(|e| device.handle_hal_error(e))?;
self.list.insert(self.list.len() - 1, new);
}
Ok(())
}
fn close(&mut self, device: &Device) -> Result<(), DeviceError> {
if self.is_open {
self.is_open = false;
let cmd_buf =
unsafe { self.raw.end_encoding() }.map_err(|e| device.handle_hal_error(e))?;
self.list.push(cmd_buf);
}
Ok(())
}
pub(crate) fn open(
&mut self,
device: &Device,
) -> Result<&mut dyn hal::DynCommandEncoder, DeviceError> {
if !self.is_open {
self.is_open = true;
let hal_label = self.hal_label.as_deref();
unsafe { self.raw.begin_encoding(hal_label) }
.map_err(|e| device.handle_hal_error(e))?;
}
Ok(self.raw.as_mut())
}
fn open_pass(&mut self, hal_label: Option<&str>, device: &Device) -> Result<(), DeviceError> {
self.is_open = true;
unsafe { self.raw.begin_encoding(hal_label) }.map_err(|e| device.handle_hal_error(e))?;
Ok(())
}
}
impl Drop for CommandEncoder {
fn drop(&mut self) {
if self.is_open {
unsafe { self.raw.discard_encoding() };
}
unsafe {
self.raw.reset_all(mem::take(&mut self.list));
}
let raw = unsafe { ManuallyDrop::take(&mut self.raw) };
self.device.command_allocator.release_encoder(raw);
}
}
pub(crate) struct BakedCommands {
pub(crate) encoder: CommandEncoder,
pub(crate) trackers: Tracker,
buffer_memory_init_actions: Vec<BufferInitTrackerAction>,
texture_memory_actions: CommandBufferTextureMemoryActions,
}
pub struct CommandBufferMutable {
pub(crate) encoder: CommandEncoder,
pub(crate) trackers: Tracker,
buffer_memory_init_actions: Vec<BufferInitTrackerAction>,
texture_memory_actions: CommandBufferTextureMemoryActions,
pub(crate) pending_query_resets: QueryResetMap,
blas_actions: Vec<BlasAction>,
tlas_actions: Vec<TlasAction>,
#[cfg(feature = "trace")]
pub(crate) commands: Option<Vec<TraceCommand>>,
}
impl CommandBufferMutable {
pub(crate) fn open_encoder_and_tracker(
&mut self,
device: &Device,
) -> Result<(&mut dyn hal::DynCommandEncoder, &mut Tracker), DeviceError> {
let encoder = self.encoder.open(device)?;
let tracker = &mut self.trackers;
Ok((encoder, tracker))
}
pub(crate) fn into_baked_commands(self) -> BakedCommands {
BakedCommands {
encoder: self.encoder,
trackers: self.trackers,
buffer_memory_init_actions: self.buffer_memory_init_actions,
texture_memory_actions: self.texture_memory_actions,
}
}
}
pub struct CommandBuffer {
pub(crate) device: Arc<Device>,
support_clear_texture: bool,
label: String,
pub(crate) data: Mutex<CommandEncoderStatus>,
}
impl Drop for CommandBuffer {
fn drop(&mut self) {
resource_log!("Drop {}", self.error_ident());
}
}
impl CommandBuffer {
pub(crate) fn new(
encoder: Box<dyn hal::DynCommandEncoder>,
device: &Arc<Device>,
label: &Label,
) -> Self {
CommandBuffer {
device: device.clone(),
support_clear_texture: device.features.contains(wgt::Features::CLEAR_TEXTURE),
label: label.to_string(),
data: Mutex::new(
rank::COMMAND_BUFFER_DATA,
CommandEncoderStatus::Recording(CommandBufferMutable {
encoder: CommandEncoder {
raw: ManuallyDrop::new(encoder),
list: Vec::new(),
device: device.clone(),
is_open: false,
hal_label: label.to_hal(device.instance_flags).map(str::to_owned),
},
trackers: Tracker::new(),
buffer_memory_init_actions: Default::default(),
texture_memory_actions: Default::default(),
pending_query_resets: QueryResetMap::new(),
blas_actions: Default::default(),
tlas_actions: Default::default(),
#[cfg(feature = "trace")]
commands: if device.trace.lock().is_some() {
Some(Vec::new())
} else {
None
},
}),
),
}
}
pub(crate) fn new_invalid(device: &Arc<Device>, label: &Label) -> Self {
CommandBuffer {
device: device.clone(),
support_clear_texture: device.features.contains(wgt::Features::CLEAR_TEXTURE),
label: label.to_string(),
data: Mutex::new(rank::COMMAND_BUFFER_DATA, CommandEncoderStatus::Error),
}
}
pub(crate) fn insert_barriers_from_tracker(
raw: &mut dyn hal::DynCommandEncoder,
base: &mut Tracker,
head: &Tracker,
snatch_guard: &SnatchGuard,
) {
profiling::scope!("insert_barriers");
base.buffers.set_from_tracker(&head.buffers);
base.textures.set_from_tracker(&head.textures);
Self::drain_barriers(raw, base, snatch_guard);
}
pub(crate) fn insert_barriers_from_scope(
raw: &mut dyn hal::DynCommandEncoder,
base: &mut Tracker,
head: &UsageScope,
snatch_guard: &SnatchGuard,
) {
profiling::scope!("insert_barriers");
base.buffers.set_from_usage_scope(&head.buffers);
base.textures.set_from_usage_scope(&head.textures);
Self::drain_barriers(raw, base, snatch_guard);
}
pub(crate) fn drain_barriers(
raw: &mut dyn hal::DynCommandEncoder,
base: &mut Tracker,
snatch_guard: &SnatchGuard,
) {
profiling::scope!("drain_barriers");
let buffer_barriers = base
.buffers
.drain_transitions(snatch_guard)
.collect::<Vec<_>>();
let (transitions, textures) = base.textures.drain_transitions(snatch_guard);
let texture_barriers = transitions
.into_iter()
.enumerate()
.map(|(i, p)| p.into_hal(textures[i].unwrap().raw()))
.collect::<Vec<_>>();
unsafe {
raw.transition_buffers(&buffer_barriers);
raw.transition_textures(&texture_barriers);
}
}
pub(crate) fn insert_barriers_from_device_tracker(
raw: &mut dyn hal::DynCommandEncoder,
base: &mut DeviceTracker,
head: &Tracker,
snatch_guard: &SnatchGuard,
) {
profiling::scope!("insert_barriers_from_device_tracker");
let buffer_barriers = base
.buffers
.set_from_tracker_and_drain_transitions(&head.buffers, snatch_guard)
.collect::<Vec<_>>();
let texture_barriers = base
.textures
.set_from_tracker_and_drain_transitions(&head.textures, snatch_guard)
.collect::<Vec<_>>();
unsafe {
raw.transition_buffers(&buffer_barriers);
raw.transition_textures(&texture_barriers);
}
}
}
impl CommandBuffer {
pub fn take_finished<'a>(&'a self) -> Result<CommandBufferMutable, InvalidResourceError> {
let status = mem::replace(&mut *self.data.lock(), CommandEncoderStatus::Error);
match status {
CommandEncoderStatus::Finished(command_buffer_mutable) => Ok(command_buffer_mutable),
CommandEncoderStatus::Recording(_)
| CommandEncoderStatus::Locked(_)
| CommandEncoderStatus::Error => Err(InvalidResourceError(self.error_ident())),
}
}
}
crate::impl_resource_type!(CommandBuffer);
crate::impl_labeled!(CommandBuffer);
crate::impl_parent_device!(CommandBuffer);
crate::impl_storage_item!(CommandBuffer);
#[doc(hidden)]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BasePass<C> {
pub label: Option<String>,
pub commands: Vec<C>,
pub dynamic_offsets: Vec<wgt::DynamicOffset>,
pub string_data: Vec<u8>,
pub push_constant_data: Vec<u32>,
}
impl<C: Clone> BasePass<C> {
fn new(label: &Label) -> Self {
Self {
label: label.as_ref().map(|cow| cow.to_string()),
commands: Vec::new(),
dynamic_offsets: Vec::new(),
string_data: Vec::new(),
push_constant_data: Vec::new(),
}
}
}
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum CommandEncoderError {
#[error("Command encoder is invalid")]
Invalid,
#[error("Command encoder must be active")]
NotRecording,
#[error(transparent)]
Device(#[from] DeviceError),
#[error("Command encoder is locked by a previously created render/compute pass. Before recording any new commands, the pass must be ended.")]
Locked,
#[error(transparent)]
InvalidColorAttachment(#[from] ColorAttachmentError),
#[error(transparent)]
InvalidAttachment(#[from] AttachmentError),
#[error(transparent)]
InvalidResource(#[from] InvalidResourceError),
#[error(transparent)]
MissingFeatures(#[from] MissingFeatures),
#[error(
"begin and end indices of pass timestamp writes are both set to {idx}, which is not allowed"
)]
TimestampWriteIndicesEqual { idx: u32 },
#[error(transparent)]
TimestampWritesInvalid(#[from] QueryUseError),
#[error("no begin or end indices were specified for pass timestamp writes, expected at least one to be set")]
TimestampWriteIndicesMissing,
}
impl Global {
pub fn command_encoder_finish(
&self,
encoder_id: id::CommandEncoderId,
_desc: &wgt::CommandBufferDescriptor<Label>,
) -> (id::CommandBufferId, Option<CommandEncoderError>) {
profiling::scope!("CommandEncoder::finish");
let hub = &self.hub;
let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
let error = match cmd_buf.data.lock().finish(&cmd_buf.device) {
Ok(_) => None,
Err(e) => Some(e),
};
(encoder_id.into_command_buffer_id(), error)
}
pub fn command_encoder_push_debug_group(
&self,
encoder_id: id::CommandEncoderId,
label: &str,
) -> Result<(), CommandEncoderError> {
profiling::scope!("CommandEncoder::push_debug_group");
api_log!("CommandEncoder::push_debug_group {label}");
let hub = &self.hub;
let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?;
let cmd_buf_data = &mut *cmd_buf_data_guard;
#[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf_data.commands {
list.push(TraceCommand::PushDebugGroup(label.to_string()));
}
let cmd_buf_raw = cmd_buf_data.encoder.open(&cmd_buf.device)?;
if !cmd_buf
.device
.instance_flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
{
unsafe {
cmd_buf_raw.begin_debug_marker(label);
}
}
cmd_buf_data_guard.mark_successful();
Ok(())
}
pub fn command_encoder_insert_debug_marker(
&self,
encoder_id: id::CommandEncoderId,
label: &str,
) -> Result<(), CommandEncoderError> {
profiling::scope!("CommandEncoder::insert_debug_marker");
api_log!("CommandEncoder::insert_debug_marker {label}");
let hub = &self.hub;
let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?;
let cmd_buf_data = &mut *cmd_buf_data_guard;
#[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf_data.commands {
list.push(TraceCommand::InsertDebugMarker(label.to_string()));
}
if !cmd_buf
.device
.instance_flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
{
let cmd_buf_raw = cmd_buf_data.encoder.open(&cmd_buf.device)?;
unsafe {
cmd_buf_raw.insert_debug_marker(label);
}
}
cmd_buf_data_guard.mark_successful();
Ok(())
}
pub fn command_encoder_pop_debug_group(
&self,
encoder_id: id::CommandEncoderId,
) -> Result<(), CommandEncoderError> {
profiling::scope!("CommandEncoder::pop_debug_marker");
api_log!("CommandEncoder::pop_debug_group");
let hub = &self.hub;
let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?;
let cmd_buf_data = &mut *cmd_buf_data_guard;
#[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf_data.commands {
list.push(TraceCommand::PopDebugGroup);
}
let cmd_buf_raw = cmd_buf_data.encoder.open(&cmd_buf.device)?;
if !cmd_buf
.device
.instance_flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
{
unsafe {
cmd_buf_raw.end_debug_marker();
}
}
cmd_buf_data_guard.mark_successful();
Ok(())
}
fn validate_pass_timestamp_writes(
device: &Device,
query_sets: &Storage<Fallible<QuerySet>>,
timestamp_writes: &PassTimestampWrites,
) -> Result<ArcPassTimestampWrites, CommandEncoderError> {
let &PassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
} = timestamp_writes;
device.require_features(wgt::Features::TIMESTAMP_QUERY)?;
let query_set = query_sets.get(query_set).get()?;
query_set.same_device(device)?;
for idx in [beginning_of_pass_write_index, end_of_pass_write_index]
.into_iter()
.flatten()
{
query_set.validate_query(SimplifiedQueryType::Timestamp, idx, None)?;
}
if let Some((begin, end)) = beginning_of_pass_write_index.zip(end_of_pass_write_index) {
if begin == end {
return Err(CommandEncoderError::TimestampWriteIndicesEqual { idx: begin });
}
}
if beginning_of_pass_write_index
.or(end_of_pass_write_index)
.is_none()
{
return Err(CommandEncoderError::TimestampWriteIndicesMissing);
}
Ok(ArcPassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
})
}
}
fn push_constant_clear<PushFn>(offset: u32, size_bytes: u32, mut push_fn: PushFn)
where
PushFn: FnMut(u32, &[u32]),
{
let mut count_words = 0_u32;
let size_words = size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT;
while count_words < size_words {
let count_bytes = count_words * wgt::PUSH_CONSTANT_ALIGNMENT;
let size_to_write_words =
(size_words - count_words).min(PUSH_CONSTANT_CLEAR_ARRAY.len() as u32);
push_fn(
offset + count_bytes,
&PUSH_CONSTANT_CLEAR_ARRAY[0..size_to_write_words as usize],
);
count_words += size_to_write_words;
}
}
#[derive(Debug, Copy, Clone)]
struct StateChange<T> {
last_state: Option<T>,
}
impl<T: Copy + PartialEq> StateChange<T> {
fn new() -> Self {
Self { last_state: None }
}
fn set_and_check_redundant(&mut self, new_state: T) -> bool {
let already_set = self.last_state == Some(new_state);
self.last_state = Some(new_state);
already_set
}
fn reset(&mut self) {
self.last_state = None;
}
}
impl<T: Copy + PartialEq> Default for StateChange<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct BindGroupStateChange {
last_states: [StateChange<Option<id::BindGroupId>>; hal::MAX_BIND_GROUPS],
}
impl BindGroupStateChange {
fn new() -> Self {
Self {
last_states: [StateChange::new(); hal::MAX_BIND_GROUPS],
}
}
fn set_and_check_redundant(
&mut self,
bind_group_id: Option<id::BindGroupId>,
index: u32,
dynamic_offsets: &mut Vec<u32>,
offsets: &[wgt::DynamicOffset],
) -> bool {
if offsets.is_empty() {
if let Some(current_bind_group) = self.last_states.get_mut(index as usize) {
if current_bind_group.set_and_check_redundant(bind_group_id) {
return true;
}
}
} else {
if let Some(current_bind_group) = self.last_states.get_mut(index as usize) {
current_bind_group.reset();
}
dynamic_offsets.extend_from_slice(offsets);
}
false
}
fn reset(&mut self) {
self.last_states = [StateChange::new(); hal::MAX_BIND_GROUPS];
}
}
impl Default for BindGroupStateChange {
fn default() -> Self {
Self::new()
}
}
trait MapPassErr<T, O> {
fn map_pass_err(self, scope: PassErrorScope) -> Result<T, O>;
}
#[derive(Clone, Copy, Debug)]
pub enum DrawKind {
Draw,
DrawIndirect,
MultiDrawIndirect,
MultiDrawIndirectCount,
}
#[derive(Clone, Copy, Debug, Error)]
pub enum PassErrorScope {
#[error("In a bundle parameter")]
Bundle,
#[error("In a pass parameter")]
Pass,
#[error("In a set_bind_group command")]
SetBindGroup,
#[error("In a set_pipeline command")]
SetPipelineRender,
#[error("In a set_pipeline command")]
SetPipelineCompute,
#[error("In a set_push_constant command")]
SetPushConstant,
#[error("In a set_vertex_buffer command")]
SetVertexBuffer,
#[error("In a set_index_buffer command")]
SetIndexBuffer,
#[error("In a set_blend_constant command")]
SetBlendConstant,
#[error("In a set_stencil_reference command")]
SetStencilReference,
#[error("In a set_viewport command")]
SetViewport,
#[error("In a set_scissor_rect command")]
SetScissorRect,
#[error("In a draw command, kind: {kind:?}")]
Draw { kind: DrawKind, indexed: bool },
#[error("In a write_timestamp command")]
WriteTimestamp,
#[error("In a begin_occlusion_query command")]
BeginOcclusionQuery,
#[error("In a end_occlusion_query command")]
EndOcclusionQuery,
#[error("In a begin_pipeline_statistics_query command")]
BeginPipelineStatisticsQuery,
#[error("In a end_pipeline_statistics_query command")]
EndPipelineStatisticsQuery,
#[error("In a execute_bundle command")]
ExecuteBundle,
#[error("In a dispatch command, indirect:{indirect}")]
Dispatch { indirect: bool },
#[error("In a push_debug_group command")]
PushDebugGroup,
#[error("In a pop_debug_group command")]
PopDebugGroup,
#[error("In a insert_debug_marker command")]
InsertDebugMarker,
}