mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] use local bindings in triton.cc (#1932)
Another follow up with the relative imports this time dealing with the bindings.
This commit is contained in:
@@ -65,7 +65,7 @@ enum backend_t {
|
||||
|
||||
void init_triton_runtime(py::module &&m) {
|
||||
// wrap backend_t
|
||||
py::enum_<backend_t>(m, "backend")
|
||||
py::enum_<backend_t>(m, "backend", py::module_local())
|
||||
.value("HOST", HOST)
|
||||
.value("CUDA", CUDA)
|
||||
.value("ROCM", ROCM)
|
||||
@@ -164,12 +164,14 @@ void init_triton_ir(py::module &&m) {
|
||||
using ret = py::return_value_policy;
|
||||
using namespace pybind11::literals;
|
||||
|
||||
py::enum_<mlir::triton::PaddingOption>(m, "PADDING_OPTION")
|
||||
py::enum_<mlir::triton::PaddingOption>(m, "PADDING_OPTION",
|
||||
py::module_local())
|
||||
.value("PAD_ZERO", mlir::triton::PaddingOption::PAD_ZERO)
|
||||
.value("PAD_NAN", mlir::triton::PaddingOption::PAD_NAN)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::CacheModifier>(m, "CACHE_MODIFIER")
|
||||
py::enum_<mlir::triton::CacheModifier>(m, "CACHE_MODIFIER",
|
||||
py::module_local())
|
||||
.value("NONE", mlir::triton::CacheModifier::NONE)
|
||||
.value("CA", mlir::triton::CacheModifier::CA)
|
||||
.value("CG", mlir::triton::CacheModifier::CG)
|
||||
@@ -178,20 +180,21 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("WT", mlir::triton::CacheModifier::WT)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::MemSemantic>(m, "MEM_SEMANTIC")
|
||||
py::enum_<mlir::triton::MemSemantic>(m, "MEM_SEMANTIC", py::module_local())
|
||||
.value("ACQUIRE_RELEASE", mlir::triton::MemSemantic::ACQUIRE_RELEASE)
|
||||
.value("ACQUIRE", mlir::triton::MemSemantic::ACQUIRE)
|
||||
.value("RELEASE", mlir::triton::MemSemantic::RELEASE)
|
||||
.value("RELAXED", mlir::triton::MemSemantic::RELAXED)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::EvictionPolicy>(m, "EVICTION_POLICY")
|
||||
py::enum_<mlir::triton::EvictionPolicy>(m, "EVICTION_POLICY",
|
||||
py::module_local())
|
||||
.value("NORMAL", mlir::triton::EvictionPolicy::NORMAL)
|
||||
.value("EVICT_FIRST", mlir::triton::EvictionPolicy::EVICT_FIRST)
|
||||
.value("EVICT_LAST", mlir::triton::EvictionPolicy::EVICT_LAST)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::RMWOp>(m, "ATOMIC_OP")
|
||||
py::enum_<mlir::triton::RMWOp>(m, "ATOMIC_OP", py::module_local())
|
||||
.value("ADD", mlir::triton::RMWOp::ADD)
|
||||
.value("FADD", mlir::triton::RMWOp::FADD)
|
||||
.value("AND", mlir::triton::RMWOp::AND)
|
||||
@@ -203,7 +206,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("UMIN", mlir::triton::RMWOp::UMIN)
|
||||
.value("UMAX", mlir::triton::RMWOp::UMAX);
|
||||
|
||||
py::class_<mlir::MLIRContext>(m, "context")
|
||||
py::class_<mlir::MLIRContext>(m, "context", py::module_local())
|
||||
.def(py::init<>())
|
||||
.def("load_triton", [](mlir::MLIRContext &self) {
|
||||
self.getOrLoadDialect<mlir::triton::TritonDialect>();
|
||||
@@ -259,7 +262,7 @@ void init_triton_ir(py::module &&m) {
|
||||
// // py::class_<ir::undef_value, ir::constant>(m, "undef")
|
||||
// // .def("get", &ir::undef_value::get, ret::reference);
|
||||
|
||||
py::class_<mlir::Type>(m, "type")
|
||||
py::class_<mlir::Type>(m, "type", py::module_local())
|
||||
.def("is_integer", &mlir::Type::isInteger)
|
||||
.def("is_fp16", &mlir::Type::isF16)
|
||||
.def("__str__", [](mlir::Type &self) {
|
||||
@@ -269,13 +272,13 @@ void init_triton_ir(py::module &&m) {
|
||||
return os.str();
|
||||
});
|
||||
|
||||
py::class_<mlir::FunctionType>(m, "function_type")
|
||||
py::class_<mlir::FunctionType>(m, "function_type", py::module_local())
|
||||
.def("param_types", [](mlir::FunctionType &self) {
|
||||
return std::vector<mlir::Type>(self.getInputs().begin(),
|
||||
self.getInputs().end());
|
||||
});
|
||||
|
||||
py::class_<mlir::Location>(m, "location")
|
||||
py::class_<mlir::Location>(m, "location", py::module_local())
|
||||
.def("__str__", [](mlir::Location &self) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
@@ -283,7 +286,7 @@ void init_triton_ir(py::module &&m) {
|
||||
return os.str();
|
||||
});
|
||||
|
||||
py::class_<mlir::Value>(m, "value")
|
||||
py::class_<mlir::Value>(m, "value", py::module_local())
|
||||
.def("set_attr",
|
||||
[](mlir::Value &self, std::string &name,
|
||||
mlir::Attribute &attr) -> void {
|
||||
@@ -307,14 +310,15 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("get_type", &mlir::Value::getType);
|
||||
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_argument");
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_argument",
|
||||
py::module_local());
|
||||
|
||||
py::class_<mlir::Region>(m, "region")
|
||||
py::class_<mlir::Region>(m, "region", py::module_local())
|
||||
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
|
||||
.def("size", [](mlir::Region &self) { return self.getBlocks().size(); })
|
||||
.def("empty", &mlir::Region::empty);
|
||||
|
||||
py::class_<mlir::Block>(m, "block")
|
||||
py::class_<mlir::Block>(m, "block", py::module_local())
|
||||
.def("arg",
|
||||
[](mlir::Block &self, int index) -> mlir::BlockArgument {
|
||||
return self.getArgument(index);
|
||||
@@ -383,12 +387,14 @@ void init_triton_ir(py::module &&m) {
|
||||
// .value("retune", eattr::retune)
|
||||
// .value("not_implemented", eattr::not_implemented);
|
||||
|
||||
py::class_<mlir::Attribute>(m, "attribute");
|
||||
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "integer_attr");
|
||||
py::class_<mlir::BoolAttr, mlir::Attribute>(m, "bool_attr");
|
||||
py::class_<mlir::Attribute>(m, "attribute", py::module_local());
|
||||
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "integer_attr",
|
||||
py::module_local());
|
||||
py::class_<mlir::BoolAttr, mlir::Attribute>(m, "bool_attr",
|
||||
py::module_local());
|
||||
|
||||
// Ops
|
||||
py::class_<mlir::OpState>(m, "OpState")
|
||||
py::class_<mlir::OpState>(m, "OpState", py::module_local())
|
||||
.def("set_attr",
|
||||
[](mlir::OpState &self, std::string &name,
|
||||
mlir::Attribute &attr) -> void { self->setAttr(name, attr); })
|
||||
@@ -427,23 +433,27 @@ void init_triton_ir(py::module &&m) {
|
||||
return mlir::succeeded(mlir::verify(self.getOperation()));
|
||||
});
|
||||
// scf Ops
|
||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp")
|
||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp", py::module_local())
|
||||
.def("get_induction_var", &mlir::scf::ForOp::getInductionVar);
|
||||
|
||||
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp")
|
||||
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp", py::module_local())
|
||||
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
|
||||
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
|
||||
.def("get_then_yield", &mlir::scf::IfOp::thenYield)
|
||||
.def("get_else_yield", &mlir::scf::IfOp::elseYield);
|
||||
py::class_<mlir::scf::YieldOp, mlir::OpState>(m, "YieldOp");
|
||||
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp")
|
||||
py::class_<mlir::scf::YieldOp, mlir::OpState>(m, "YieldOp",
|
||||
py::module_local());
|
||||
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp",
|
||||
py::module_local())
|
||||
.def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference)
|
||||
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
|
||||
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "ConditionOp");
|
||||
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "ConditionOp",
|
||||
py::module_local());
|
||||
|
||||
// dynamic_attr is used to transfer ownership of the MLIR context to the
|
||||
// module
|
||||
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module", py::dynamic_attr())
|
||||
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module", py::module_local(),
|
||||
py::dynamic_attr())
|
||||
.def("dump", &mlir::ModuleOp::dump)
|
||||
.def("str",
|
||||
[](mlir::ModuleOp &self) -> std::string {
|
||||
@@ -523,7 +533,8 @@ void init_triton_ir(py::module &&m) {
|
||||
},
|
||||
ret::take_ownership);
|
||||
|
||||
py::class_<mlir::triton::FuncOp, mlir::OpState>(m, "function")
|
||||
py::class_<mlir::triton::FuncOp, mlir::OpState>(m, "function",
|
||||
py::module_local())
|
||||
// .def_property_readonly("attrs", &ir::function::attrs)
|
||||
// .def("add_attr", &ir::function::add_attr);
|
||||
.def("args",
|
||||
@@ -571,9 +582,11 @@ void init_triton_ir(py::module &&m) {
|
||||
.def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType)
|
||||
.def("reset_type", &mlir::triton::FuncOp::setType);
|
||||
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint",
|
||||
py::module_local());
|
||||
|
||||
py::class_<TritonOpBuilder>(m, "builder", py::dynamic_attr())
|
||||
py::class_<TritonOpBuilder>(m, "builder", py::module_local(),
|
||||
py::dynamic_attr())
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
// getters
|
||||
.def("create_module",
|
||||
@@ -1507,7 +1520,7 @@ void init_triton_ir(py::module &&m) {
|
||||
offsets);
|
||||
});
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
py::class_<mlir::PassManager>(m, "pass_manager", py::module_local())
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
.def("enable_debug",
|
||||
[](mlir::PassManager &self) {
|
||||
|
||||
Reference in New Issue
Block a user