diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 682ce8c8c6..b852dac79b 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -570,6 +570,40 @@ pub(super) fn instruction_composite_construct( instruction } +pub(super) fn instruction_composite_extract( + result_type_id: Word, + id: Word, + composite_id: Word, + indices: &[Word], +) -> Instruction { + let mut instruction = Instruction::new(Op::CompositeExtract); + instruction.set_type(result_type_id); + instruction.set_result(id); + + instruction.add_operand(composite_id); + for index in indices { + instruction.add_operand(*index); + } + + instruction +} + +pub(super) fn instruction_vector_extract_dynamic( + result_type_id: Word, + id: Word, + vector_id: Word, + index_id: Word, +) -> Instruction { + let mut instruction = Instruction::new(Op::VectorExtractDynamic); + instruction.set_type(result_type_id); + instruction.set_result(id); + + instruction.add_operand(vector_id); + instruction.add_operand(index_id); + + instruction +} + // // Arithmetic Instructions // diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 2b1d2c25db..91ad75edb4 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1060,11 +1060,10 @@ impl Writer { let (raw_expression, lookup_ty) = self.write_expression_raw(ir_module, ir_function, handle, block, function)?; Ok(match raw_expression { - RawExpression::Value(_id) => { - //TODO: create a local variable? - log::error!("Pointer expression {:?}", ir_function.expressions[handle]); - return Err(Error::FeatureNotImplemented("getting pointer of a value")); - } + RawExpression::Value(_id) => unimplemented!( + "Expression {:?} is not a pointer", + ir_function.expressions[handle] + ), RawExpression::Pointer(id, class) => (id, lookup_ty, class), }) } @@ -1081,78 +1080,89 @@ impl Writer { match ir_function.expressions[expr_handle] { crate::Expression::Access { base, index } => { let id = self.generate_id(); - - let (base_id, base_lookup_ty, class) = - self.write_expression_pointer(ir_module, ir_function, base, block, function)?; + let (raw_base_expression, base_lookup_ty) = + self.write_expression_raw(ir_module, ir_function, base, block, function)?; + let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty); let (index_id, _) = self.write_expression(ir_module, ir_function, index, block, function)?; - let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty); - let (pointer_type_id, lookup_ty) = match *base_ty_inner { - crate::TypeInner::Array { base, .. } => { - self.create_pointer_type(LookupType::Handle(base), class, &ir_module.types)? - } - crate::TypeInner::Vector { kind, width, .. } => self.create_pointer_type( - LocalType::Scalar { kind, width }.into(), - class, - &ir_module.types, - )?, - ref other => { - log::error!("Unable to index {:?}", other); - return Err(Error::FeatureNotImplemented("accessing of non-vector")); - } - }; - - block - .body - .push(super::instructions::instruction_access_chain( - pointer_type_id, - id, - base_id, - &[index_id], - )); - - Ok((RawExpression::Pointer(id, class), lookup_ty)) - } - crate::Expression::AccessIndex { base, index } => { - let id = self.generate_id(); - let (base_id, base_lookup_ty, class) = - self.write_expression_pointer(ir_module, ir_function, base, block, function)?; - - let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty); - let (pointer_type_id, lookup_ty) = match *base_ty_inner { + let lookup_ty = match *base_ty_inner { crate::TypeInner::Vector { size: _, kind, width, - } => self.create_pointer_type( - LocalType::Scalar { kind, width }.into(), - class, - &ir_module.types, - )?, + } => LookupType::Local(LocalType::Scalar { kind, width }), + crate::TypeInner::Array { base, .. } => LookupType::Handle(base), + ref other => { + log::error!("Unable to index {:?}", other); + return Err(Error::FeatureNotImplemented( + "accessing index of non vector or array", + )); + } + }; + + Ok(match raw_base_expression { + RawExpression::Value(base_id) => { + if let crate::TypeInner::Array { .. } = *base_ty_inner { + return Err(Error::FeatureNotImplemented( + "accessing index of a value array", + )); + } + + let result_type_id = self.get_type_id(&ir_module.types, lookup_ty)?; + block + .body + .push(super::instructions::instruction_vector_extract_dynamic( + result_type_id, + id, + base_id, + index_id, + )); + + (RawExpression::Value(id), lookup_ty) + } + RawExpression::Pointer(base_id, class) => { + let (pointer_type_id, pointer_lookup_ty) = + self.create_pointer_type(lookup_ty, class, &ir_module.types)?; + + block + .body + .push(super::instructions::instruction_access_chain( + pointer_type_id, + id, + base_id, + &[index_id], + )); + + (RawExpression::Pointer(id, class), pointer_lookup_ty) + } + }) + } + crate::Expression::AccessIndex { base, index } => { + let id = self.generate_id(); + let (raw_base_expression, base_lookup_ty) = + self.write_expression_raw(ir_module, ir_function, base, block, function)?; + let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty); + + let lookup_ty = match *base_ty_inner { + crate::TypeInner::Vector { + size: _, + kind, + width, + } => LookupType::Local(LocalType::Scalar { kind, width }), crate::TypeInner::Matrix { columns: _, rows, width, - } => { - let local_type = LocalType::Vector { - size: rows, - kind: crate::ScalarKind::Float, - width, - }; - self.create_pointer_type(local_type.into(), class, &ir_module.types)? - } + } => LookupType::Local(LocalType::Vector { + size: rows, + kind: crate::ScalarKind::Float, + width, + }), crate::TypeInner::Struct { block: _, ref members, - } => { - let member = &members[index as usize]; - self.create_pointer_type( - LookupType::Handle(member.ty), - class, - &ir_module.types, - )? - } + } => LookupType::Handle(members[index as usize].ty), ref other => { log::error!("Unable to access index {:?}", other); return Err(Error::FeatureNotImplemented( @@ -1161,25 +1171,44 @@ impl Writer { } }; - let const_ty_id = self.get_type_id( - &ir_module.types, - LookupType::Local(LocalType::Scalar { - kind: crate::ScalarKind::Sint, - width: 4, - }), - )?; - let const_id = self.create_constant(const_ty_id, &[index]); + Ok(match raw_base_expression { + RawExpression::Value(base_id) => { + let result_type_id = self.get_type_id(&ir_module.types, lookup_ty)?; + block + .body + .push(super::instructions::instruction_composite_extract( + result_type_id, + id, + base_id, + &[index], + )); - block - .body - .push(super::instructions::instruction_access_chain( - pointer_type_id, - id, - base_id, - &[const_id], - )); + (RawExpression::Value(id), lookup_ty) + } + RawExpression::Pointer(base_id, class) => { + let const_ty_id = self.get_type_id( + &ir_module.types, + LookupType::Local(LocalType::Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }), + )?; + let const_id = self.create_constant(const_ty_id, &[index]); + let (pointer_type_id, pointer_lookup_ty) = + self.create_pointer_type(lookup_ty, class, &ir_module.types)?; - Ok((RawExpression::Pointer(id, class), lookup_ty)) + block + .body + .push(super::instructions::instruction_access_chain( + pointer_type_id, + id, + base_id, + &[const_id], + )); + + (RawExpression::Pointer(id, class), pointer_lookup_ty) + } + }) } crate::Expression::GlobalVariable(handle) => { let var = &ir_module.global_variables[handle]; diff --git a/tests/snapshots/snapshots__skybox.spvasm.snap b/tests/snapshots/snapshots__skybox.spvasm.snap index 3e5289334a..a6a27f59f1 100644 --- a/tests/snapshots/snapshots__skybox.spvasm.snap +++ b/tests/snapshots/snapshots__skybox.spvasm.snap @@ -46,8 +46,8 @@ OpDecorate %180 Location 0 %26 = OpTypeStruct %27 %27 %28 = OpTypePointer Uniform %26 %25 = OpVariable %28 Uniform -%29 = OpTypePointer Uniform %27 -%30 = OpConstant %3 0 +%29 = OpConstant %3 0 +%30 = OpTypePointer Uniform %27 %36 = OpConstant %8 4.0 %37 = OpConstant %8 1.0 %42 = OpConstant %8 0.0 @@ -55,66 +55,66 @@ OpDecorate %180 Location 0 %46 = OpTypePointer Output %45 %44 = OpVariable %46 Output %48 = OpTypeMatrix %45 3 -%52 = OpTypePointer Uniform %27 -%53 = OpConstant %3 1 -%54 = OpTypePointer Uniform %7 -%55 = OpConstant %3 0 -%56 = OpTypePointer Uniform %8 -%57 = OpConstant %3 0 -%62 = OpTypePointer Uniform %27 -%63 = OpConstant %3 1 -%64 = OpTypePointer Uniform %7 -%65 = OpConstant %3 0 -%66 = OpTypePointer Uniform %8 -%67 = OpConstant %3 1 -%72 = OpTypePointer Uniform %27 -%73 = OpConstant %3 1 -%74 = OpTypePointer Uniform %7 -%75 = OpConstant %3 0 -%76 = OpTypePointer Uniform %8 -%77 = OpConstant %3 2 -%83 = OpTypePointer Uniform %27 -%84 = OpConstant %3 1 -%85 = OpTypePointer Uniform %7 -%86 = OpConstant %3 1 -%87 = OpTypePointer Uniform %8 -%88 = OpConstant %3 0 -%93 = OpTypePointer Uniform %27 -%94 = OpConstant %3 1 -%95 = OpTypePointer Uniform %7 -%96 = OpConstant %3 1 -%97 = OpTypePointer Uniform %8 -%98 = OpConstant %3 1 -%103 = OpTypePointer Uniform %27 -%104 = OpConstant %3 1 -%105 = OpTypePointer Uniform %7 -%106 = OpConstant %3 1 -%107 = OpTypePointer Uniform %8 -%108 = OpConstant %3 2 -%114 = OpTypePointer Uniform %27 -%115 = OpConstant %3 1 -%116 = OpTypePointer Uniform %7 -%117 = OpConstant %3 2 -%118 = OpTypePointer Uniform %8 -%119 = OpConstant %3 0 -%124 = OpTypePointer Uniform %27 -%125 = OpConstant %3 1 -%126 = OpTypePointer Uniform %7 -%127 = OpConstant %3 2 -%128 = OpTypePointer Uniform %8 -%129 = OpConstant %3 1 -%134 = OpTypePointer Uniform %27 -%135 = OpConstant %3 1 -%136 = OpTypePointer Uniform %7 -%137 = OpConstant %3 2 -%138 = OpTypePointer Uniform %8 -%139 = OpConstant %3 2 -%145 = OpTypePointer Function %8 -%146 = OpConstant %3 0 -%149 = OpTypePointer Function %8 -%150 = OpConstant %3 1 -%153 = OpTypePointer Function %8 -%154 = OpConstant %3 2 +%52 = OpConstant %3 1 +%53 = OpTypePointer Uniform %27 +%54 = OpConstant %3 0 +%55 = OpTypePointer Uniform %7 +%56 = OpConstant %3 0 +%57 = OpTypePointer Uniform %8 +%62 = OpConstant %3 1 +%63 = OpTypePointer Uniform %27 +%64 = OpConstant %3 0 +%65 = OpTypePointer Uniform %7 +%66 = OpConstant %3 1 +%67 = OpTypePointer Uniform %8 +%72 = OpConstant %3 1 +%73 = OpTypePointer Uniform %27 +%74 = OpConstant %3 0 +%75 = OpTypePointer Uniform %7 +%76 = OpConstant %3 2 +%77 = OpTypePointer Uniform %8 +%83 = OpConstant %3 1 +%84 = OpTypePointer Uniform %27 +%85 = OpConstant %3 1 +%86 = OpTypePointer Uniform %7 +%87 = OpConstant %3 0 +%88 = OpTypePointer Uniform %8 +%93 = OpConstant %3 1 +%94 = OpTypePointer Uniform %27 +%95 = OpConstant %3 1 +%96 = OpTypePointer Uniform %7 +%97 = OpConstant %3 1 +%98 = OpTypePointer Uniform %8 +%103 = OpConstant %3 1 +%104 = OpTypePointer Uniform %27 +%105 = OpConstant %3 1 +%106 = OpTypePointer Uniform %7 +%107 = OpConstant %3 2 +%108 = OpTypePointer Uniform %8 +%114 = OpConstant %3 1 +%115 = OpTypePointer Uniform %27 +%116 = OpConstant %3 2 +%117 = OpTypePointer Uniform %7 +%118 = OpConstant %3 0 +%119 = OpTypePointer Uniform %8 +%124 = OpConstant %3 1 +%125 = OpTypePointer Uniform %27 +%126 = OpConstant %3 2 +%127 = OpTypePointer Uniform %7 +%128 = OpConstant %3 1 +%129 = OpTypePointer Uniform %8 +%134 = OpConstant %3 1 +%135 = OpTypePointer Uniform %27 +%136 = OpConstant %3 2 +%137 = OpTypePointer Uniform %7 +%138 = OpConstant %3 2 +%139 = OpTypePointer Uniform %8 +%145 = OpConstant %3 0 +%146 = OpTypePointer Function %8 +%149 = OpConstant %3 1 +%150 = OpTypePointer Function %8 +%153 = OpConstant %3 2 +%154 = OpTypePointer Function %8 %158 = OpTypePointer Output %7 %157 = OpVariable %158 Output %170 = OpVariable %158 Output @@ -138,7 +138,7 @@ OpStore %2 %14 %21 = OpLoad %16 %15 %20 = OpBitwiseAnd %3 %21 %22 OpStore %5 %20 -%24 = OpAccessChain %29 %25 %30 +%24 = OpAccessChain %30 %25 %29 %31 = OpLoad %27 %24 %34 = OpLoad %3 %2 %35 = OpConvertSToF %8 %34 @@ -151,52 +151,52 @@ OpStore %5 %20 %43 = OpCompositeConstruct %7 %32 %38 %42 %37 %23 = OpMatrixTimesVector %7 %31 %43 OpStore %6 %23 -%51 = OpAccessChain %52 %25 %53 -%50 = OpAccessChain %54 %51 %55 -%49 = OpAccessChain %56 %50 %57 +%51 = OpAccessChain %53 %25 %52 +%50 = OpAccessChain %55 %51 %54 +%49 = OpAccessChain %57 %50 %56 %58 = OpLoad %8 %49 -%61 = OpAccessChain %62 %25 %63 -%60 = OpAccessChain %64 %61 %65 -%59 = OpAccessChain %66 %60 %67 +%61 = OpAccessChain %63 %25 %62 +%60 = OpAccessChain %65 %61 %64 +%59 = OpAccessChain %67 %60 %66 %68 = OpLoad %8 %59 -%71 = OpAccessChain %72 %25 %73 -%70 = OpAccessChain %74 %71 %75 -%69 = OpAccessChain %76 %70 %77 +%71 = OpAccessChain %73 %25 %72 +%70 = OpAccessChain %75 %71 %74 +%69 = OpAccessChain %77 %70 %76 %78 = OpLoad %8 %69 %79 = OpCompositeConstruct %45 %58 %68 %78 -%82 = OpAccessChain %83 %25 %84 -%81 = OpAccessChain %85 %82 %86 -%80 = OpAccessChain %87 %81 %88 +%82 = OpAccessChain %84 %25 %83 +%81 = OpAccessChain %86 %82 %85 +%80 = OpAccessChain %88 %81 %87 %89 = OpLoad %8 %80 -%92 = OpAccessChain %93 %25 %94 -%91 = OpAccessChain %95 %92 %96 -%90 = OpAccessChain %97 %91 %98 +%92 = OpAccessChain %94 %25 %93 +%91 = OpAccessChain %96 %92 %95 +%90 = OpAccessChain %98 %91 %97 %99 = OpLoad %8 %90 -%102 = OpAccessChain %103 %25 %104 -%101 = OpAccessChain %105 %102 %106 -%100 = OpAccessChain %107 %101 %108 +%102 = OpAccessChain %104 %25 %103 +%101 = OpAccessChain %106 %102 %105 +%100 = OpAccessChain %108 %101 %107 %109 = OpLoad %8 %100 %110 = OpCompositeConstruct %45 %89 %99 %109 -%113 = OpAccessChain %114 %25 %115 -%112 = OpAccessChain %116 %113 %117 -%111 = OpAccessChain %118 %112 %119 +%113 = OpAccessChain %115 %25 %114 +%112 = OpAccessChain %117 %113 %116 +%111 = OpAccessChain %119 %112 %118 %120 = OpLoad %8 %111 -%123 = OpAccessChain %124 %25 %125 -%122 = OpAccessChain %126 %123 %127 -%121 = OpAccessChain %128 %122 %129 +%123 = OpAccessChain %125 %25 %124 +%122 = OpAccessChain %127 %123 %126 +%121 = OpAccessChain %129 %122 %128 %130 = OpLoad %8 %121 -%133 = OpAccessChain %134 %25 %135 -%132 = OpAccessChain %136 %133 %137 -%131 = OpAccessChain %138 %132 %139 +%133 = OpAccessChain %135 %25 %134 +%132 = OpAccessChain %137 %133 %136 +%131 = OpAccessChain %139 %132 %138 %140 = OpLoad %8 %131 %141 = OpCompositeConstruct %45 %120 %130 %140 %142 = OpCompositeConstruct %48 %79 %110 %141 %143 = OpTranspose %48 %142 -%144 = OpAccessChain %145 %6 %146 +%144 = OpAccessChain %146 %6 %145 %147 = OpLoad %8 %144 -%148 = OpAccessChain %149 %6 %150 +%148 = OpAccessChain %150 %6 %149 %151 = OpLoad %8 %148 -%152 = OpAccessChain %153 %6 %154 +%152 = OpAccessChain %154 %6 %153 %155 = OpLoad %8 %152 %156 = OpCompositeConstruct %45 %147 %151 %155 %47 = OpMatrixTimesVector %45 %143 %156