Clean up the storage classes (#245)

This commit is contained in:
Dzmitry Malyshau
2020-10-27 09:10:55 -04:00
committed by GitHub
parent c1830901c7
commit ce49afa391
14 changed files with 184 additions and 112 deletions

View File

@@ -112,6 +112,7 @@ bitflags::bitflags! {
const IMAGE_LOAD_STORE = 1 << 8;
const CONSERVATIVE_DEPTH = 1 << 9;
const TEXTURE_1D = 1 << 10;
const PUSH_CONSTANT = 1 << 11;
}
}
@@ -364,7 +365,7 @@ pub fn write<'a>(
}
let block = match global.class {
StorageClass::StorageBuffer | StorageClass::Uniform => true,
StorageClass::Storage | StorageClass::Uniform => true,
_ => false,
};
@@ -557,14 +558,15 @@ pub fn write<'a>(
let name = if let Some(ref binding) = global.binding {
let prefix = match global.class {
StorageClass::Constant => "const",
StorageClass::Function => "fn",
StorageClass::Input => "in",
StorageClass::Output => "out",
StorageClass::Private => "priv",
StorageClass::StorageBuffer => "buffer",
StorageClass::Storage => "buffer",
StorageClass::Uniform => "uniform",
StorageClass::Handle => "handle",
StorageClass::WorkGroup => "wg",
StorageClass::PushConstant => "pc",
};
match binding {
@@ -606,7 +608,7 @@ pub fn write<'a>(
}
let block = match global.class {
StorageClass::StorageBuffer | StorageClass::Uniform => {
StorageClass::Storage | StorageClass::Uniform => {
Some(format!("global_block_{}", handle.index()))
}
_ => None,
@@ -1492,22 +1494,24 @@ fn write_storage_class(
manager: &mut FeaturesManager,
) -> Result<&'static str, Error> {
Ok(match class {
StorageClass::Constant => "",
StorageClass::Function => "",
StorageClass::Input => "in ",
StorageClass::Output => "out ",
StorageClass::Private => "",
StorageClass::StorageBuffer => {
StorageClass::Storage => {
manager.request(Features::BUFFER_STORAGE);
"buffer "
}
StorageClass::Uniform => "uniform ",
StorageClass::Handle => "uniform ",
StorageClass::WorkGroup => {
manager.request(Features::COMPUTE_SHADER);
"shared "
}
StorageClass::PushConstant => {
manager.request(Features::PUSH_CONSTANT);
""
}
})
}

View File

@@ -254,9 +254,7 @@ impl<'a> TypedGlobalVariable<'a> {
let (space_qualifier, reference) = match ty.inner {
crate::TypeInner::Struct { .. } => match var.class {
crate::StorageClass::Constant
| crate::StorageClass::Uniform
| crate::StorageClass::StorageBuffer => {
crate::StorageClass::Uniform | crate::StorageClass::Storage => {
let space = if self.usage.contains(crate::GlobalUse::STORE) {
"device "
} else {
@@ -837,9 +835,13 @@ impl<W: Write> Writer<W> {
let base_name = module.types[base].name.or_index(base);
let class_name = match class {
Sc::Input | Sc::Output => continue,
Sc::Constant | Sc::Uniform => "constant",
Sc::StorageBuffer => "device",
Sc::Private | Sc::Function | Sc::WorkGroup => "",
Sc::Uniform => "constant",
Sc::Storage => "device",
Sc::Handle
| Sc::Private
| Sc::Function
| Sc::WorkGroup
| Sc::PushConstant => "",
};
write!(self.out, "typedef {} {} *{}", class_name, base_name, name)?;
}

View File

@@ -24,6 +24,10 @@ impl PhysicalLayout {
sink.extend(iter::once(self.bound));
sink.extend(iter::once(self.instruction_schema));
}
pub(super) fn supports_storage_buffers(&self) -> bool {
self.version >= 0x10300
}
}
impl LogicalLayout {

View File

@@ -417,14 +417,19 @@ impl Writer {
fn parse_to_spirv_storage_class(&self, class: crate::StorageClass) -> spirv::StorageClass {
match class {
crate::StorageClass::Constant => spirv::StorageClass::UniformConstant,
crate::StorageClass::Handle => spirv::StorageClass::UniformConstant,
crate::StorageClass::Function => spirv::StorageClass::Function,
crate::StorageClass::Input => spirv::StorageClass::Input,
crate::StorageClass::Output => spirv::StorageClass::Output,
crate::StorageClass::Private => spirv::StorageClass::Private,
crate::StorageClass::StorageBuffer => spirv::StorageClass::StorageBuffer,
crate::StorageClass::Uniform => spirv::StorageClass::Uniform,
crate::StorageClass::Storage if self.physical_layout.supports_storage_buffers() => {
spirv::StorageClass::StorageBuffer
}
crate::StorageClass::Storage | crate::StorageClass::Uniform => {
spirv::StorageClass::Uniform
}
crate::StorageClass::WorkGroup => spirv::StorageClass::Workgroup,
crate::StorageClass::PushConstant => spirv::StorageClass::PushConstant,
}
}

View File

@@ -596,9 +596,7 @@ pomelo! {
// single_type_qualifier ::= invariant_qualifier;
// single_type_qualifier ::= precise_qualifier;
storage_qualifier ::= Const {
StorageClass::Constant
}
// storage_qualifier ::= Const
// storage_qualifier ::= InOut;
storage_qualifier ::= In {
StorageClass::Input

View File

@@ -43,21 +43,6 @@ pub fn map_vector_size(word: spirv::Word) -> Result<crate::VectorSize, Error> {
}
}
pub fn map_storage_class(word: spirv::Word) -> Result<crate::StorageClass, Error> {
use spirv::StorageClass as Sc;
match Sc::from_u32(word) {
Some(Sc::UniformConstant) => Ok(crate::StorageClass::Constant),
Some(Sc::Function) => Ok(crate::StorageClass::Function),
Some(Sc::Input) => Ok(crate::StorageClass::Input),
Some(Sc::Output) => Ok(crate::StorageClass::Output),
Some(Sc::Private) => Ok(crate::StorageClass::Private),
Some(Sc::StorageBuffer) => Ok(crate::StorageClass::StorageBuffer),
Some(Sc::Uniform) => Ok(crate::StorageClass::Uniform),
Some(Sc::Workgroup) => Ok(crate::StorageClass::WorkGroup),
_ => Err(Error::UnsupportedStorageClass(word)),
}
}
pub fn map_image_dim(word: spirv::Word) -> Result<crate::ImageDimension, Error> {
use spirv::Dim as D;
match D::from_u32(word) {

View File

@@ -2274,19 +2274,39 @@ impl<I: Iterator<Item = u32>> Parser<I> {
.future_decor
.remove(&id)
.ok_or(Error::InvalidBinding(id))?;
let class = map_storage_class(storage_class)?;
let class = {
use spirv::StorageClass as Sc;
match Sc::from_u32(storage_class) {
Some(Sc::Function) => crate::StorageClass::Function,
Some(Sc::Input) => crate::StorageClass::Input,
Some(Sc::Output) => crate::StorageClass::Output,
Some(Sc::Private) => crate::StorageClass::Private,
Some(Sc::UniformConstant) => crate::StorageClass::Handle,
Some(Sc::StorageBuffer) => crate::StorageClass::Storage,
Some(Sc::Uniform) => {
if self
.lookup_storage_buffer_types
.contains(&lookup_type.handle)
{
crate::StorageClass::Storage
} else {
crate::StorageClass::Uniform
}
}
Some(Sc::Workgroup) => crate::StorageClass::WorkGroup,
Some(Sc::PushConstant) => crate::StorageClass::PushConstant,
_ => return Err(Error::UnsupportedStorageClass(storage_class)),
}
};
let binding = match (class, &module.types[lookup_type.handle].inner) {
(crate::StorageClass::Input, &crate::TypeInner::Struct { .. })
| (crate::StorageClass::Output, &crate::TypeInner::Struct { .. }) => None,
_ => Some(dec.get_binding().ok_or(Error::InvalidBinding(id))?),
};
let is_storage = match module.types[lookup_type.handle].inner {
crate::TypeInner::Struct { .. } => match class {
crate::StorageClass::StorageBuffer => true,
_ => self
.lookup_storage_buffer_types
.contains(&lookup_type.handle),
},
crate::TypeInner::Struct { .. } => class == crate::StorageClass::Storage,
crate::TypeInner::Image {
class: crate::ImageClass::Storage(_),
..

View File

@@ -4,8 +4,9 @@ pub fn map_storage_class(word: &str) -> Result<crate::StorageClass, Error<'_>> {
match word {
"in" => Ok(crate::StorageClass::Input),
"out" => Ok(crate::StorageClass::Output),
"private" => Ok(crate::StorageClass::Private),
"uniform" => Ok(crate::StorageClass::Uniform),
"storage_buffer" => Ok(crate::StorageClass::StorageBuffer),
"storage" => Ok(crate::StorageClass::Storage),
_ => Err(Error::UnknownStorageClass(word)),
}
}

View File

@@ -930,8 +930,17 @@ impl Parser {
lexer.expect(Token::Separator(':'))?;
let ty = self.parse_type_decl(lexer, None, type_arena)?;
let access = match class {
Some(crate::StorageClass::StorageBuffer) => crate::StorageAccess::all(),
Some(crate::StorageClass::Constant) => crate::StorageAccess::LOAD,
Some(crate::StorageClass::Storage) => crate::StorageAccess::all(),
Some(crate::StorageClass::Handle) => {
match type_arena[ty].inner {
//TODO: RW textures
crate::TypeInner::Image {
class: crate::ImageClass::Storage(_),
..
} => crate::StorageAccess::LOAD,
_ => crate::StorageAccess::empty(),
}
}
_ => crate::StorageAccess::empty(),
};
if lexer.skip(Token::Operation('=')) {
@@ -1708,7 +1717,7 @@ impl Parser {
crate::BuiltIn::Position => crate::StorageClass::Output,
_ => unimplemented!(),
},
_ => crate::StorageClass::Private,
_ => crate::StorageClass::Handle,
},
};
let var_handle = module.global_variables.append(crate::GlobalVariable {

View File

@@ -108,14 +108,24 @@ pub enum ShaderStage {
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[allow(missing_docs)] // The names are self evident
pub enum StorageClass {
Constant,
/// Function locals.
Function,
/// Pipeline input, per invocation.
Input,
/// Pipeline output, per invocation, mutable.
Output,
/// Private data, per invocation, mutable.
Private,
StorageBuffer,
Uniform,
/// Workgroup shared data, mutable.
WorkGroup,
/// Uniform buffer data.
Uniform,
/// Storage buffer data, potentially mutable.
Storage,
/// Opaque handles, such as samplers and images.
Handle,
/// Push constants.
PushConstant,
}
/// Built-in inputs and outputs.

View File

@@ -218,7 +218,7 @@ mod tests {
fn global_use_scan() {
let test_global = GlobalVariable {
name: None,
class: StorageClass::Constant,
class: StorageClass::Uniform,
binding: None,
ty: Handle::new(std::num::NonZeroU32::new(1).unwrap()),
interpolation: None,

View File

@@ -21,8 +21,11 @@ pub enum GlobalVariableError {
InvalidType,
#[error("Interpolation is not valid")]
InvalidInterpolation,
#[error("Storage access flags are invalid")]
InvalidStorageAccess,
#[error("Storage access {seen:?} exceed the allowed {allowed:?}")]
InvalidStorageAccess {
allowed: crate::StorageAccess,
seen: crate::StorageAccess,
},
#[error("Binding decoration is missing or not applicable")]
InvalidBinding,
#[error("Binding is out of range")]
@@ -63,8 +66,12 @@ pub enum ValidationError {
InvalidTypeWidth(crate::ScalarKind, crate::Bytes),
#[error("The type handle {0:?} can not be resolved")]
UnresolvedType(Handle<crate::Type>),
#[error("Global variable {0:?} is invalid: {1:?}")]
GlobalVariable(Handle<crate::GlobalVariable>, GlobalVariableError),
#[error("Global variable {handle:?} '{name}' is invalid: {error:?}")]
GlobalVariable {
handle: Handle<crate::GlobalVariable>,
name: String,
error: GlobalVariableError,
},
#[error("Function {0:?} is invalid: {1:?}")]
Function(Handle<crate::Function>, FunctionError),
#[error("Entry point {name} at {stage:?} is invalid: {error:?}")]
@@ -75,6 +82,41 @@ pub enum ValidationError {
},
}
impl crate::GlobalVariable {
fn forbid_interpolation(&self) -> Result<(), GlobalVariableError> {
match self.interpolation {
Some(_) => Err(GlobalVariableError::InvalidInterpolation),
None => Ok(()),
}
}
fn check_resource(&self) -> Result<(), GlobalVariableError> {
match self.binding {
Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point
Some(crate::Binding::Resource { group, binding }) => {
if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES {
return Err(GlobalVariableError::OutOfRangeBinding);
}
}
Some(crate::Binding::Location(_)) | None => {
return Err(GlobalVariableError::InvalidBinding)
}
}
self.forbid_interpolation()
}
}
fn storage_usage(access: crate::StorageAccess) -> crate::GlobalUse {
let mut storage_usage = crate::GlobalUse::empty();
if access.contains(crate::StorageAccess::LOAD) {
storage_usage |= crate::GlobalUse::LOAD;
}
if access.contains(crate::StorageAccess::STORE) {
storage_usage |= crate::GlobalUse::STORE;
}
storage_usage
}
impl Validator {
/// Construct a new validator instance.
pub fn new() -> Self {
@@ -89,15 +131,13 @@ impl Validator {
types: &Arena<crate::Type>,
) -> Result<(), GlobalVariableError> {
log::debug!("var {:?}", var);
let is_storage = match var.class {
let allowed_storage_access = match var.class {
crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage),
crate::StorageClass::Input | crate::StorageClass::Output => {
match var.binding {
Some(crate::Binding::BuiltIn(_)) => {
// validated per entry point
if var.interpolation.is_some() {
return Err(GlobalVariableError::InvalidInterpolation);
}
var.forbid_interpolation()?
}
Some(crate::Binding::Location(loc)) => {
if loc > MAX_LOCATIONS {
@@ -117,56 +157,53 @@ impl Validator {
match types[var.ty].inner {
//TODO: check the member types
crate::TypeInner::Struct { members: _ } => {
if var.interpolation.is_some() {
return Err(GlobalVariableError::InvalidInterpolation);
}
var.forbid_interpolation()?
}
_ => return Err(GlobalVariableError::InvalidType),
}
}
}
false
crate::StorageAccess::empty()
}
crate::StorageClass::Constant
| crate::StorageClass::StorageBuffer
| crate::StorageClass::Uniform => {
match var.binding {
Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point
Some(crate::Binding::Resource { group, binding }) => {
if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES {
return Err(GlobalVariableError::OutOfRangeBinding);
}
}
Some(crate::Binding::Location(_)) | None => {
return Err(GlobalVariableError::InvalidBinding)
}
}
if var.interpolation.is_some() {
return Err(GlobalVariableError::InvalidInterpolation);
}
//TODO: prevent `Uniform` storage class with `STORE` access
crate::StorageClass::Storage => {
var.check_resource()?;
crate::StorageAccess::all()
}
crate::StorageClass::Uniform => {
var.check_resource()?;
crate::StorageAccess::empty()
}
crate::StorageClass::Handle => {
var.check_resource()?;
match types[var.ty].inner {
crate::TypeInner::Struct { .. }
| crate::TypeInner::Image {
crate::TypeInner::Image {
class: crate::ImageClass::Storage(_),
..
} => true,
_ => false,
} => crate::StorageAccess::all(),
_ => crate::StorageAccess::empty(),
}
}
crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
if var.binding.is_some() {
return Err(GlobalVariableError::InvalidBinding);
}
if var.interpolation.is_some() {
return Err(GlobalVariableError::InvalidInterpolation);
}
false
var.forbid_interpolation()?;
crate::StorageAccess::empty()
}
crate::StorageClass::PushConstant => {
//TODO
return Err(GlobalVariableError::InvalidStorageAccess {
allowed: crate::StorageAccess::empty(),
seen: crate::StorageAccess::empty(),
});
}
};
if !is_storage && !var.storage_access.is_empty() {
return Err(GlobalVariableError::InvalidStorageAccess);
if !allowed_storage_access.contains(var.storage_access) {
return Err(GlobalVariableError::InvalidStorageAccess {
allowed: allowed_storage_access,
seen: var.storage_access,
});
}
Ok(())
@@ -291,26 +328,19 @@ impl Validator {
location_out_mask |= mask;
crate::GlobalUse::LOAD | crate::GlobalUse::STORE
}
crate::StorageClass::Constant => crate::GlobalUse::LOAD,
crate::StorageClass::Uniform | crate::StorageClass::StorageBuffer => {
//TODO: built-in checks?
let mut storage_usage = crate::GlobalUse::empty();
if var.storage_access.contains(crate::StorageAccess::LOAD) {
storage_usage |= crate::GlobalUse::LOAD;
}
if var.storage_access.contains(crate::StorageAccess::STORE) {
storage_usage |= crate::GlobalUse::STORE;
}
if storage_usage.is_empty() {
// its a uniform buffer
crate::GlobalUse::LOAD
} else {
storage_usage
}
}
crate::StorageClass::Uniform => crate::GlobalUse::LOAD,
crate::StorageClass::Storage => storage_usage(var.storage_access),
crate::StorageClass::Handle => match module.types[var.ty].inner {
crate::TypeInner::Image {
class: crate::ImageClass::Storage(_),
..
} => storage_usage(var.storage_access),
_ => crate::GlobalUse::LOAD,
},
crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
crate::GlobalUse::all()
}
crate::StorageClass::PushConstant => crate::GlobalUse::LOAD,
};
if !allowed_usage.contains(usage) {
log::warn!("\tUsage error for: {:?}", var);
@@ -384,7 +414,11 @@ impl Validator {
for (var_handle, var) in module.global_variables.iter() {
self.validate_global_var(var, &module.types)
.map_err(|e| ValidationError::GlobalVariable(var_handle, e))?;
.map_err(|error| ValidationError::GlobalVariable {
handle: var_handle,
name: var.name.clone().unwrap_or_default(),
error,
})?;
}
for (fun_handle, fun) in module.functions.iter() {

View File

@@ -61,8 +61,8 @@ type Particles = struct {
};
[[group(0), binding(0)]] var<uniform> params : SimParams;
[[group(0), binding(1)]] var<storage_buffer> particlesA : Particles;
[[group(0), binding(2)]] var<storage_buffer> particlesB : Particles;
[[group(0), binding(1)]] var<storage> particlesA : Particles;
[[group(0), binding(2)]] var<storage> particlesB : Particles;
[[builtin(global_invocation_id)]] var gl_GlobalInvocationID : vec3<u32>;

View File

@@ -14,8 +14,8 @@ fn main() -> void {
# fragment
[[location(0)]] var<in> v_uv : vec2<f32>;
[[group(0), binding(0)]] var<uniform> u_texture : texture_sampled_2d<f32>;
[[group(0), binding(1)]] var<uniform> u_sampler : sampler;
[[group(0), binding(0)]] var u_texture : texture_sampled_2d<f32>;
[[group(0), binding(1)]] var u_sampler : sampler;
[[location(0)]] var<out> o_color : vec4<f32>;
[[stage(fragment)]]