From d370686351444f871bf2d4f886b32a19820cec8c Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 28 Aug 2020 00:44:52 -0400 Subject: [PATCH] proc: support parameter types in Typifier --- src/arena.rs | 28 ++++--- src/front/glsl/mod.rs | 73 ++++++++++++++----- src/front/wgsl/mod.rs | 61 ++++++++++------ src/proc/typifier.rs | 165 ++++++++++++++++++++++-------------------- 4 files changed, 192 insertions(+), 135 deletions(-) diff --git a/src/arena.rs b/src/arena.rs index 3d7304644f..210133e1bf 100644 --- a/src/arena.rs +++ b/src/arena.rs @@ -9,8 +9,8 @@ use std::{cmp::Ordering, fmt, hash, marker::PhantomData, num::NonZeroU32}; type Index = NonZeroU32; /// A strongly typed reference to a SPIR-V element. -#[cfg_attr(feature = "serialize", derive(crate::Serialize))] -#[cfg_attr(feature = "deserialize", derive(crate::Deserialize))] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr( any(feature = "serialize", feature = "deserialize"), serde(transparent) @@ -85,8 +85,8 @@ impl Handle { /// The arena can be indexed using the given handle to obtain /// a reference to the stored item. #[derive(Debug)] -#[cfg_attr(feature = "serialize", derive(crate::Serialize))] -#[cfg_attr(feature = "deserialize", derive(crate::Deserialize))] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr( any(feature = "serialize", feature = "deserialize"), serde(transparent) @@ -138,14 +138,12 @@ impl Arena { Handle::new(index) } - /// Adds a value with a check for uniqueness: returns a handle pointing to - /// an existing element if its value matches the given one, or adds a new + /// Adds a value with a custom check for uniqueness: + /// returns a handle pointing to + /// an existing element if the check succeeds, or adds a new /// element otherwise. - pub fn fetch_or_append(&mut self, value: T) -> Handle - where - T: PartialEq, - { - if let Some(index) = self.data.iter().position(|d| d == &value) { + pub fn fetch_if_or_append bool>(&mut self, value: T, fun: F) -> Handle { + if let Some(index) = self.data.iter().position(|d| fun(d, &value)) { let index = unsafe { Index::new_unchecked((index + 1) as u32) }; Handle::new(index) } else { @@ -153,6 +151,14 @@ impl Arena { } } + /// Adds a value with a check for uniqueness, where the check is plain comparison. + pub fn fetch_or_append(&mut self, value: T) -> Handle + where + T: PartialEq, + { + self.fetch_if_or_append(value, T::eq) + } + /// Get a mutable reference to an element in the arena. pub fn get_mut(&mut self, handle: Handle) -> &mut T { self.data.get_mut(handle.index.get() as usize - 1).unwrap() diff --git a/src/front/glsl/mod.rs b/src/front/glsl/mod.rs index b4061233b8..6bc39d63c6 100644 --- a/src/front/glsl/mod.rs +++ b/src/front/glsl/mod.rs @@ -151,6 +151,7 @@ impl<'a> Parser<'a> { &mut locals, &mut locals_map, ¶meter_lookup, + &[], )?; let handle = expressions.append(expr); let val = self.eval_const_expr(handle, &expressions)?; @@ -174,7 +175,7 @@ impl<'a> Parser<'a> { let mut index = 0; for field in block.fields { - let ty = self.parse_type(field.ty).unwrap(); + let ty = self.parse_type(field.ty, &[]).unwrap(); for ident in field.identifiers { let field_name = ident.ident.0; @@ -184,7 +185,7 @@ impl<'a> Parser<'a> { name: Some(field_name.clone()), origin, ty: if let Some(array_spec) = ident.array_spec { - let size = self.parse_array_size(array_spec)?; + let size = self.parse_array_size(array_spec, &[])?; self.types.fetch_or_append(Type { name: None, inner: TypeInner::Array { @@ -213,7 +214,7 @@ impl<'a> Parser<'a> { inner: TypeInner::Struct { members: fields }, }); - let size = self.parse_array_size(array_spec)?; + let size = self.parse_array_size(array_spec, &[])?; self.types.fetch_or_append(Type { name: None, inner: TypeInner::Array { @@ -271,7 +272,7 @@ impl<'a> Parser<'a> { let name = function.prototype.name.0; // Parse return type - let ty = self.parse_type(function.prototype.ty.ty); + let ty = self.parse_type(function.prototype.ty.ty, &[]); let mut parameter_types = Vec::with_capacity(function.prototype.parameters.len()); let mut parameter_lookup = FastHashMap::default(); @@ -285,10 +286,10 @@ impl<'a> Parser<'a> { for (index, parameter) in function.prototype.parameters.into_iter().enumerate() { match parameter { FunctionParameterDeclaration::Named(_ /* TODO */, decl) => { - let ty = self.parse_type(decl.ty).unwrap(); + let ty = self.parse_type(decl.ty, &[]).unwrap(); let ty = if let Some(array_spec) = decl.ident.array_spec { - let size = self.parse_array_size(array_spec)?; + let size = self.parse_array_size(array_spec, &[])?; self.types.fetch_or_append(Type { name: None, inner: TypeInner::Array { @@ -308,7 +309,7 @@ impl<'a> Parser<'a> { ); } FunctionParameterDeclaration::Unnamed(_, ty) => { - parameter_types.push(self.parse_type(ty).unwrap()); + parameter_types.push(self.parse_type(ty, &[]).unwrap()); } } } @@ -325,6 +326,7 @@ impl<'a> Parser<'a> { &mut local_variables, &mut locals_map, ¶meter_lookup, + ¶meter_types, )?; } _ => unimplemented!(), @@ -336,6 +338,7 @@ impl<'a> Parser<'a> { &mut local_variables, &mut locals_map, ¶meter_lookup, + ¶meter_types, )?); } SimpleStatement::Expression(None) => (), @@ -355,6 +358,7 @@ impl<'a> Parser<'a> { &mut local_variables, &mut locals_map, ¶meter_lookup, + ¶meter_types, ) .unwrap(); expressions.append(expr) @@ -385,13 +389,14 @@ impl<'a> Parser<'a> { locals: &mut Arena, locals_map: &mut FastHashMap>, parameter_lookup: &FastHashMap, + parameter_types: &[Handle], ) -> Result, Error> { let name = init.head.name.map(|d| d.0); let ty = { - let ty = self.parse_type(init.head.ty.ty).unwrap(); + let ty = self.parse_type(init.head.ty.ty, parameter_types).unwrap(); if let Some(array_spec) = init.head.array_specifier { - let size = self.parse_array_size(array_spec)?; + let size = self.parse_array_size(array_spec, parameter_types)?; self.types.fetch_or_append(Type { name: None, inner: TypeInner::Array { @@ -412,6 +417,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, )?) } else { None @@ -435,6 +441,7 @@ impl<'a> Parser<'a> { locals: &mut Arena, locals_map: &mut FastHashMap>, parameter_lookup: &FastHashMap, + parameter_types: &[Handle], ) -> Result, Error> { match initializer { Initializer::Simple(expr) => { @@ -444,6 +451,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, )?; Ok(expressions.append(handle)) @@ -459,6 +467,7 @@ impl<'a> Parser<'a> { locals: &mut Arena, locals_map: &mut FastHashMap>, parameter_lookup: &FastHashMap, + parameter_types: &[Handle], ) -> Result { match expr { Expr::Assignment(reg, op, value) => { @@ -469,6 +478,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, )?; expressions.append(pointer) }; @@ -479,6 +489,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, )?; let value = match op { AssignmentOp::Equal => right, @@ -555,6 +566,7 @@ impl<'a> Parser<'a> { locals: &mut Arena, locals_map: &mut FastHashMap>, parameter_lookup: &FastHashMap, + parameter_types: &[Handle], ) -> Result { match expr { Expr::Variable(ident) => { @@ -765,8 +777,14 @@ impl<'a> Parser<'a> { }, ))), Expr::Unary(op, reg) => { - let expr = - self.parse_expression(*reg, expressions, locals, locals_map, parameter_lookup)?; + let expr = self.parse_expression( + *reg, + expressions, + locals, + locals_map, + parameter_lookup, + parameter_types, + )?; Ok(Expression::Unary { op: helpers::glsl_to_spirv_unary_op(op), expr: expressions.append(expr), @@ -779,6 +797,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, )?; let right = self.parse_expression( *right, @@ -786,6 +805,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, )?; Ok(Expression::Binary { @@ -828,6 +848,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, ) .unwrap(); expressions.append(expr) @@ -850,6 +871,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, )?, self.parse_expression( sample_args.remove(0), @@ -857,6 +879,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, )?, ), _ => unimplemented!(), @@ -871,6 +894,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, )?; Ok(Expression::ImageSample { @@ -893,6 +917,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, ) .unwrap(); expressions.append(expr) @@ -909,6 +934,7 @@ impl<'a> Parser<'a> { locals, locals_map, parameter_lookup, + parameter_types, )?; expressions.append(expr) }; @@ -924,6 +950,7 @@ impl<'a> Parser<'a> { &self.globals, locals, &self.functions, + parameter_types, ) .map_err(|e| Error { kind: e.into() })?; let base_type = &self.types[type_handle]; @@ -977,10 +1004,7 @@ impl<'a> Parser<'a> { crate::TypeInner::Vector { size, kind, width } }; Ok(crate::Expression::Compose { - ty: crate::proc::Typifier::deduce_type_handle( - inner, - &mut self.types, - ), + ty: self.types.fetch_or_append(Type { name: None, inner }), components, }) } else { @@ -1005,7 +1029,11 @@ impl<'a> Parser<'a> { } // None = void - fn parse_type(&mut self, ty: TypeSpecifier) -> Option> { + fn parse_type( + &mut self, + ty: TypeSpecifier, + parameter_types: &[Handle], + ) -> Option> { let base_ty = helpers::glsl_to_spirv_type(ty.ty)?; let ty = if let Some(array_spec) = ty.array_specifier { @@ -1013,7 +1041,7 @@ impl<'a> Parser<'a> { name: None, inner: base_ty, }); - let size = self.parse_array_size(array_spec).unwrap(); + let size = self.parse_array_size(array_spec, parameter_types).unwrap(); TypeInner::Array { base: handle, @@ -1033,10 +1061,10 @@ impl<'a> Parser<'a> { fn parse_global(&mut self, head: SingleDeclaration) -> Result, Error> { let name = head.name.map(|d| d.0); let ty = { - let ty = self.parse_type(head.ty.ty).unwrap(); + let ty = self.parse_type(head.ty.ty, &[]).unwrap(); if let Some(array_spec) = head.array_specifier { - let size = self.parse_array_size(array_spec)?; + let size = self.parse_array_size(array_spec, &[])?; self.types.fetch_or_append(Type { name: None, inner: TypeInner::Array { @@ -1137,7 +1165,11 @@ impl<'a> Parser<'a> { } } - pub fn parse_array_size(&mut self, array_spec: ArraySpecifier) -> Result { + pub fn parse_array_size( + &mut self, + array_spec: ArraySpecifier, + parameter_types: &[Handle], + ) -> Result { let parameter_lookup = FastHashMap::default(); let mut locals = Arena::::new(); let mut locals_map = FastHashMap::default(); @@ -1151,6 +1183,7 @@ impl<'a> Parser<'a> { &mut locals, &mut locals_map, ¶meter_lookup, + parameter_types, )?; let handle = expressions.append(expr); diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index cd57c57677..4e34671b16 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -97,6 +97,7 @@ struct StatementContext<'input, 'temp, 'out> { types: &'out mut Arena, constants: &'out mut Arena, global_vars: &'out Arena, + parameter_types: &'out [Handle], } impl<'a> StatementContext<'a, '_, '_> { @@ -109,6 +110,7 @@ impl<'a> StatementContext<'a, '_, '_> { types: self.types, constants: self.constants, global_vars: self.global_vars, + parameter_types: self.parameter_types, } } @@ -121,6 +123,7 @@ impl<'a> StatementContext<'a, '_, '_> { constants: self.constants, global_vars: self.global_vars, local_vars: self.variables, + parameter_types: self.parameter_types, } } } @@ -133,6 +136,7 @@ struct ExpressionContext<'input, 'temp, 'out> { constants: &'out mut Arena, global_vars: &'out Arena, local_vars: &'out Arena, + parameter_types: &'out [Handle], } impl<'a> ExpressionContext<'a, '_, '_> { @@ -145,6 +149,7 @@ impl<'a> ExpressionContext<'a, '_, '_> { constants: self.constants, global_vars: self.global_vars, local_vars: self.local_vars, + parameter_types: self.parameter_types, } } @@ -160,7 +165,8 @@ impl<'a> ExpressionContext<'a, '_, '_> { self.constants, self.global_vars, self.local_vars, - &Arena::new(), + &Arena::new(), //TODO + self.parameter_types, ) .map_err(Error::InvalidResolve) } @@ -376,7 +382,7 @@ impl Parser { inner } _ => { - let composite_ty = self.parse_type_decl(lexer, type_arena)?; + let composite_ty = self.parse_type_decl(lexer, None, type_arena)?; lexer.expect(Token::Paren('('))?; let mut components = Vec::new(); while !lexer.skip(Token::Paren(')')) { @@ -422,13 +428,13 @@ impl Parser { name: None, specialization: None, inner: crate::ConstantInner::Bool(true), - ty: Typifier::deduce_type_handle( - crate::TypeInner::Scalar { + ty: ctx.types.fetch_or_append(crate::Type { + name: None, + inner: crate::TypeInner::Scalar { kind: crate::ScalarKind::Bool, width: 1, }, - ctx.types, - ), + }), }); crate::Expression::Constant(handle) } @@ -437,13 +443,13 @@ impl Parser { name: None, specialization: None, inner: crate::ConstantInner::Bool(false), - ty: Typifier::deduce_type_handle( - crate::TypeInner::Scalar { + ty: ctx.types.fetch_or_append(crate::Type { + name: None, + inner: crate::TypeInner::Scalar { kind: crate::ScalarKind::Bool, width: 1, }, - ctx.types, - ), + }), }); crate::Expression::Constant(handle) } @@ -453,10 +459,10 @@ impl Parser { name: None, specialization: None, inner, - ty: Typifier::deduce_type_handle( - crate::TypeInner::Scalar { kind, width: 4 }, - ctx.types, - ), + ty: ctx.types.fetch_or_append(crate::Type { + name: None, + inner: crate::TypeInner::Scalar { kind, width: 4 }, + }), }); crate::Expression::Constant(handle) } @@ -472,7 +478,7 @@ impl Parser { expr } else { *lexer = backup; - let ty = self.parse_type_decl(lexer, ctx.types)?; + let ty = self.parse_type_decl(lexer, None, ctx.types)?; lexer.expect(Token::Paren('('))?; let mut components = Vec::new(); while !lexer.skip(Token::Paren(')')) { @@ -556,7 +562,9 @@ impl Parser { crate::TypeInner::Vector { size, kind, width } }; crate::Expression::Compose { - ty: Typifier::deduce_type_handle(inner, ctx.types), + ty: ctx + .types + .fetch_or_append(crate::Type { name: None, inner }), components, } } else { @@ -823,7 +831,7 @@ impl Parser { ) -> Result<(&'a str, Handle), Error<'a>> { let name = lexer.next_ident()?; lexer.expect(Token::Separator(':'))?; - let ty = self.parse_type_decl(lexer, type_arena)?; + let ty = self.parse_type_decl(lexer, None, type_arena)?; Ok((name, ty)) } @@ -842,7 +850,7 @@ impl Parser { } let name = lexer.next_ident()?; lexer.expect(Token::Separator(':'))?; - let ty = self.parse_type_decl(lexer, type_arena)?; + let ty = self.parse_type_decl(lexer, None, type_arena)?; if lexer.skip(Token::Operation('=')) { let _inner = self.parse_const_expression(lexer, type_arena, const_arena)?; //TODO @@ -890,7 +898,7 @@ impl Parser { return Err(Error::MissingMemberOffset(name)); } lexer.expect(Token::Separator(':'))?; - let ty = self.parse_type_decl(lexer, type_arena)?; + let ty = self.parse_type_decl(lexer, None, type_arena)?; lexer.expect(Token::Separator(';'))?; members.push(crate::StructMember { name: Some(name.to_owned()), @@ -903,6 +911,7 @@ impl Parser { fn parse_type_decl<'a>( &mut self, lexer: &mut Lexer<'a>, + self_name: Option<&'a str>, type_arena: &mut Arena, ) -> Result, Error<'a>> { self.scopes.push(Scope::TypeDecl); @@ -1034,13 +1043,13 @@ impl Parser { lexer.expect(Token::Paren('<'))?; let class = Self::get_storage_class(lexer.next_ident()?)?; lexer.expect(Token::Separator(','))?; - let base = self.parse_type_decl(lexer, type_arena)?; + let base = self.parse_type_decl(lexer, None, type_arena)?; lexer.expect(Token::Paren('>'))?; crate::TypeInner::Pointer { base, class } } Token::Word("array") => { lexer.expect(Token::Paren('<'))?; - let base = self.parse_type_decl(lexer, type_arena)?; + let base = self.parse_type_decl(lexer, None, type_arena)?; let size = match lexer.next() { Token::Separator(',') => { let value = lexer.next_uint_literal()?; @@ -1087,7 +1096,10 @@ impl Parser { other => return Err(Error::Unexpected(other)), }; self.scopes.pop(); - Ok(Typifier::deduce_type_handle(inner, type_arena)) + Ok(type_arena.fetch_or_append(crate::Type { + name: self_name.map(|s| s.to_string()), + inner, + })) } fn parse_statement<'a>( @@ -1269,7 +1281,7 @@ impl Parser { let return_type = if lexer.skip(Token::Word("void")) { None } else { - Some(self.parse_type_decl(lexer, &mut module.types)?) + Some(self.parse_type_decl(lexer, None, &mut module.types)?) }; let fun_handle = module.functions.append(crate::Function { @@ -1302,6 +1314,7 @@ impl Parser { types: &mut module.types, constants: &mut module.constants, global_vars: &module.global_variables, + parameter_types: &fun.parameter_types, }, )?; // done @@ -1396,7 +1409,7 @@ impl Parser { Token::Word("type") => { let name = lexer.next_ident()?; lexer.expect(Token::Operation('='))?; - let ty = self.parse_type_decl(lexer, &mut module.types)?; + let ty = self.parse_type_decl(lexer, Some(name), &mut module.types)?; self.lookup_type.insert(name.to_owned(), ty); lexer.expect(Token::Separator(';'))?; } diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index e5c5a40060..7109751be7 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -1,7 +1,4 @@ -use crate::{ - arena::{Arena, Handle}, - Type, TypeInner, VectorSize, -}; +use crate::arena::{Arena, Handle}; use thiserror::Error; @@ -29,29 +26,35 @@ impl Typifier { &mut self, expr_handle: Handle, expressions: &Arena, - types: &mut Arena, + arena: &mut Arena, constants: &Arena, global_vars: &Arena, local_vars: &Arena, functions: &Arena, + parameter_types: &[Handle], ) -> Result, ResolveError> { + #[derive(Debug)] + enum Resolution { + Handle(crate::Handle), + Value(crate::TypeInner), + } + if self.types.len() <= expr_handle.index() { for (eh, expr) in expressions.iter().skip(self.types.len()) { - let ty = match *expr { + let resolution = match *expr { crate::Expression::Access { base, .. } => { - match types[self.types[base.index()]].inner { - crate::TypeInner::Array { base, .. } => base, + match arena[self.types[base.index()]].inner { + crate::TypeInner::Array { base, .. } => Resolution::Handle(base), ref other => panic!("Can't access into {:?}", other), } } crate::Expression::AccessIndex { base, index } => { - match types[self.types[base.index()]].inner { + match arena[self.types[base.index()]].inner { crate::TypeInner::Vector { size, kind, width } => { if index >= size as u32 { return Err(ResolveError::InvalidAccessIndex); } - let inner = crate::TypeInner::Scalar { kind, width }; - Self::deduce_type_handle(inner, types) + Resolution::Value(crate::TypeInner::Scalar { kind, width }) } crate::TypeInner::Matrix { columns, @@ -62,89 +65,92 @@ impl Typifier { if index >= columns as u32 { return Err(ResolveError::InvalidAccessIndex); } - let inner = crate::TypeInner::Vector { + Resolution::Value(crate::TypeInner::Vector { size: rows, kind, width, - }; - Self::deduce_type_handle(inner, types) + }) } - crate::TypeInner::Array { base, .. } => base, + crate::TypeInner::Array { base, .. } => Resolution::Handle(base), crate::TypeInner::Struct { ref members } => { - members + let member = members .get(index as usize) - .ok_or(ResolveError::InvalidAccessIndex)? - .ty + .ok_or(ResolveError::InvalidAccessIndex)?; + Resolution::Handle(member.ty) } ref other => panic!("Can't access into {:?}", other), } } - crate::Expression::Constant(h) => constants[h].ty, - crate::Expression::Compose { ty, .. } => ty, - crate::Expression::FunctionParameter(_) => unimplemented!(), - crate::Expression::GlobalVariable(h) => global_vars[h].ty, - crate::Expression::LocalVariable(h) => local_vars[h].ty, + crate::Expression::Constant(h) => Resolution::Handle(constants[h].ty), + crate::Expression::Compose { ty, .. } => Resolution::Handle(ty), + crate::Expression::FunctionParameter(index) => { + Resolution::Handle(parameter_types[index as usize]) + } + crate::Expression::GlobalVariable(h) => Resolution::Handle(global_vars[h].ty), + crate::Expression::LocalVariable(h) => Resolution::Handle(local_vars[h].ty), crate::Expression::Load { .. } => unimplemented!(), crate::Expression::ImageSample { image, .. } | crate::Expression::ImageLoad { image, .. } => { let image = self.resolve( image, expressions, - types, + arena, constants, global_vars, local_vars, functions, + parameter_types, )?; - let inner = match types[image].inner { - TypeInner::Image { + Resolution::Value(match arena[image].inner { + crate::TypeInner::Image { kind, class: crate::ImageClass::Depth, .. - } => TypeInner::Scalar { kind, width: 4 }, - TypeInner::Image { kind, .. } => TypeInner::Vector { + } => crate::TypeInner::Scalar { kind, width: 4 }, + crate::TypeInner::Image { kind, .. } => crate::TypeInner::Vector { kind, width: 4, - size: VectorSize::Quad, + size: crate::VectorSize::Quad, }, _ => unreachable!(), - }; - - types.fetch_or_append(Type { name: None, inner }) + }) + } + crate::Expression::Unary { expr, .. } => { + Resolution::Handle(self.types[expr.index()]) } - crate::Expression::Unary { expr, .. } => self.types[expr.index()], crate::Expression::Binary { op, left, right } => match op { crate::BinaryOperator::Add | crate::BinaryOperator::Subtract | crate::BinaryOperator::Divide - | crate::BinaryOperator::Modulo => self.types[left.index()], + | crate::BinaryOperator::Modulo => { + Resolution::Handle(self.types[left.index()]) + } crate::BinaryOperator::Multiply => { let ty_left = self.types[left.index()]; let ty_right = self.types[right.index()]; if ty_left == ty_right { - ty_left - } else if let crate::TypeInner::Scalar { .. } = types[ty_right].inner { - ty_left - } else if let crate::TypeInner::Scalar { .. } = types[ty_left].inner { - ty_right + Resolution::Handle(ty_left) + } else if let crate::TypeInner::Scalar { .. } = arena[ty_right].inner { + Resolution::Handle(ty_left) + } else if let crate::TypeInner::Scalar { .. } = arena[ty_left].inner { + Resolution::Handle(ty_right) } else if let crate::TypeInner::Matrix { columns, kind, width, .. - } = types[ty_left].inner + } = arena[ty_left].inner { - let inner = crate::TypeInner::Vector { + Resolution::Value(crate::TypeInner::Vector { size: columns, kind, width, - }; - Self::deduce_type_handle(inner, types) + }) } else { panic!( "Incompatible arguments {:?} x {:?}", - types[ty_left], types[ty_right] + arena[ty_left], arena[ty_right] ); } } @@ -155,60 +161,61 @@ impl Typifier { | crate::BinaryOperator::Greater | crate::BinaryOperator::GreaterEqual | crate::BinaryOperator::LogicalAnd - | crate::BinaryOperator::LogicalOr => self.types[left.index()], + | crate::BinaryOperator::LogicalOr => { + Resolution::Handle(self.types[left.index()]) + } crate::BinaryOperator::And | crate::BinaryOperator::ExclusiveOr | crate::BinaryOperator::InclusiveOr | crate::BinaryOperator::ShiftLeftLogical | crate::BinaryOperator::ShiftRightLogical - | crate::BinaryOperator::ShiftRightArithmetic => self.types[left.index()], + | crate::BinaryOperator::ShiftRightArithmetic => { + Resolution::Handle(self.types[left.index()]) + } }, crate::Expression::Intrinsic { .. } => unimplemented!(), crate::Expression::Transpose(expr) => { let ty_handle = self.types[expr.index()]; - let inner = match types[ty_handle].inner { + match arena[ty_handle].inner { crate::TypeInner::Matrix { columns, rows, kind, width, - } => crate::TypeInner::Matrix { + } => Resolution::Value(crate::TypeInner::Matrix { columns: rows, rows: columns, kind, width, - }, + }), ref other => panic!("incompatible transpose of {:?}", other), - }; - types.fetch_or_append(Type { name: None, inner }) + } } crate::Expression::DotProduct(left_expr, _) => { let left_ty = self.types[left_expr.index()]; - let inner = match types[left_ty].inner { + match arena[left_ty].inner { crate::TypeInner::Vector { kind, size: _, width, - } => crate::TypeInner::Scalar { kind, width }, + } => Resolution::Value(crate::TypeInner::Scalar { kind, width }), ref other => panic!("incompatible dot of {:?}", other), - }; - types.fetch_or_append(Type { name: None, inner }) + } } crate::Expression::CrossProduct(_, _) => unimplemented!(), crate::Expression::As(expr, kind) => { let ty_handle = self.types[expr.index()]; - let inner = match types[ty_handle].inner { + match arena[ty_handle].inner { crate::TypeInner::Scalar { kind: _, width } => { - crate::TypeInner::Scalar { kind, width } + Resolution::Value(crate::TypeInner::Scalar { kind, width }) } crate::TypeInner::Vector { kind: _, size, width, - } => crate::TypeInner::Vector { kind, size, width }, + } => Resolution::Value(crate::TypeInner::Vector { kind, size, width }), ref other => panic!("incompatible as of {:?}", other), - }; - types.fetch_or_append(Type { name: None, inner }) + } } crate::Expression::Derivative { .. } => unimplemented!(), crate::Expression::Call { @@ -217,42 +224,40 @@ impl Typifier { } => match name.as_str() { "distance" | "length" | "dot" => { let ty_handle = self.types[arguments[0].index()]; - let inner = match types[ty_handle].inner { + match arena[ty_handle].inner { crate::TypeInner::Vector { kind, width, .. } => { - crate::TypeInner::Scalar { kind, width } + Resolution::Value(crate::TypeInner::Scalar { kind, width }) } ref other => panic!("Unexpected argument {:?}", other), - }; - Self::deduce_type_handle(inner, types) + } } "normalize" | "fclamp" | "max" | "reflect" | "pow" | "clamp" | "mix" => { - self.types[arguments[0].index()] + Resolution::Handle(self.types[arguments[0].index()]) } _ => return Err(ResolveError::FunctionNotDefined { name: name.clone() }), }, crate::Expression::Call { origin: crate::FunctionOrigin::Local(handle), arguments: _, - } => functions[handle] - .return_type - .ok_or(ResolveError::FunctionReturnsVoid)?, + } => { + let ty = functions[handle] + .return_type + .ok_or(ResolveError::FunctionReturnsVoid)?; + Resolution::Handle(ty) + } }; - log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, ty); - self.types.push(ty); + log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution); + self.types.push(match resolution { + Resolution::Handle(h) => h, + Resolution::Value(inner) => arena + .fetch_if_or_append(crate::Type { name: None, inner }, |a, b| { + a.inner == b.inner + }), + }); } } Ok(self.types[expr_handle.index()]) } - - pub fn deduce_type_handle( - inner: crate::TypeInner, - arena: &mut Arena, - ) -> Handle { - if let Some((token, _)) = arena.iter().find(|(_, ty)| ty.inner == inner) { - return token; - } - arena.append(crate::Type { name: None, inner }) - } } #[derive(Clone, Debug, Error)]