diff --git a/src/front/glsl/ast.rs b/src/front/glsl/ast.rs index ddeff5d5c5..306bf17958 100644 --- a/src/front/glsl/ast.rs +++ b/src/front/glsl/ast.rs @@ -644,16 +644,10 @@ impl<'function> Context<'function> { let (pointer, ptr_meta) = self.lower_expect(program, tgt, true, body)?; let (mut value, value_meta) = self.lower_expect(program, value, false, body)?; - let ptr_kind = match *program.resolve_type(self, pointer, ptr_meta)? { - TypeInner::Pointer { base, .. } => { - program.module.types[base].inner.scalar_kind() - } - TypeInner::ValuePointer { kind, .. } => Some(kind), - ref ty => ty.scalar_kind(), - }; + let scalar_components = self.expr_scalar_components(program, pointer, ptr_meta)?; - if let Some(kind) = ptr_kind { - self.implicit_conversion(program, &mut value, value_meta, kind)?; + if let Some((kind, width)) = scalar_components { + self.implicit_conversion(program, &mut value, value_meta, kind, width)?; } if let Expression::Swizzle { @@ -808,13 +802,14 @@ impl<'function> Context<'function> { Ok((Some(handle), meta)) } - pub fn expr_scalar_kind( + pub fn expr_scalar_components( &mut self, program: &mut Program, expr: Handle, meta: SourceMetadata, - ) -> Result, ErrorKind> { - Ok(program.resolve_type(self, expr, meta)?.scalar_kind()) + ) -> Result, ErrorKind> { + let ty = program.resolve_type(self, expr, meta)?; + Ok(scalar_components(ty)) } pub fn expr_power( @@ -824,8 +819,8 @@ impl<'function> Context<'function> { meta: SourceMetadata, ) -> Result, ErrorKind> { Ok(self - .expr_scalar_kind(program, expr, meta)? - .and_then(type_power)) + .expr_scalar_components(program, expr, meta)? + .and_then(|(kind, _)| type_power(kind))) } pub fn get_expression(&self, expr: Handle) -> &Expression { @@ -838,6 +833,7 @@ impl<'function> Context<'function> { expr: &mut Handle, meta: SourceMetadata, kind: ScalarKind, + width: crate::Bytes, ) -> Result<(), ErrorKind> { if let (Some(tgt_power), Some(expr_power)) = (type_power(kind), self.expr_power(program, *expr, meta)?) @@ -846,7 +842,7 @@ impl<'function> Context<'function> { *expr = self.expressions.append(Expression::As { expr: *expr, kind, - convert: None, + convert: Some(width), }) } } @@ -862,19 +858,22 @@ impl<'function> Context<'function> { right: &mut Handle, right_meta: SourceMetadata, ) -> Result<(), ErrorKind> { - let left_kind = self.expr_scalar_kind(program, *left, left_meta)?; - let right_kind = self.expr_scalar_kind(program, *right, right_meta)?; + let left_components = self.expr_scalar_components(program, *left, left_meta)?; + let right_components = self.expr_scalar_components(program, *right, right_meta)?; - if let (Some((left_power, left_kind)), Some((right_power, right_kind))) = ( - left_kind.and_then(|kind| Some((type_power(kind)?, kind))), - right_kind.and_then(|kind| Some((type_power(kind)?, kind))), + if let ( + Some((left_power, left_width, left_kind)), + Some((right_power, right_width, right_kind)), + ) = ( + left_components.and_then(|(kind, width)| Some((type_power(kind)?, width, kind))), + right_components.and_then(|(kind, width)| Some((type_power(kind)?, width, kind))), ) { match left_power.cmp(&right_power) { std::cmp::Ordering::Less => { *left = self.expressions.append(Expression::As { expr: *left, kind: right_kind, - convert: None, + convert: Some(right_width), }) } std::cmp::Ordering::Equal => {} @@ -882,7 +881,7 @@ impl<'function> Context<'function> { *right = self.expressions.append(Expression::As { expr: *right, kind: left_kind, - convert: None, + convert: Some(left_width), }) } } @@ -910,6 +909,16 @@ impl<'function> Context<'function> { } } +pub fn scalar_components(ty: &TypeInner) -> Option<(ScalarKind, crate::Bytes)> { + match *ty { + TypeInner::Scalar { kind, width } => Some((kind, width)), + TypeInner::Vector { kind, width, .. } => Some((kind, width)), + TypeInner::Matrix { width, .. } => Some((ScalarKind::Float, width)), + TypeInner::ValuePointer { kind, width, .. } => Some((kind, width)), + _ => None, + } +} + pub fn type_power(kind: ScalarKind) -> Option { Some(match kind { ScalarKind::Sint => 0, diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index 42ac28a6ad..b40a5fd334 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -75,9 +75,9 @@ impl Program<'_> { } match self.module.types[ty].inner { - TypeInner::Vector { size, kind, .. } if vector_size.is_none() => { + TypeInner::Vector { size, kind, width } if vector_size.is_none() => { let (mut value, meta) = args[0]; - ctx.implicit_conversion(self, &mut value, meta, kind)?; + ctx.implicit_conversion(self, &mut value, meta, kind, width)?; ctx.add_expression(Expression::Splat { size, value }, body) } @@ -108,13 +108,23 @@ impl Program<'_> { body, ) } - TypeInner::Matrix { columns, rows, .. } => { + TypeInner::Matrix { + columns, + rows, + width, + } => { // TODO: casts // `Expression::As` doesn't support matrix width // casts so we need to do some extra work for casts let (mut value, meta) = args[0]; - ctx.implicit_conversion(self, &mut value, meta, ScalarKind::Float)?; + ctx.implicit_conversion( + self, + &mut value, + meta, + ScalarKind::Float, + width, + )?; let column = match *self.resolve_type(ctx, args[0].0, args[0].1)? { TypeInner::Scalar { .. } => ctx .add_expression(Expression::Splat { size: rows, value }, body), @@ -176,8 +186,9 @@ impl Program<'_> { let mut components = Vec::with_capacity(args.len()); for (mut arg, meta) in args.iter().copied() { - if let Some(kind) = self.module.types[ty].inner.scalar_kind() { - ctx.implicit_conversion(self, &mut arg, meta, kind)?; + let scalar_components = scalar_components(&self.module.types[ty].inner); + if let Some((kind, width)) = scalar_components { + ctx.implicit_conversion(self, &mut arg, meta, kind, width)?; } components.push(arg) } @@ -773,8 +784,10 @@ impl Program<'_> { } } - if let Some(kind) = self.module.types[*parameter].inner.scalar_kind() { - ctx.implicit_conversion(self, &mut handle, meta, kind)?; + let scalar_components = + scalar_components(&self.module.types[*parameter].inner); + if let Some((kind, width)) = scalar_components { + ctx.implicit_conversion(self, &mut handle, meta, kind, width)?; } arguments.push(handle) diff --git a/src/front/glsl/parser.rs b/src/front/glsl/parser.rs index bb30e9bb2b..8e0a8d0452 100644 --- a/src/front/glsl/parser.rs +++ b/src/front/glsl/parser.rs @@ -1,6 +1,6 @@ use super::{ ast::{ - Context, FunctionCall, FunctionCallKind, GlobalLookup, GlobalLookupKind, HirExpr, + self, Context, FunctionCall, FunctionCallKind, GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterQualifier, Profile, StorageQualifier, StructLayout, TypeQualifier, }, error::ErrorKind, @@ -591,9 +591,16 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> { .map::, _>(|_| { let (mut expr, init_meta) = self.parse_initializer(ty, ctx.ctx, ctx.body)?; - if let Some(kind) = self.program.module.types[ty].inner.scalar_kind() { - ctx.ctx - .implicit_conversion(self.program, &mut expr, init_meta, kind)?; + let scalar_components = + ast::scalar_components(&self.program.module.types[ty].inner); + if let Some((kind, width)) = scalar_components { + ctx.ctx.implicit_conversion( + self.program, + &mut expr, + init_meta, + kind, + width, + )?; } meta = meta.union(&init_meta); diff --git a/tests/out/wgsl/246-collatz-comp.wgsl b/tests/out/wgsl/246-collatz-comp.wgsl index 08d58c28a2..0b1410dc9c 100644 --- a/tests/out/wgsl/246-collatz-comp.wgsl +++ b/tests/out/wgsl/246-collatz-comp.wgsl @@ -9,11 +9,10 @@ var gl_GlobalInvocationID: vec3; fn collatz_iterations(n: u32) -> u32 { var n1: u32; - var i: u32; + var i: u32 = 0u; var local: u32; n1 = n; - i = u32(0); loop { let _e7: u32 = n1; if (!((_e7 != u32(1)))) { diff --git a/tests/out/wgsl/constant-array-size-vert.wgsl b/tests/out/wgsl/constant-array-size-vert.wgsl index 96c2472e1c..5a7aba0abc 100644 --- a/tests/out/wgsl/constant-array-size-vert.wgsl +++ b/tests/out/wgsl/constant-array-size-vert.wgsl @@ -7,11 +7,10 @@ struct Data { var global: Data; fn function() -> vec4 { - var sum: vec4; + var sum: vec4 = vec4(0.0, 0.0, 0.0, 0.0); var i: i32 = 0; var local: i32; - sum = vec4(f32(0)); loop { let _e9: i32 = i; if (!((_e9 < 42))) {