From 43180a70beba50ba91bfecc4a38ac0dcea19386e Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 19 Jan 2024 15:52:13 +0100 Subject: [PATCH] feat(compiler): Add infrastructure for type inference Type inference is implemented through the two classes `ForwardTypeInferenceAnalysis` and `BackwardTypeInferenceAnalysis`, which can be used as forward and backward dataflow analyses with the MLIR sparse dataflow analysis framework. Both classes rely on a type resolver, which must be a class inheriting `TypeResolver` and that specifies which types are to be considered as unresolved and that resolves the actual types for the values related to an operation based on the previous state of type inference. The type inference state for an operation is represented by an instance of the class `LocalInferenceState`, which maps the values related to an operation to instances of `InferredType` (either indicating the inferred type as an `mlir::Type` or indicating that no type has been inferred, yet). The local type inference by a type resolver can be implemented with type constraints (instances of sub-classes of `TypeConstraint`), which can be combined into a `TypeConstraintSet`. The latter provides a function that attempts to apply the constraints until the resulting type inference state converges. There are multiple, predefined type constraint classes for common constraints (e.g., if two values must have the same type or the same element type). These exist both as static constraints and as dynamic constraints. Some pre-defined type constraints depend on a class that yields a pair of values for which the contraints shall be applied (e.g., yielding two operands or an operand and a result, etc.). --- .../Analysis/TypeInferenceAnalysis.h | 978 ++++++++++++++++++ 1 file changed, 978 insertions(+) create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Analysis/TypeInferenceAnalysis.h diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/TypeInferenceAnalysis.h b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/TypeInferenceAnalysis.h new file mode 100644 index 000000000..0851766de --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/TypeInferenceAnalysis.h @@ -0,0 +1,978 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +// Infrastructure for type inference +// +// The type inference framework below allows for the implementation of +// custom type inference passes that propagate the types of a +// partially-typed input IR throughout its operations following the +// flow of data. +// +// The framework relies on three main concepts: +// +// - A type resolver inheriting from `TypeResolver` that (1) specifies +// which types are to be considered as unresolved and that (2) +// resolves the actual types for the values related to an operation +// based on a previous state of type inference. +// +// - The propagation of types is implemented through the two classes +// `ForwardTypeInferenceAnalysis` and +// `BackwardTypeInferenceAnalysis`, which can be used as forward and +// backward dataflow analyses with the MLIR sparse dataflow analysis +// framework. These analyses invoke the above-mentioned type +// resolver for an operation every time the inferred type of a value +// referenced by that operation changes. +// +// - The type inference state for an operation is represented by an +// instance of the class `LocalInferenceState`, which maps the +// values related to an operation to instances of `InferredType` +// (either indicating the inferred type as an `mlir::Type` or +// indicating that no type has been inferred, yet). +// +// Additionally, the local rules specifying the relationship between +// the state of inference before invocation of the type resolver and +// the state of inference the invocation can be implemented with type +// constraints (instances of sub-classes of `TypeConstraint`). Type +// constraints can be combined into a `TypeConstraintSet`, which +// provides a method `converge()`, that attempts to apply the +// constraints until the resulting type inference state converges. +// +// There are multiple, predefined type constraint classes for common +// constraints (e.g., if two values must have the same type or the +// same element type). These exist both as static constraints and as +// dynamic constraints. Some pre-defined type constraints depend on a +// class that yields a pair of values for which the contraints shall +// be applied (e.g., yielding two operands or an operand and a result, +// etc.). +// +// The global state of type inference after running the type inference +// analyses is contained in the `DataFlowSolver` to which the analyses +// where added. The local state of inference for an operation can be +// obtained at any stage of the anlysis via the helper methods of +// `TypeInferenceUtils`. + +#ifndef CONCRETELANG_ANALYSIS_TYPEINFERENCEANALYSIS_H +#define CONCRETELANG_ANALYSIS_TYPEINFERENCEANALYSIS_H + +#include +#include +#include +#include + +namespace mlir { +namespace concretelang { +// Encapsulated the state of inference for a single type +class InferredType { +public: + InferredType() : type(std::nullopt) {} + InferredType(std::optional t) : type(t) {} + InferredType(mlir::Type t) : type(t) {} + + mlir::Type getType() const { return type.value(); } + void setType(std::optional t) { type = t; } + bool hasType() const { return type.has_value(); } + + bool operator==(const InferredType &other) const { + return other.type == type; + } + bool operator!=(const InferredType &other) const { + return other.type != type; + } + +protected: + std::optional type; +}; + +inline raw_ostream &operator<<(llvm::raw_ostream &os, InferredType t) { + if (t.hasType()) + os << t.getType(); + else + os << "(none)"; + + return os; +} + +// A local inference state is a mapping of values to inferred types +// used to represent the state of type inference for all values +// related to an operation (e.g., types inferred for operands, +// results, block arguments, etc.) +class LocalInferenceState { +public: + using MapT = llvm::DenseMap; + + LocalInferenceState() : changed(false) {} + + // Updates the inferred type for a value + void set(mlir::Value v, InferredType t) { + if (find(v) != t) + changed = true; + + state[v] = t; + } + + // Looks up the inferred type as an `InferredType` for a value. If + // no type has been inferred or set for the value, an empty inferred + // type is returned. + InferredType find(mlir::Value v) const { + auto it = state.find(v); + + return (it != state.end()) ? it->second : InferredType(std::nullopt); + } + + // Looks up the inferred type as an MLIR type for a value. If no + // type has been inferred or set for the value, an empty type is + // returned. + mlir::Type lookup(mlir::Value v) const { + InferredType t = this->find(v); + + return t.hasType() ? t.getType() : mlir::Type(); + } + + // Looks up the inferred types for an entire range of values. For + // values whose type has not been inferred, empty types are + // returned. + llvm::SmallVector lookup(mlir::ValueRange r) const { + return llvm::to_vector( + llvm::map_range(r, [&](mlir::Value v) { return this->lookup(v); })); + } + + // Resets the change flag + void setUnchanged() { changed = false; } + + // Indicates if the inferred type for any of the values has changed + // or if the inferred type for a new value has been added since the + // creation of the inference set or since the last time the change + // flag was reset. + bool hasChanged() { return changed; } + + // Helper class representing the values for which types have been + // inferred + class ValueSet { + public: + ValueSet(const MapT &map) : map(map) {} + + template class IteratorImpl { + public: + IteratorImpl(MapIteratorT it) : it(it) {} + mlir::Value operator*() { return it->first; } + mlir::Value *operator->() { return &it->first; } + IteratorImpl operator++() { return IteratorImpl(it++); } + bool operator==(const IteratorImpl &other) { return other.it == it; } + bool operator!=(const IteratorImpl &other) { return other.it != it; } + + protected: + MapIteratorT it; + }; + + using const_iterator = IteratorImpl; + + const_iterator cbegin() const { return const_iterator(map.begin()); } + const_iterator cend() const { return const_iterator(map.end()); } + const_iterator begin() const { return const_iterator(map.begin()); } + const_iterator end() const { return const_iterator(map.end()); } + + protected: + const MapT ↦ + }; + + void dump() const { + llvm::dbgs() << "LocalInferenceState:\n"; + + for (auto [v, t] : state) { + llvm::dbgs() << " " << v << ": " << t << "\n"; + } + } + + const ValueSet getValues() const { return ValueSet(state); } + +protected: + MapT state; + bool changed; +}; + +// Lattice value for type inference analysis, encapsulating a single +// inferred type +class TypeInferenceLatticeValue { +public: + TypeInferenceLatticeValue(const InferredType &t = {}) : inferredType(t) {} + + static TypeInferenceLatticeValue getPessimisticValueState(mlir::Value value) { + return TypeInferenceLatticeValue(); + } + + bool operator==(const TypeInferenceLatticeValue &rhs) { + return rhs.getInferredType() == inferredType; + } + + static TypeInferenceLatticeValue join(const TypeInferenceLatticeValue &lhs, + const TypeInferenceLatticeValue &rhs) { + return TypeInferenceLatticeValue(lhs.getInferredType().hasType() + ? lhs.getInferredType() + : rhs.getInferredType()); + } + + static TypeInferenceLatticeValue meet(const TypeInferenceLatticeValue &lhs, + const TypeInferenceLatticeValue &rhs) { + return join(lhs, rhs); + } + + const InferredType &getInferredType() const { return inferredType; } + void setType(const InferredType &t) { this->inferredType = t; } + + void print(raw_ostream &os) const { os << inferredType; } + +protected: + InferredType inferredType; +}; + +// Lattice for forward and backward type inference analysis +class TypeInferenceLattice + : public mlir::dataflow::Lattice { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TypeInferenceLattice) + + using Lattice::Lattice; +}; + +// Collection of common operations for forward and backward type +// inference and subsequent transformations (e.g., rewriting the IR +// with the inferred types). +class TypeInferenceUtils { +public: + // Invokes the callback function `fn` to all values related to the + // operation `op`. Iteration over the related value stops if `fn` + // returns `false`. + // + // Returns `true` if `fn` returned `true` for all related values, + // otherwise `false`. + static bool iterateRelatedValues(mlir::Operation *op, + llvm::function_ref fn) { + auto applyToOperandsAndResults = [&](mlir::Operation *op) { + for (mlir::Value v : op->getOperands()) { + if (!fn(v)) + return false; + } + + for (mlir::Value v : op->getResults()) { + if (!fn(v)) + return false; + } + + return true; + }; + + if (!applyToOperandsAndResults(op)) + return false; + + for (mlir::Region &r : op->getRegions()) { + for (mlir::Value v : r.getArguments()) { + if (!fn(v)) + return false; + } + + if (!op->hasTrait()) { + // Also map return-like terminators, as they are excluded from + // forward and backward analysis and not visited individually + for (mlir::Block &b : r.getBlocks()) { + mlir::Operation *terminator = b.getTerminator(); + + if (terminator->hasTrait()) + if (!applyToOperandsAndResults(terminator)) + return false; + } + } + } + + return true; + } + + // Looks up the inferred types for all values related to the + // operation `op`, i.e., the inferred types for operands, results, + // region arguments and operands of terminators used in directly + // nested regions. + static LocalInferenceState getLocalInferenceState( + mlir::Operation *op, + llvm::function_ref lookup) { + LocalInferenceState map; + + iterateRelatedValues(op, [&](mlir::Value v) { + map.set(v, getInferredType(v, lookup)); + return true; + }); + + return map; + } + + // Retrieves the inferred types for all values related to `op` using + // `DataFlowSolver::lookupState` invoked on `solver`. + static LocalInferenceState + getLocalInferenceState(const DataFlowSolver &solver, mlir::Operation *op) { + LocalInferenceState map = + TypeInferenceUtils::getLocalInferenceState(op, [&](mlir::Value v) { + return solver.lookupState(v); + }); + + return map; + } + + // If `t` is a tensor or memref type, return a tensor or memref type + // with the same shape, but with `elementType` as the element + // type. If `t` is a scalar type, simply return `elementType`. + static mlir::Type applyElementType(mlir::Type elementType, mlir::Type t) { + if (mlir::RankedTensorType rtt = + llvm::dyn_cast(t)) { + return mlir::RankedTensorType::get( + rtt.getShape(), applyElementType(elementType, rtt.getElementType())); + } else if (mlir::MemRefType mrt = llvm::dyn_cast(t)) { + return mlir::MemRefType::get( + mrt.getShape(), applyElementType(elementType, mrt.getElementType())); + } else { + return elementType; + } + } + +protected: + // Looks up the lattice value for `v` using `lookup` and returns the + // result as an instance of `InferredType` + static InferredType getInferredType( + mlir::Value v, + llvm::function_ref lookup) { + const TypeInferenceLattice *latticeOperand = lookup(v); + + return InferredType(latticeOperand->getValue().getInferredType()); + } +}; + +// A type resolver is responsible for the actual inference of the +// types related to an operation based on the previous inference state +class TypeResolver { +public: + virtual ~TypeResolver() {} + + // Resolve the types for all values related to `op` based on the + // previous inference state `prevState` and the current state + // `currState`. + virtual LocalInferenceState resolve(mlir::Operation *op, + const LocalInferenceState &prevState) = 0; + + // Returns true for any type that is considered unresolved by the + // resolver + virtual bool isUnresolvedType(mlir::Type t) const = 0; +}; + +// Base class for all type constraints that can be added to a +// `TypeConstraintSet`. +class TypeConstraint { +public: + virtual ~TypeConstraint() {} + + // Apply the constraint on the types inferred previously (based on + // types inferred during previous visits of related operations or + // the operation itself) and those inferred during the convergence + // of the current visit of the operation. + virtual void apply(mlir::Operation *op, TypeResolver &resolver, + LocalInferenceState &nextState, + const LocalInferenceState &prevState) = 0; + +protected: + // Returns the type inferred for `v`. This is either the type of `v` + // if `v` has a non-unresolved type, or the type specified in + // `currState`, or the type specified in `prevState`, or + // `std::nullopt` if the type of `v` is unresolved and no type has + // been inferred. + static InferredType getFirstTypeInOrder(TypeResolver &resolver, + const LocalInferenceState &currState, + const LocalInferenceState &prevState, + mlir::Value v) { + if (!resolver.isUnresolvedType(v.getType())) + return v.getType(); + + InferredType sit = currState.find(v); + + if (sit.hasType()) + return sit; + + return prevState.find(v); + } +}; + +// Yields the operands with the static indices i and j of an operation +template class OperandValueYield { +public: + static std::pair yield(mlir::Operation *op) { + return {op->getOperand(i), op->getOperand(j)}; + } +}; + +// Yields the results with the static indices i and j of an +// operation +template class ResultValueYield { +public: + static std::pair yield(mlir::Operation *op) { + return {op->getResult(i), op->getResult(j)}; + } +}; + +// Yields the operand and result with the static indices i and j of an +// operation +template class OperandAndResultValueYield { +public: + static std::pair yield(mlir::Operation *op) { + return {op->getOperand(i), op->getResult(j)}; + } +}; + +// Yields the i-th and j-th operands of an operation, where `i` and +// `j` are defined dynamically +class DynamicOperandValueYield { +public: + DynamicOperandValueYield(int i, int j) : i(i), j(j) {} + + std::pair yield(mlir::Operation *op) { + return {op->getOperand(i), op->getOperand(j)}; + } + +protected: + int i; + int j; +}; + +// Yields the i-th and j-th results of an operation, where `i` and +// `j` are defined dynamically +class DynamicResultValueYield { +public: + DynamicResultValueYield(int i, int j) : i(i), j(j) {} + + std::pair yield(mlir::Operation *op) { + return {op->getResult(i), op->getResult(j)}; + } + +protected: + int i; + int j; +}; + +// Yields the i-th operand and j-th result of an operation, where `i` +// and `j` are defined dynamically +class DynamicOperandAndResultValueYield { +public: + DynamicOperandAndResultValueYield(int i, int j) : i(i), j(j) {} + + std::pair yield(mlir::Operation *op) { + return {op->getOperand(i), op->getResult(j)}; + } + +protected: + int i; + int j; +}; + +// Yield the return values of two callback functions `f̀` and `g` +class DynamicFunctorYield { +public: + DynamicFunctorYield(std::function f, + std::function g) + : f(f), g(g) {} + + std::pair yield(mlir::Operation *op) { + return {f(), g()}; + } + +protected: + std::function f; + std::function g; +}; + +// Empty type constraint that accepts any previously inferred type +class NoTypeConstraint : public TypeConstraint { +public: + void apply(mlir::Operation *op, TypeResolver &resolver, + LocalInferenceState &currState, + const LocalInferenceState &prevState) override { + for (mlir::Value v : prevState.getValues()) { + InferredType t = getFirstTypeInOrder(resolver, currState, prevState, v); + currState.set(v, t); + } + } +}; + +// Base class for statically-defined type constraints, assuming `YieldT::yield` +// is a static function +template class StaticTypeConstraint : public TypeConstraint { +public: + StaticTypeConstraint() = default; + +protected: + std::pair yieldValues(mlir::Operation *op) { + return YieldT::yield(op); + } +}; + +// Base class for dynamically-defined type constraints, assuming `YieldT::yield` +// is a non-static function +template class DynamicTypeConstraint : public TypeConstraint { +public: + template + DynamicTypeConstraint(YieldTCtorArgTs &&...ctorArgs) + : yield(std::forward(ctorArgs)...) {} + +protected: + std::pair yieldValues(mlir::Operation *op) { + return yield.yield(op); + } + + YieldT yield; +}; + +// Constraint that ensures that the types inferred for two values +// match exactly. If different types have been inferred previously for +// the two values, precedence is given to the type inferred for the +// first value yielded by `BaseConstraintT::yield`. +// +// If the types for the two values are fixed, but contradictory, an +// assertion is triggered. +template +class SameTypeConstraintBase : public BaseConstraintT { +public: + using BaseConstraintT::BaseConstraintT; + + void apply(mlir::Operation *op, TypeResolver &resolver, + LocalInferenceState &currState, + const LocalInferenceState &prevState) override { + auto [leftValue, rightValue] = this->yieldValues(op); + + mlir::Type leftType = leftValue.getType(); + mlir::Type rightType = rightValue.getType(); + + bool leftUnresolved = resolver.isUnresolvedType(leftType); + bool rightUnresolved = resolver.isUnresolvedType(rightType); + + // Priority is given from left to right, i.e., the type of the + // first value yielded by `yield()` takes precendence over the + // type for the second value + if (!leftUnresolved && rightUnresolved) { + currState.set(leftValue, leftType); + currState.set(rightValue, leftType); + } else if (leftUnresolved && !rightUnresolved) { + currState.set(leftValue, rightType); + currState.set(rightValue, rightType); + } else if (!leftUnresolved && !rightUnresolved) { + assert(leftType == rightType && "Constraint cannot be matched, as " + "operands have different, fixed types"); + currState.set(leftValue, leftType); + currState.set(rightValue, rightType); + } else { + // Both unresolved + InferredType leftCurrType = + this->getFirstTypeInOrder(resolver, currState, prevState, leftValue); + InferredType rightCurrType = + this->getFirstTypeInOrder(resolver, currState, prevState, rightValue); + + if (leftCurrType.hasType()) { + currState.set(leftValue, leftCurrType); + currState.set(rightValue, leftCurrType); + } else if (rightCurrType.hasType()) { + currState.set(leftValue, rightCurrType); + currState.set(rightValue, rightCurrType); + } + } + } +}; + +// Same type constraint for static yield operators +template +using SameTypeConstraint = SameTypeConstraintBase>; + +// Same type constraint for dynamic yield operators +template +using DynamicSameTypeConstraint = + SameTypeConstraintBase>; + +// Constraint that ensures that the types inferred for two values use +// the same element type. If different types have been inferred +// previously for the two values, precedence is given to the element +// type inferred for the first value yielded by +// `BaseConstraintT::yield`. +// +// If the element types for the two values are fixed, but +// contradictory, an assertion is triggered. +template +class SameElementTypeConstraintBase : public BaseConstraintT { +public: + using BaseConstraintT::BaseConstraintT; + + void apply(mlir::Operation *op, TypeResolver &resolver, + LocalInferenceState &currState, + const LocalInferenceState &prevState) override { + auto [leftValue, rightValue] = this->yieldValues(op); + + mlir::Type leftType = leftValue.getType(); + mlir::Type rightType = rightValue.getType(); + + bool leftUnresolved = resolver.isUnresolvedType(leftType); + bool rightUnresolved = resolver.isUnresolvedType(rightType); + + // Priority is given from left to right, i.e., the type of the + // first value yielded by `yield()` takes precendence over the + // type for the second value + if (!leftUnresolved && rightUnresolved) { + currState.set(leftValue, leftType); + currState.set(rightValue, TypeInferenceUtils::applyElementType( + getElementType(leftType), rightType)); + } else if (leftUnresolved && !rightUnresolved) { + currState.set(leftValue, TypeInferenceUtils::applyElementType( + getElementType(rightType), leftType)); + currState.set(rightValue, rightType); + } else if (!leftUnresolved && !rightUnresolved) { + assert(getElementType(leftType) == getElementType(rightType) && + "Constraint cannot be matched, as fixed types of the values" + "are incompatible types"); + currState.set(leftValue, leftType); + currState.set(rightValue, rightType); + } else { + // Both unresolved + InferredType leftCurrType = + this->getFirstTypeInOrder(resolver, currState, prevState, leftValue); + InferredType rightCurrType = + this->getFirstTypeInOrder(resolver, currState, prevState, rightValue); + + if (leftCurrType.hasType()) { + currState.set(leftValue, leftCurrType); + currState.set(rightValue, + TypeInferenceUtils::applyElementType( + getElementType(leftCurrType.getType()), rightType)); + } else if (rightCurrType.hasType()) { + currState.set(leftValue, + TypeInferenceUtils::applyElementType( + getElementType(rightCurrType.getType()), leftType)); + currState.set(rightValue, rightCurrType); + } + } + } + +protected: + static mlir::Type getElementType(mlir::Type t) { + if (mlir::RankedTensorType rtt = + llvm::dyn_cast(t)) { + return rtt.getElementType(); + } else if (mlir::MemRefType mrt = llvm::dyn_cast(t)) { + return mrt.getElementType(); + } else { + return t; + } + } +}; + +// Same element type constraint for static yield operators +template +using SameElementTypeConstraint = + SameElementTypeConstraintBase>; + +// Same element type constraint for dynamic yield operators +template +using DynamicSameElementTypeConstraint = + SameElementTypeConstraintBase>; + +// Type constraint ensuring that two operands with the statically-defined +// indexes `a` and `b` of an operation have the same type +template +using SameOperandTypeConstraint = SameTypeConstraint>; + +// Type constraint ensuring that two results with the statically-defined +// indexes `a` and `b` of an operation have the same type +template +using SameResultTypeConstraint = SameTypeConstraint>; + +// Type constraint ensuring that the operand with the +// statically-defined index `a` and the result with the +// statically-defined index `b` of an operation have the same type +template +using SameOperandAndResultTypeConstraint = + SameTypeConstraint>; + +// Type constraint ensuring that two operands with the statically-defined +// indexes `a` and `b` of an operation have the same element type +template +using SameOperandElementTypeConstraint = + SameElementTypeConstraint>; + +// Type constraint ensuring that two results with the statically-defined +// indexes `a` and `b` of an operation have the same element type +template +using SameResultElementTypeConstraint = + SameElementTypeConstraint>; + +// Type constraint ensuring that the operand with the +// statically-defined index `a` and the result with the +// statically-defined index `b` of an operation have the same element type +template +using SameOperandAndResultElementTypeConstraint = SameElementTypeConstraint< + OperandAndResultValueYield>; + +namespace { +namespace impl { +// Specialization needs to happen at namespace scope +template +void addConstraint(std::vector> &constraints, + ArgTs &&...args) { + std::unique_ptr constraint = + std::make_unique(std::forward(args)...); + constraints.push_back(std::move(constraint)); +} + +template +void addConstraints(std::vector> &constraints, + ArgTs &&...args) { + addConstraint(constraints, std::forward(args)...); +} + +template +void addConstraints(std::vector> &constraints, + ArgTs &&...args) { + addConstraint(constraints, std::forward(args)...); + addConstraints( + constraints, std::forward(args)...); +} + +}; // namespace impl +}; // namespace + +// Set of type constraints to be applied on one visit of an +// operation. The template parameter `maxApplications` sets a limit to +// the number of rounds the set of constraints is applied in +// `converge()`. +template class TypeConstraintSet { +public: + TypeConstraintSet() {} + + // Instantiates a constraint of the type `ConstraintT` with the + // arguments `args` and adds the constraint to the set + template + void addConstraint(ArgTs &&...args) { + impl::addConstraint(constraints, + std::forward(args)...); + } + + // Instantiates the constraints of types specified by `ConstraintTs` with the + // arguments `args` and adds them to the set + template + void addConstraints(ArgTs &&...args) { + impl::addConstraints( + constraints, std::forward(args)...); + } + + // Applies all type constraints in order of their addition for a + // maximum of `maxApplications` rounds. Initial state needs to be + // provided in `prevState`. The resulting state is contained in + // `currState` afterward th call. + // + // If the rules converge on less or equal to `maxApplications` + // rounds, the function return `true`, otherwise `false`. + bool converge(mlir::Operation *op, TypeResolver &resolver, + LocalInferenceState &currState, + const LocalInferenceState &prevState) { + // Initialize with fixed types + for (mlir::Value v : op->getOperands()) { + if (!resolver.isUnresolvedType(v.getType())) + currState.set(v, v.getType()); + } + + for (mlir::Value v : op->getResults()) { + if (!resolver.isUnresolvedType(v.getType())) + currState.set(v, v.getType()); + } + + // Apply type constraints until state converges or maximum number + // of iterations has been reached + for (int i = 0; i < maxApplications; i++) { + currState.setUnchanged(); + + for (auto &constraint : constraints) { + constraint->apply(op, resolver, currState, prevState); + } + + if (!currState.hasChanged()) + return true; + } + + return !currState.hasChanged(); + } + +protected: + std::vector> constraints; +}; + +// Base class for forward and backward type analysis with shared +// infrastructure. The template parameter `AnalysisT` must either be +// `mlir::dataflow::SparseDataFlowAnalysis` or +// `mlir::dataflow::SparseBackwardDataFlowAnalysis`. +template +class TypeInferenceAnalysisBase : public AnalysisT { +public: + // Constructor passing on `constructorArgs` to the analysis class + template + TypeInferenceAnalysisBase(TypeResolver &resolver, + mlir::DataFlowSolver &solver, + ConstructorArgTs &&...constructorArgs) + : AnalysisT(solver, std::forward(constructorArgs)...), + resolver(resolver) {} + +protected: + // Returns the types inferred for all values related to the + // operation `op` from the respective lattice values. For values, + // for which no type has been inferred so far, a new lattice with an + // empty inferred types is created. + LocalInferenceState getCurrentInferredTypes(mlir::Operation *op) { + LocalInferenceState map = + TypeInferenceUtils::getLocalInferenceState(op, [&](mlir::Value v) { + return this->template getOrCreate(v); + }); + + return map; + } + + // Dumps the inferred type for all values related to `op` on + // `llvm::dbgs()`. + void debugPrintOp(mlir::Operation *op) { + LocalInferenceState inferredTypes = getCurrentInferredTypes(op); + + llvm::dbgs() << "Inference state for " << op->getName() << " \n"; + for (mlir::Value operand : op->getOperands()) { + InferredType operandType = inferredTypes.find(operand); + llvm::dbgs() << " Operand: " << operandType << "\n"; + } + + for (mlir::Value result : op->getResults()) { + InferredType resultType = inferredTypes.find(result); + llvm::dbgs() << " Result: " << resultType << "\n"; + } + + for (mlir::Region &r : op->getRegions()) { + for (mlir::Value v : r.getArguments()) { + InferredType argType = inferredTypes.find(v); + llvm::dbgs() << " RegionArg: " << argType << "\n"; + } + } + } + + // Updates the lattice for the inferred type for the value `v` using + // its inferred type specified by `state`. + void updateLatticeValuesFromState(LocalInferenceState &state, mlir::Value v) { + TypeInferenceLattice *lattice = + this->template getOrCreate(v); + auto latticeValue = lattice->getValue(); + latticeValue.setType(state.find(v)); + mlir::ChangeResult res = lattice->join(latticeValue); + this->propagateIfChanged(lattice, res); + } + + // Visits a single operation: looks up the lattice values for the + // values related to the operation, extracts the inferred types, + // invokes the type resolver and updates the lattice values with the + // newly inferred types. + void doVisitOperation(Operation *op) { + // Skip operations that do not use unresolved types + if (TypeInferenceUtils::iterateRelatedValues(op, [&](mlir::Value v) { + return !resolver.isUnresolvedType(v.getType()); + })) { + return; + } + + // Retrieve the current state of inference from the lattice values + const LocalInferenceState inferredTypes = getCurrentInferredTypes(op); + LocalInferenceState state; + + // Treat special op types for type inference + mlir::TypeSwitch(op) + .Case([&](auto op) { + state.set(op->getOperand(0), op->getResult(0).getType()); + state.set(op->getResult(0), op->getResult(0).getType()); + }) + .template Case([&](auto op) { + state.set(op->getOperand(0), op->getOperand(0).getType()); + state.set(op->getResult(0), op->getOperand(0).getType()); + }) + + // Actually resolve the types based on the current state of + // inference + .Default([&](auto op) { state = resolver.resolve(op, inferredTypes); }); + + // Store the resulting state in the lattice values + TypeInferenceUtils::iterateRelatedValues(op, [&](mlir::Value v) { + updateLatticeValuesFromState(state, v); + return true; + }); + } + + // Redirects the request for the initial inferred type of an + // `mlir::Value` to the type resolver by invoking the resolver with + // the owning operation of the value if it's a block argument. + void doInitializeForValue(mlir::Value v) { + if (mlir::BlockArgument ba = llvm::dyn_cast(v)) { + doVisitOperation(ba.getOwner()->getParentOp()); + } + } + + TypeResolver &resolver; +}; + +// Type inference analysis running forward by following the flow of +// data from sources to sinks. For each operation encountered, the +// type resolver is invoked in order to update types until +// convergence. May be combined with `BackwardTypeInferenceAnalysis` +// for bidirectional type inference. +class ForwardTypeInferenceAnalysis + : public TypeInferenceAnalysisBase< + mlir::dataflow::SparseDataFlowAnalysis> { +public: + ForwardTypeInferenceAnalysis(mlir::DataFlowSolver &solver, + TypeResolver &resolver) + : TypeInferenceAnalysisBase(resolver, solver) {} + + void setToEntryState(TypeInferenceLattice *lattice) override { + TypeInferenceAnalysisBase::doInitializeForValue(lattice->getPoint()); + } + + void + visitOperation(Operation *op, + ArrayRef latticeOperandTypes, + ArrayRef latticeResultTypes) override { + TypeInferenceAnalysisBase::doVisitOperation(op); + } +}; + +// Type inference analysis running backward by tracking back the flow +// of data from sinks to sources. For each operation encountered, the +// type resolver is invoked in order to update types until +// convergence. May be combined with `ForwardTypeInferenceAnalysis` +// for bidirectional type inference. +class BackwardTypeInferenceAnalysis + : public TypeInferenceAnalysisBase< + mlir::dataflow::SparseBackwardDataFlowAnalysis< + TypeInferenceLattice>> { +public: + BackwardTypeInferenceAnalysis(mlir::DataFlowSolver &solver, + mlir::SymbolTableCollection &symTabs, + TypeResolver &resolver) + : TypeInferenceAnalysisBase(resolver, solver, symTabs) {} + + void setToExitState(TypeInferenceLattice *lattice) override { + TypeInferenceAnalysisBase::doInitializeForValue(lattice->getPoint()); + } + + void visitOperation( + Operation *op, ArrayRef latticeOperandTypes, + ArrayRef latticeResultTypes) override { + TypeInferenceAnalysisBase::doVisitOperation(op); + } + + void visitBranchOperand(mlir::OpOperand &operand) override {} +}; + +} // namespace concretelang +} // namespace mlir + +#endif