feat: toy tutorial chapter 4.

Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
jackfiled 2025-06-03 16:03:17 +08:00
parent eacf20fe3c
commit 902915a57b
Signed by: jackfiled
GPG Key ID: 5F7234760472A46A
12 changed files with 380 additions and 68 deletions

View File

@ -35,23 +35,30 @@ mlir_tablegen(HelloCombine.inc -gen-rewriters)
include_directories(${CMAKE_BINARY_DIR})
add_public_tablegen_target(HelloCombineIncGen)
add_library(SyntaxNode STATIC
add_library(HelloDialect STATIC
lib/SyntaxNode.cpp
lib/Dialect.cpp
lib/MLIRGen.cpp
lib/HelloCombine.cpp
lib/ShapeInferencePass.cpp
include/SyntaxNode.h
include/Parser.h
include/Lexer.h
include/Dialect.h
include/MLIRGen.h
include/Passes.h
)
add_dependencies(SyntaxNode HelloOpsIncGen HelloCombineIncGen)
add_dependencies(HelloDialect HelloOpsIncGen HelloCombineIncGen HelloInterfaceIncGen)
target_link_libraries(SyntaxNode
target_link_libraries(HelloDialect
PRIVATE
MLIRSupport
MLIRAnalysis
MLIRFunctionInterfaces
MLIRCallInterfaces
MLIRCastInterfaces
MLIRIR
MLIRParser
MLIRSideEffectInterfaces
@ -61,6 +68,6 @@ add_executable(hello-mlir main.cpp)
target_link_libraries(hello-mlir
PRIVATE
SyntaxNode
HelloDialect
LLVMSupport
LLVMCore)

View File

@ -16,11 +16,5 @@ def main() {
# reuse the previously specialized and inferred version and return <3, 2>.
var d = multiply_transpose(b, a);
# A new call with <3, 2> (instead of <2, 3>) for both dimensions will
# trigger another specialization of `multiply_transpose`.
var e = multiply_transpose(c, d);
# Finally, calling into `multiply_transpose` with incompatible shapes
# (<2, 3> and <3, 2>) will trigger a shape inference error.
var f = multiply_transpose(a, c);
print(d);
}

View File

@ -9,8 +9,10 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "hello/ShapeInferenceInterface.h"
/// Include the auto-generated header file containing the declaration of the toy
/// dialect.

20
include/Passes.h Normal file
View File

@ -0,0 +1,20 @@
//
// Created by ricardo on 02/06/25.
//
#ifndef PASSES_H
#define PASSES_H
#include <memory>
namespace mlir
{
class Pass;
namespace hello
{
std::unique_ptr<Pass> createShapeInferencePass();
}
}
#endif //PASSES_H

View File

@ -4,3 +4,8 @@ mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
add_public_tablegen_target(HelloOpsIncGen)
set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
mlir_tablegen(ShapeInferenceInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(ShapeInferenceInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(HelloInterfaceIncGen)

View File

@ -5,12 +5,17 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "hello/ShapeInferenceInterface.td"
def Hello_Dialect : Dialect {
let name = "hello";
let cppNamespace = "::mlir::hello";
}
class Hello_Op<string mnemonic, list<Trait> traits = []> : Op<Hello_Dialect, mnemonic, traits>;
@ -70,7 +75,7 @@ def ConstantOp : Hello_Op<"constant", [Pure]> {
// AddOp
//===----------------------------------------------------------------------===//
def AddOp : Hello_Op<"add"> {
def AddOp : Hello_Op<"add", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise addition operation";
let description = [{
The "add" operation performs element-wise addition between two tensors.
@ -148,7 +153,8 @@ def FuncOp : Hello_Op<"func", [
// GenericCallOp
//===----------------------------------------------------------------------===//
def GenericCallOp : Hello_Op<"generic_call"> {
def GenericCallOp : Hello_Op<"generic_call",
[DeclareOpInterfaceMethods<CallOpInterface>]> {
let summary = "generic call operation";
let description = [{
Generic calls represent calls to a user defined function that needs to
@ -187,7 +193,7 @@ def GenericCallOp : Hello_Op<"generic_call"> {
// MulOp
//===----------------------------------------------------------------------===//
def MulOp : Hello_Op<"mul"> {
def MulOp : Hello_Op<"mul", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise multiplication operation";
let description = [{
The "mul" operation performs element-wise multiplication between two
@ -296,7 +302,7 @@ def ReturnOp : Hello_Op<"return", [Pure, HasParent<"FuncOp">,
// TransposeOp
//===----------------------------------------------------------------------===//
def TransposeOp : Hello_Op<"transpose", [Pure]> {
def TransposeOp : Hello_Op<"transpose", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "transpose operation";
let arguments = (ins F64Tensor:$input);
@ -316,4 +322,25 @@ def TransposeOp : Hello_Op<"transpose", [Pure]> {
let hasCanonicalizer = 1;
}
def CastOp : Hello_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
Pure,
SameOperandsAndResultShape
]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
without changing any data elements. The source and destination types
must both be tensor types with the same element type. If both are ranked,
then shape is required to match. The operation is invalid if converting
to a mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}
#endif

View File

@ -0,0 +1,18 @@
//
// Created by ricardo on 02/06/25.
//
#ifndef SHAPEINFERENCEINTERFACE_H
#define SHAPEINFERENCEINTERFACE_H
#include "mlir/IR/OpDefinition.h"
namespace mlir
{
namespace hello
{
#include "hello/ShapeInferenceInterface.h.inc"
}
}
#endif //SHAPEINFERENCEINTERFACE_H

View File

@ -0,0 +1,18 @@
#ifndef SHAPE_INFERENCE_INTERFACE
#define SHAPE_INFERENCE_INTERFACE
include "mlir/IR/OpBase.td"
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
}];
let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.",
"void", "inferShapes">
];
}
#endif

View File

@ -6,6 +6,49 @@
#include "hello/Dialect.cpp.inc"
#include <mlir/Interfaces/FunctionImplementation.h>
#include <mlir/Transforms/InliningUtils.h>
#include <oneapi/tbb/detail/_template_helpers.h>
struct HelloDialectInlinerInterface : mlir::DialectInlinerInterface
{
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable, bool wouldBeCloned) const override
{
return true;
}
bool isLegalToInline(mlir::Region* dest, mlir::Region* src, bool wouldBeCloned,
mlir::IRMapping& valueMapping) const override
{
return true;
}
bool isLegalToInline(mlir::Operation* op, mlir::Region* dest, bool wouldBeCloned,
mlir::IRMapping& valueMapping) const override
{
return true;
}
void handleTerminator(mlir::Operation* op, mlir::ValueRange returnValues) const override
{
// Only the `hello.returnOp` is the function terminator
auto returnOp = llvm::cast<mlir::hello::ReturnOp>(op);
assert(returnOp.getNumOperands() == returnValues.size());
for (const auto& it : llvm::enumerate(returnOp.getOperands()))
{
returnValues[it.index()].replaceAllUsesWith(it.value());
}
}
mlir::Operation* materializeCallConversion(mlir::OpBuilder& builder, mlir::Value input, mlir::Type resultType,
mlir::Location conversionLoc) const override
{
return builder.create<mlir::hello::CastOp>(conversionLoc, resultType, input);
}
};
void mlir::hello::HelloDialect::initialize()
{
@ -13,6 +56,7 @@ void mlir::hello::HelloDialect::initialize()
#define GET_OP_LIST
#include "hello/Ops.cpp.inc"
>();
addInterfaces<HelloDialectInlinerInterface>();
}
using namespace mlir;
@ -25,43 +69,43 @@ static ParseResult parseBinaryOp(OpAsmParser& parser,
OperationState& result)
{
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
llvm::SMLoc operandsLoc = parser.getCurrentLocation();
mlir::Type type;
SMLoc operandsLoc = parser.getCurrentLocation();
Type type;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type))
return mlir::failure();
return failure();
// If the type is a function type, it contains the input and result types of
// this operation.
if (mlir::FunctionType funcType = llvm::dyn_cast<mlir::FunctionType>(type))
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type))
{
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
return failure();
result.addTypes(funcType.getResults());
return mlir::success();
return success();
}
// Otherwise, the parsed type is the type of both operands and results.
if (parser.resolveOperands(operands, type, result.operands))
return mlir::failure();
return failure();
result.addTypes(type);
return mlir::success();
return success();
}
/// A generalized printer for binary operations. It prints in two different
/// forms depending on if all of the types match.
static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op)
static void printBinaryOp(OpAsmPrinter& printer, Operation* op)
{
printer << " " << op->getOperands();
printer.printOptionalAttrDict(op->getAttrs());
printer << " : ";
// If all of the types are the same, print the type directly.
mlir::Type resultType = *op->result_type_begin();
Type resultType = *op->result_type_begin();
if (llvm::all_of(op->getOperandTypes(),
[=](mlir::Type type) { return type == resultType; }))
[=](Type type) { return type == resultType; }))
{
printer << resultType;
return;
@ -78,8 +122,8 @@ static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op)
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state,
double value)
void ConstantOp::build(OpBuilder& builder, OperationState& state,
double value)
{
auto dataType = RankedTensorType::get({}, builder.getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
@ -93,10 +137,10 @@ void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state,
/// or `false` on success. This allows for easily chaining together a set of
/// parser rules. These rules are used to populate an `mlir::OperationState`
/// similarly to the `build` methods described above.
mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser,
mlir::OperationState& result)
ParseResult ConstantOp::parse(OpAsmParser& parser,
OperationState& result)
{
mlir::DenseElementsAttr value;
DenseElementsAttr value;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(value, "value", result.attributes))
return failure();
@ -107,7 +151,7 @@ mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser,
/// The 'OpAsmPrinter' class is a stream that allows for formatting
/// strings, attributes, operands, types, etc.
void ConstantOp::print(mlir::OpAsmPrinter& printer)
void ConstantOp::print(OpAsmPrinter& printer)
{
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
@ -116,17 +160,17 @@ void ConstantOp::print(mlir::OpAsmPrinter& printer)
/// Verifier for the constant operation. This corresponds to the
/// `let hasVerifier = 1` in the op definition.
mlir::LogicalResult ConstantOp::verify()
LogicalResult ConstantOp::verify()
{
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
auto resultType = llvm::dyn_cast<RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
auto attrType = llvm::cast<RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank())
{
return emitOpError("return type must match the one of the attached value "
@ -145,57 +189,82 @@ mlir::LogicalResult ConstantOp::verify()
<< " != " << resultType.getShape()[dim];
}
}
return mlir::success();
return success();
}
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
void AddOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
mlir::Value lhs, mlir::Value rhs)
void AddOp::build(OpBuilder& builder, OperationState& state,
Value lhs, Value rhs)
{
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands({lhs, rhs});
}
mlir::ParseResult AddOp::parse(mlir::OpAsmParser& parser,
mlir::OperationState& result)
ParseResult AddOp::parse(OpAsmParser& parser,
OperationState& result)
{
return parseBinaryOp(parser, result);
}
void AddOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); }
void AddOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); }
void AddOp::inferShapes()
{
getResult().setType(getLhs().getType());
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
void GenericCallOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
StringRef callee, ArrayRef<mlir::Value> arguments)
void GenericCallOp::build(OpBuilder& builder, OperationState& state,
StringRef callee, ArrayRef<Value> arguments)
{
// Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands(arguments);
state.addAttribute("callee",
mlir::SymbolRefAttr::get(builder.getContext(), callee));
SymbolRefAttr::get(builder.getContext(), callee));
}
CallInterfaceCallable GenericCallOp::getCallableForCallee()
{
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee)
{
(*this)->setAttr("callee", mlir::cast<SymbolRefAttr>(callee));
}
Operation::operand_range GenericCallOp::getArgOperands()
{
return getInputs();
}
MutableOperandRange GenericCallOp::getArgOperandsMutable()
{
return getInputsMutable();
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
void FuncOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs)
void FuncOp::build(OpBuilder& builder, OperationState& state,
StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
{
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}
mlir::ParseResult FuncOp::parse(OpAsmParser& parser,
OperationState& result)
ParseResult FuncOp::parse(OpAsmParser& parser,
OperationState& result)
{
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
@ -208,17 +277,17 @@ mlir::ParseResult FuncOp::parse(OpAsmParser& parser,
return builder.getFunctionType(argTypes, results);
};
return mlir::function_interface_impl::parseFunctionOp(
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(mlir::OpAsmPrinter& p)
void FuncOp::print(OpAsmPrinter& p)
{
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}
@ -227,26 +296,31 @@ void FuncOp::print(mlir::OpAsmPrinter& p)
// MulOp
//===----------------------------------------------------------------------===//
void MulOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
mlir::Value lhs, mlir::Value rhs)
void MulOp::build(OpBuilder& builder, OperationState& state,
Value lhs, Value rhs)
{
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands({lhs, rhs});
}
mlir::ParseResult MulOp::parse(mlir::OpAsmParser& parser,
mlir::OperationState& result)
ParseResult MulOp::parse(OpAsmParser& parser,
OperationState& result)
{
return parseBinaryOp(parser, result);
}
void MulOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); }
void MulOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); }
void MulOp::inferShapes()
{
getResult().setType(getLhs().getType());
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
mlir::LogicalResult ReturnOp::verify()
LogicalResult ReturnOp::verify()
{
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
@ -265,15 +339,15 @@ mlir::LogicalResult ReturnOp::verify()
// If the operation does not have an input, we are done.
if (!hasOperand())
return mlir::success();
return success();
auto inputType = *operand_type_begin();
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
if (inputType == resultType || llvm::isa<UnrankedTensorType>(inputType) ||
llvm::isa<UnrankedTensorType>(resultType))
return success();
return emitError() << "type of return operand (" << inputType
<< ") doesn't match function result type (" << resultType
@ -284,19 +358,19 @@ mlir::LogicalResult ReturnOp::verify()
// TransposeOp
//===----------------------------------------------------------------------===//
void TransposeOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
mlir::Value value)
void TransposeOp::build(OpBuilder& builder, OperationState& state,
Value value)
{
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands(value);
}
mlir::LogicalResult TransposeOp::verify()
LogicalResult TransposeOp::verify()
{
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
return success();
auto inputShape = inputType.getShape();
if (!std::equal(inputShape.begin(), inputShape.end(),
@ -305,7 +379,43 @@ mlir::LogicalResult TransposeOp::verify()
return emitError()
<< "expected result shape to be a transpose of the input";
}
return mlir::success();
return success();
}
void TransposeOp::inferShapes()
{
// Transpose will reverse the shape of tensor.
auto tensorType = llvm::cast<RankedTensorType>(getOperand().getType());
// And assume that transpose only applies for matrix.
SmallVector<int64_t, 2> dimensions(llvm::reverse(tensorType.getShape()));
getResult().setType(RankedTensorType::get(dimensions, tensorType.getElementType()));
}
// CastOp
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs)
{
if (inputs.size() != 1 || outputs.size() != 1)
{
return false;
}
const auto inputTensorType = mlir::dyn_cast<TensorType>(inputs.front());
const auto outputTensorType = mlir::dyn_cast<TensorType>(outputs.front());
if (!inputTensorType || !outputTensorType || inputTensorType.getElementType() != outputTensorType.getElementType())
{
return false;
}
// If both have rank, they must be to the size.
// And the known size can be cast into unknown size.
return !inputTensorType.hasRank() || !outputTensorType.hasRank() || inputTensorType == outputTensorType;
}
void CastOp::inferShapes()
{
getResult().setType(getInput().getType());
}

View File

@ -44,7 +44,9 @@ namespace
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (Function& f : moduleAST)
{
mlirGen(f);
}
// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
@ -160,6 +162,13 @@ namespace
function.getFunctionType().getInputs(), getType(ValueType{})));
}
// Jus set all functions except 'main' to private
// which is used to inline the other functions.
if (funcAST.getPrototype()->getName() != "main")
{
function.setPrivate();
}
return function;
}

View File

@ -0,0 +1,95 @@
//
// Created by ricardo on 02/06/25.
//
#include <llvm/Support/Debug.h>
#include <mlir/Pass/Pass.h>
#include "Dialect.h"
#include "Passes.h"
namespace mlir::hello
{
#include "hello/ShapeInferenceInterface.cpp.inc"
}
#define DEBUG_TYPE "ShapeInference"
namespace
{
struct ShapeInferencePass : mlir::PassWrapper<ShapeInferencePass, mlir::OperationPass<mlir::hello::FuncOp>>
{
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass)
void runOnOperation() override
{
mlir::hello::FuncOp operation = getOperation();
llvm::SmallPtrSet<mlir::Operation*, 16> opWorkList;
operation.walk([&](mlir::Operation* op)
{
if (isDynamicShapes(op))
{
opWorkList.insert(op);
}
});
while (!opWorkList.empty())
{
auto nextOperation = llvm::find_if(opWorkList, isOperationInferred);
if (nextOperation == opWorkList.end())
{
break;
}
mlir::Operation* op = *nextOperation;
opWorkList.erase(op);
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
if (auto shapeInference = mlir::dyn_cast<mlir::hello::ShapeInference>(op))
{
shapeInference.inferShapes();
}
else
{
op->emitError(
std::string("Failed to inference shape for operation '") + op->getName().getIdentifier().str() +
"' without shape inference interface.");
signalPassFailure();
return;
}
}
if (!opWorkList.empty())
{
operation.emitError("Failed to inference shape, ") << opWorkList.size() <<
" operations failed to inference.\n";
signalPassFailure();
}
}
static bool isOperationInferred(mlir::Operation* op)
{
return llvm::all_of(op->getOperandTypes(), [](mlir::Type operandType)
{
return llvm::isa<mlir::RankedTensorType>(operandType);
});
}
static bool isDynamicShapes(mlir::Operation* op)
{
return llvm::any_of(op->getResultTypes(), [](mlir::Type operandType)
{
return !llvm::isa<mlir::RankedTensorType>(operandType);
});
}
};
}
std::unique_ptr<mlir::Pass> mlir::hello::createShapeInferencePass()
{
return std::make_unique<ShapeInferencePass>();
}

View File

@ -12,6 +12,7 @@
#include "Dialect.h"
#include "MLIRGen.h"
#include "Passes.h"
namespace mlir
{
@ -119,7 +120,13 @@ int dumpMLIR()
return 1;
}
manager.addNestedPass<mlir::hello::FuncOp>(mlir::createCanonicalizerPass());
// To inline all functions except 'main' function.
manager.addPass(mlir::createInlinerPass());
// In the canonicalizer pass, we add Transpose Pass and Reshape Pass.
mlir::OpPassManager& functionPassManager = manager.nest<mlir::hello::FuncOp>();
functionPassManager.addPass(mlir::createCanonicalizerPass());
functionPassManager.addPass(mlir::createCSEPass());
functionPassManager.addPass(mlir::hello::createShapeInferencePass());
if (mlir::failed(manager.run(*module)))
{