From 902915a57b7fa3cf031a9ff73ffa7805adf12394 Mon Sep 17 00:00:00 2001 From: jackfiled Date: Tue, 3 Jun 2025 16:03:17 +0800 Subject: [PATCH] feat: toy tutorial chapter 4. Signed-off-by: jackfiled --- CMakeLists.txt | 15 +- examples/multiply_transpose.hello | 8 +- include/Dialect.h | 2 + include/Passes.h | 20 +++ include/hello/CMakeLists.txt | 5 + include/hello/Ops.td | 35 +++- include/hello/ShapeInferenceInterface.h | 18 ++ include/hello/ShapeInferenceInterface.td | 18 ++ lib/Dialect.cpp | 214 +++++++++++++++++------ lib/MLIRGen.cpp | 9 + lib/ShapeInferencePass.cpp | 95 ++++++++++ main.cpp | 9 +- 12 files changed, 380 insertions(+), 68 deletions(-) create mode 100644 include/Passes.h create mode 100644 include/hello/ShapeInferenceInterface.h create mode 100644 include/hello/ShapeInferenceInterface.td create mode 100644 lib/ShapeInferencePass.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 54e1361..754304f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,23 +35,30 @@ mlir_tablegen(HelloCombine.inc -gen-rewriters) include_directories(${CMAKE_BINARY_DIR}) add_public_tablegen_target(HelloCombineIncGen) -add_library(SyntaxNode STATIC +add_library(HelloDialect STATIC lib/SyntaxNode.cpp lib/Dialect.cpp lib/MLIRGen.cpp lib/HelloCombine.cpp + lib/ShapeInferencePass.cpp + include/SyntaxNode.h include/Parser.h include/Lexer.h + include/Dialect.h + include/MLIRGen.h + include/Passes.h ) -add_dependencies(SyntaxNode HelloOpsIncGen HelloCombineIncGen) +add_dependencies(HelloDialect HelloOpsIncGen HelloCombineIncGen HelloInterfaceIncGen) -target_link_libraries(SyntaxNode +target_link_libraries(HelloDialect PRIVATE MLIRSupport MLIRAnalysis MLIRFunctionInterfaces + MLIRCallInterfaces + MLIRCastInterfaces MLIRIR MLIRParser MLIRSideEffectInterfaces @@ -61,6 +68,6 @@ add_executable(hello-mlir main.cpp) target_link_libraries(hello-mlir PRIVATE - SyntaxNode + HelloDialect LLVMSupport LLVMCore) \ No newline at end of file diff --git a/examples/multiply_transpose.hello b/examples/multiply_transpose.hello index e174055..9d39162 100644 --- a/examples/multiply_transpose.hello +++ b/examples/multiply_transpose.hello @@ -16,11 +16,5 @@ def main() { # reuse the previously specialized and inferred version and return <3, 2>. var d = multiply_transpose(b, a); - # A new call with <3, 2> (instead of <2, 3>) for both dimensions will - # trigger another specialization of `multiply_transpose`. - var e = multiply_transpose(c, d); - - # Finally, calling into `multiply_transpose` with incompatible shapes - # (<2, 3> and <3, 2>) will trigger a shape inference error. - var f = multiply_transpose(a, c); + print(d); } \ No newline at end of file diff --git a/include/Dialect.h b/include/Dialect.h index 7a8924b..ca7ebfa 100644 --- a/include/Dialect.h +++ b/include/Dialect.h @@ -9,8 +9,10 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "hello/ShapeInferenceInterface.h" /// Include the auto-generated header file containing the declaration of the toy /// dialect. diff --git a/include/Passes.h b/include/Passes.h new file mode 100644 index 0000000..a92cbdb --- /dev/null +++ b/include/Passes.h @@ -0,0 +1,20 @@ +// +// Created by ricardo on 02/06/25. +// + +#ifndef PASSES_H +#define PASSES_H + +#include + +namespace mlir +{ + class Pass; + + namespace hello + { + std::unique_ptr createShapeInferencePass(); + } +} + +#endif //PASSES_H diff --git a/include/hello/CMakeLists.txt b/include/hello/CMakeLists.txt index f53458d..764fbdb 100644 --- a/include/hello/CMakeLists.txt +++ b/include/hello/CMakeLists.txt @@ -4,3 +4,8 @@ mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(Dialect.h.inc -gen-dialect-decls) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) add_public_tablegen_target(HelloOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +mlir_tablegen(ShapeInferenceInterface.h.inc -gen-op-interface-decls) +mlir_tablegen(ShapeInferenceInterface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(HelloInterfaceIncGen) diff --git a/include/hello/Ops.td b/include/hello/Ops.td index 25b38cc..77b3dc5 100644 --- a/include/hello/Ops.td +++ b/include/hello/Ops.td @@ -5,12 +5,17 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" +include "hello/ShapeInferenceInterface.td" def Hello_Dialect : Dialect { let name = "hello"; let cppNamespace = "::mlir::hello"; } + + class Hello_Op traits = []> : Op; @@ -70,7 +75,7 @@ def ConstantOp : Hello_Op<"constant", [Pure]> { // AddOp //===----------------------------------------------------------------------===// -def AddOp : Hello_Op<"add"> { +def AddOp : Hello_Op<"add", [Pure, DeclareOpInterfaceMethods]> { let summary = "element-wise addition operation"; let description = [{ The "add" operation performs element-wise addition between two tensors. @@ -148,7 +153,8 @@ def FuncOp : Hello_Op<"func", [ // GenericCallOp //===----------------------------------------------------------------------===// -def GenericCallOp : Hello_Op<"generic_call"> { +def GenericCallOp : Hello_Op<"generic_call", + [DeclareOpInterfaceMethods]> { let summary = "generic call operation"; let description = [{ Generic calls represent calls to a user defined function that needs to @@ -187,7 +193,7 @@ def GenericCallOp : Hello_Op<"generic_call"> { // MulOp //===----------------------------------------------------------------------===// -def MulOp : Hello_Op<"mul"> { +def MulOp : Hello_Op<"mul", [Pure, DeclareOpInterfaceMethods]> { let summary = "element-wise multiplication operation"; let description = [{ The "mul" operation performs element-wise multiplication between two @@ -296,7 +302,7 @@ def ReturnOp : Hello_Op<"return", [Pure, HasParent<"FuncOp">, // TransposeOp //===----------------------------------------------------------------------===// -def TransposeOp : Hello_Op<"transpose", [Pure]> { +def TransposeOp : Hello_Op<"transpose", [Pure, DeclareOpInterfaceMethods]> { let summary = "transpose operation"; let arguments = (ins F64Tensor:$input); @@ -316,4 +322,25 @@ def TransposeOp : Hello_Op<"transpose", [Pure]> { let hasCanonicalizer = 1; } +def CastOp : Hello_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Pure, + SameOperandsAndResultShape + ]> { + let summary = "shape cast operation"; + + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked, + then shape is required to match. The operation is invalid if converting + to a mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; +} + #endif diff --git a/include/hello/ShapeInferenceInterface.h b/include/hello/ShapeInferenceInterface.h new file mode 100644 index 0000000..b026734 --- /dev/null +++ b/include/hello/ShapeInferenceInterface.h @@ -0,0 +1,18 @@ +// +// Created by ricardo on 02/06/25. +// + +#ifndef SHAPEINFERENCEINTERFACE_H +#define SHAPEINFERENCEINTERFACE_H + +#include "mlir/IR/OpDefinition.h" + +namespace mlir +{ + namespace hello + { +#include "hello/ShapeInferenceInterface.h.inc" + } +} + +#endif //SHAPEINFERENCEINTERFACE_H diff --git a/include/hello/ShapeInferenceInterface.td b/include/hello/ShapeInferenceInterface.td new file mode 100644 index 0000000..1ec8a3b --- /dev/null +++ b/include/hello/ShapeInferenceInterface.td @@ -0,0 +1,18 @@ +#ifndef SHAPE_INFERENCE_INTERFACE +#define SHAPE_INFERENCE_INTERFACE + +include "mlir/IR/OpBase.td" + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif diff --git a/lib/Dialect.cpp b/lib/Dialect.cpp index ce7c2c7..83e8875 100644 --- a/lib/Dialect.cpp +++ b/lib/Dialect.cpp @@ -6,6 +6,49 @@ #include "hello/Dialect.cpp.inc" #include +#include +#include + +struct HelloDialectInlinerInterface : mlir::DialectInlinerInterface +{ + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable, bool wouldBeCloned) const override + { + return true; + } + + bool isLegalToInline(mlir::Region* dest, mlir::Region* src, bool wouldBeCloned, + mlir::IRMapping& valueMapping) const override + { + return true; + } + + bool isLegalToInline(mlir::Operation* op, mlir::Region* dest, bool wouldBeCloned, + mlir::IRMapping& valueMapping) const override + { + return true; + } + + void handleTerminator(mlir::Operation* op, mlir::ValueRange returnValues) const override + { + // Only the `hello.returnOp` is the function terminator + auto returnOp = llvm::cast(op); + + assert(returnOp.getNumOperands() == returnValues.size()); + + for (const auto& it : llvm::enumerate(returnOp.getOperands())) + { + returnValues[it.index()].replaceAllUsesWith(it.value()); + } + } + + mlir::Operation* materializeCallConversion(mlir::OpBuilder& builder, mlir::Value input, mlir::Type resultType, + mlir::Location conversionLoc) const override + { + return builder.create(conversionLoc, resultType, input); + } +}; void mlir::hello::HelloDialect::initialize() { @@ -13,6 +56,7 @@ void mlir::hello::HelloDialect::initialize() #define GET_OP_LIST #include "hello/Ops.cpp.inc" >(); + addInterfaces(); } using namespace mlir; @@ -25,43 +69,43 @@ static ParseResult parseBinaryOp(OpAsmParser& parser, OperationState& result) { SmallVector operands; - llvm::SMLoc operandsLoc = parser.getCurrentLocation(); - mlir::Type type; + SMLoc operandsLoc = parser.getCurrentLocation(); + Type type; if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type)) - return mlir::failure(); + return failure(); // If the type is a function type, it contains the input and result types of // this operation. - if (mlir::FunctionType funcType = llvm::dyn_cast(type)) + if (FunctionType funcType = llvm::dyn_cast(type)) { if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, result.operands)) - return mlir::failure(); + return failure(); result.addTypes(funcType.getResults()); - return mlir::success(); + return success(); } // Otherwise, the parsed type is the type of both operands and results. if (parser.resolveOperands(operands, type, result.operands)) - return mlir::failure(); + return failure(); result.addTypes(type); - return mlir::success(); + return success(); } /// A generalized printer for binary operations. It prints in two different /// forms depending on if all of the types match. -static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op) +static void printBinaryOp(OpAsmPrinter& printer, Operation* op) { printer << " " << op->getOperands(); printer.printOptionalAttrDict(op->getAttrs()); printer << " : "; // If all of the types are the same, print the type directly. - mlir::Type resultType = *op->result_type_begin(); + Type resultType = *op->result_type_begin(); if (llvm::all_of(op->getOperandTypes(), - [=](mlir::Type type) { return type == resultType; })) + [=](Type type) { return type == resultType; })) { printer << resultType; return; @@ -78,8 +122,8 @@ static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op) /// Build a constant operation. /// The builder is passed as an argument, so is the state that this method is /// expected to fill in order to build the operation. -void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state, - double value) +void ConstantOp::build(OpBuilder& builder, OperationState& state, + double value) { auto dataType = RankedTensorType::get({}, builder.getF64Type()); auto dataAttribute = DenseElementsAttr::get(dataType, value); @@ -93,10 +137,10 @@ void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state, /// or `false` on success. This allows for easily chaining together a set of /// parser rules. These rules are used to populate an `mlir::OperationState` /// similarly to the `build` methods described above. -mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser, - mlir::OperationState& result) +ParseResult ConstantOp::parse(OpAsmParser& parser, + OperationState& result) { - mlir::DenseElementsAttr value; + DenseElementsAttr value; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseAttribute(value, "value", result.attributes)) return failure(); @@ -107,7 +151,7 @@ mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser, /// The 'OpAsmPrinter' class is a stream that allows for formatting /// strings, attributes, operands, types, etc. -void ConstantOp::print(mlir::OpAsmPrinter& printer) +void ConstantOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); @@ -116,17 +160,17 @@ void ConstantOp::print(mlir::OpAsmPrinter& printer) /// Verifier for the constant operation. This corresponds to the /// `let hasVerifier = 1` in the op definition. -mlir::LogicalResult ConstantOp::verify() +LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast(getResult().getType()); + auto resultType = llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = llvm::cast(getValue().getType()); + auto attrType = llvm::cast(getValue().getType()); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " @@ -145,57 +189,82 @@ mlir::LogicalResult ConstantOp::verify() << " != " << resultType.getShape()[dim]; } } - return mlir::success(); + return success(); } //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// -void AddOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, - mlir::Value lhs, mlir::Value rhs) +void AddOp::build(OpBuilder& builder, OperationState& state, + Value lhs, Value rhs) { state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands({lhs, rhs}); } -mlir::ParseResult AddOp::parse(mlir::OpAsmParser& parser, - mlir::OperationState& result) +ParseResult AddOp::parse(OpAsmParser& parser, + OperationState& result) { return parseBinaryOp(parser, result); } -void AddOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); } +void AddOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); } + +void AddOp::inferShapes() +{ + getResult().setType(getLhs().getType()); +} //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// -void GenericCallOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, - StringRef callee, ArrayRef arguments) +void GenericCallOp::build(OpBuilder& builder, OperationState& state, + StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands(arguments); state.addAttribute("callee", - mlir::SymbolRefAttr::get(builder.getContext(), callee)); + SymbolRefAttr::get(builder.getContext(), callee)); +} + +CallInterfaceCallable GenericCallOp::getCallableForCallee() +{ + return (*this)->getAttrOfType("callee"); +} + +void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) +{ + (*this)->setAttr("callee", mlir::cast(callee)); +} + +Operation::operand_range GenericCallOp::getArgOperands() +{ + return getInputs(); +} + +MutableOperandRange GenericCallOp::getArgOperandsMutable() +{ + return getInputsMutable(); } //===----------------------------------------------------------------------===// // FuncOp //===----------------------------------------------------------------------===// -void FuncOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, - llvm::StringRef name, mlir::FunctionType type, - llvm::ArrayRef attrs) +void FuncOp::build(OpBuilder& builder, OperationState& state, + StringRef name, FunctionType type, + ArrayRef attrs) { // FunctionOpInterface provides a convenient `build` method that will populate // the state of our FuncOp, and create an entry block. buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); } -mlir::ParseResult FuncOp::parse(OpAsmParser& parser, - OperationState& result) +ParseResult FuncOp::parse(OpAsmParser& parser, + OperationState& result) { // Dispatch to the FunctionOpInterface provided utility method that parses the // function operation. @@ -208,17 +277,17 @@ mlir::ParseResult FuncOp::parse(OpAsmParser& parser, return builder.getFunctionType(argTypes, results); }; - return mlir::function_interface_impl::parseFunctionOp( + return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } -void FuncOp::print(mlir::OpAsmPrinter& p) +void FuncOp::print(OpAsmPrinter& p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp( + function_interface_impl::printFunctionOp( p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()); } @@ -227,26 +296,31 @@ void FuncOp::print(mlir::OpAsmPrinter& p) // MulOp //===----------------------------------------------------------------------===// -void MulOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, - mlir::Value lhs, mlir::Value rhs) +void MulOp::build(OpBuilder& builder, OperationState& state, + Value lhs, Value rhs) { state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands({lhs, rhs}); } -mlir::ParseResult MulOp::parse(mlir::OpAsmParser& parser, - mlir::OperationState& result) +ParseResult MulOp::parse(OpAsmParser& parser, + OperationState& result) { return parseBinaryOp(parser, result); } -void MulOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); } +void MulOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); } + +void MulOp::inferShapes() +{ + getResult().setType(getLhs().getType()); +} //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// -mlir::LogicalResult ReturnOp::verify() +LogicalResult ReturnOp::verify() { // We know that the parent operation is a function, because of the 'HasParent' // trait attached to the operation definition. @@ -265,15 +339,15 @@ mlir::LogicalResult ReturnOp::verify() // If the operation does not have an input, we are done. if (!hasOperand()) - return mlir::success(); + return success(); auto inputType = *operand_type_begin(); auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa(inputType) || - llvm::isa(resultType)) - return mlir::success(); + if (inputType == resultType || llvm::isa(inputType) || + llvm::isa(resultType)) + return success(); return emitError() << "type of return operand (" << inputType << ") doesn't match function result type (" << resultType @@ -284,19 +358,19 @@ mlir::LogicalResult ReturnOp::verify() // TransposeOp //===----------------------------------------------------------------------===// -void TransposeOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, - mlir::Value value) +void TransposeOp::build(OpBuilder& builder, OperationState& state, + Value value) { state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands(value); } -mlir::LogicalResult TransposeOp::verify() +LogicalResult TransposeOp::verify() { auto inputType = llvm::dyn_cast(getOperand().getType()); auto resultType = llvm::dyn_cast(getType()); if (!inputType || !resultType) - return mlir::success(); + return success(); auto inputShape = inputType.getShape(); if (!std::equal(inputShape.begin(), inputShape.end(), @@ -305,7 +379,43 @@ mlir::LogicalResult TransposeOp::verify() return emitError() << "expected result shape to be a transpose of the input"; } - return mlir::success(); + return success(); +} + +void TransposeOp::inferShapes() +{ + // Transpose will reverse the shape of tensor. + auto tensorType = llvm::cast(getOperand().getType()); + // And assume that transpose only applies for matrix. + SmallVector dimensions(llvm::reverse(tensorType.getShape())); + getResult().setType(RankedTensorType::get(dimensions, tensorType.getElementType())); +} + +// CastOp + +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) +{ + if (inputs.size() != 1 || outputs.size() != 1) + { + return false; + } + + const auto inputTensorType = mlir::dyn_cast(inputs.front()); + const auto outputTensorType = mlir::dyn_cast(outputs.front()); + + if (!inputTensorType || !outputTensorType || inputTensorType.getElementType() != outputTensorType.getElementType()) + { + return false; + } + + // If both have rank, they must be to the size. + // And the known size can be cast into unknown size. + return !inputTensorType.hasRank() || !outputTensorType.hasRank() || inputTensorType == outputTensorType; +} + +void CastOp::inferShapes() +{ + getResult().setType(getInput().getType()); } diff --git a/lib/MLIRGen.cpp b/lib/MLIRGen.cpp index 19d70fc..c53e91f 100644 --- a/lib/MLIRGen.cpp +++ b/lib/MLIRGen.cpp @@ -44,7 +44,9 @@ namespace theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); for (Function& f : moduleAST) + { mlirGen(f); + } // Verify the module after we have finished constructing it, this will check // the structural properties of the IR and invoke any specific verifiers we @@ -160,6 +162,13 @@ namespace function.getFunctionType().getInputs(), getType(ValueType{}))); } + // Jus set all functions except 'main' to private + // which is used to inline the other functions. + if (funcAST.getPrototype()->getName() != "main") + { + function.setPrivate(); + } + return function; } diff --git a/lib/ShapeInferencePass.cpp b/lib/ShapeInferencePass.cpp new file mode 100644 index 0000000..b5b6b9e --- /dev/null +++ b/lib/ShapeInferencePass.cpp @@ -0,0 +1,95 @@ +// +// Created by ricardo on 02/06/25. +// + +#include +#include + +#include "Dialect.h" +#include "Passes.h" + + +namespace mlir::hello +{ +#include "hello/ShapeInferenceInterface.cpp.inc" +} + + +#define DEBUG_TYPE "ShapeInference" + +namespace +{ + struct ShapeInferencePass : mlir::PassWrapper> + { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) + + + void runOnOperation() override + { + mlir::hello::FuncOp operation = getOperation(); + llvm::SmallPtrSet opWorkList; + + operation.walk([&](mlir::Operation* op) + { + if (isDynamicShapes(op)) + { + opWorkList.insert(op); + } + }); + + while (!opWorkList.empty()) + { + auto nextOperation = llvm::find_if(opWorkList, isOperationInferred); + if (nextOperation == opWorkList.end()) + { + break; + } + + mlir::Operation* op = *nextOperation; + opWorkList.erase(op); + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + + if (auto shapeInference = mlir::dyn_cast(op)) + { + shapeInference.inferShapes(); + } + else + { + op->emitError( + std::string("Failed to inference shape for operation '") + op->getName().getIdentifier().str() + + "' without shape inference interface."); + signalPassFailure(); + return; + } + } + + if (!opWorkList.empty()) + { + operation.emitError("Failed to inference shape, ") << opWorkList.size() << + " operations failed to inference.\n"; + signalPassFailure(); + } + } + + static bool isOperationInferred(mlir::Operation* op) + { + return llvm::all_of(op->getOperandTypes(), [](mlir::Type operandType) + { + return llvm::isa(operandType); + }); + } + + static bool isDynamicShapes(mlir::Operation* op) + { + return llvm::any_of(op->getResultTypes(), [](mlir::Type operandType) + { + return !llvm::isa(operandType); + }); + } + }; +} + +std::unique_ptr mlir::hello::createShapeInferencePass() +{ + return std::make_unique(); +} diff --git a/main.cpp b/main.cpp index 8d51c76..0ae2363 100644 --- a/main.cpp +++ b/main.cpp @@ -12,6 +12,7 @@ #include "Dialect.h" #include "MLIRGen.h" +#include "Passes.h" namespace mlir { @@ -119,7 +120,13 @@ int dumpMLIR() return 1; } - manager.addNestedPass(mlir::createCanonicalizerPass()); + // To inline all functions except 'main' function. + manager.addPass(mlir::createInlinerPass()); + // In the canonicalizer pass, we add Transpose Pass and Reshape Pass. + mlir::OpPassManager& functionPassManager = manager.nest(); + functionPassManager.addPass(mlir::createCanonicalizerPass()); + functionPassManager.addPass(mlir::createCSEPass()); + functionPassManager.addPass(mlir::hello::createShapeInferencePass()); if (mlir::failed(manager.run(*module))) {