feat: toy tutorial chapter 4.

Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
jackfiled 2025-06-03 16:03:17 +08:00
parent eacf20fe3c
commit 902915a57b
Signed by: jackfiled
GPG Key ID: 5F7234760472A46A
12 changed files with 380 additions and 68 deletions

View File

@ -35,23 +35,30 @@ mlir_tablegen(HelloCombine.inc -gen-rewriters)
include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR})
add_public_tablegen_target(HelloCombineIncGen) add_public_tablegen_target(HelloCombineIncGen)
add_library(SyntaxNode STATIC add_library(HelloDialect STATIC
lib/SyntaxNode.cpp lib/SyntaxNode.cpp
lib/Dialect.cpp lib/Dialect.cpp
lib/MLIRGen.cpp lib/MLIRGen.cpp
lib/HelloCombine.cpp lib/HelloCombine.cpp
lib/ShapeInferencePass.cpp
include/SyntaxNode.h include/SyntaxNode.h
include/Parser.h include/Parser.h
include/Lexer.h include/Lexer.h
include/Dialect.h
include/MLIRGen.h
include/Passes.h
) )
add_dependencies(SyntaxNode HelloOpsIncGen HelloCombineIncGen) add_dependencies(HelloDialect HelloOpsIncGen HelloCombineIncGen HelloInterfaceIncGen)
target_link_libraries(SyntaxNode target_link_libraries(HelloDialect
PRIVATE PRIVATE
MLIRSupport MLIRSupport
MLIRAnalysis MLIRAnalysis
MLIRFunctionInterfaces MLIRFunctionInterfaces
MLIRCallInterfaces
MLIRCastInterfaces
MLIRIR MLIRIR
MLIRParser MLIRParser
MLIRSideEffectInterfaces MLIRSideEffectInterfaces
@ -61,6 +68,6 @@ add_executable(hello-mlir main.cpp)
target_link_libraries(hello-mlir target_link_libraries(hello-mlir
PRIVATE PRIVATE
SyntaxNode HelloDialect
LLVMSupport LLVMSupport
LLVMCore) LLVMCore)

View File

@ -16,11 +16,5 @@ def main() {
# reuse the previously specialized and inferred version and return <3, 2>. # reuse the previously specialized and inferred version and return <3, 2>.
var d = multiply_transpose(b, a); var d = multiply_transpose(b, a);
# A new call with <3, 2> (instead of <2, 3>) for both dimensions will print(d);
# trigger another specialization of `multiply_transpose`.
var e = multiply_transpose(c, d);
# Finally, calling into `multiply_transpose` with incompatible shapes
# (<2, 3> and <3, 2>) will trigger a shape inference error.
var f = multiply_transpose(a, c);
} }

View File

@ -9,8 +9,10 @@
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "hello/ShapeInferenceInterface.h"
/// Include the auto-generated header file containing the declaration of the toy /// Include the auto-generated header file containing the declaration of the toy
/// dialect. /// dialect.

20
include/Passes.h Normal file
View File

@ -0,0 +1,20 @@
//
// Created by ricardo on 02/06/25.
//
#ifndef PASSES_H
#define PASSES_H
#include <memory>
namespace mlir
{
class Pass;
namespace hello
{
std::unique_ptr<Pass> createShapeInferencePass();
}
}
#endif //PASSES_H

View File

@ -4,3 +4,8 @@ mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls) mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
add_public_tablegen_target(HelloOpsIncGen) add_public_tablegen_target(HelloOpsIncGen)
set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
mlir_tablegen(ShapeInferenceInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(ShapeInferenceInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(HelloInterfaceIncGen)

View File

@ -5,12 +5,17 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "hello/ShapeInferenceInterface.td"
def Hello_Dialect : Dialect { def Hello_Dialect : Dialect {
let name = "hello"; let name = "hello";
let cppNamespace = "::mlir::hello"; let cppNamespace = "::mlir::hello";
} }
class Hello_Op<string mnemonic, list<Trait> traits = []> : Op<Hello_Dialect, mnemonic, traits>; class Hello_Op<string mnemonic, list<Trait> traits = []> : Op<Hello_Dialect, mnemonic, traits>;
@ -70,7 +75,7 @@ def ConstantOp : Hello_Op<"constant", [Pure]> {
// AddOp // AddOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def AddOp : Hello_Op<"add"> { def AddOp : Hello_Op<"add", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise addition operation"; let summary = "element-wise addition operation";
let description = [{ let description = [{
The "add" operation performs element-wise addition between two tensors. The "add" operation performs element-wise addition between two tensors.
@ -148,7 +153,8 @@ def FuncOp : Hello_Op<"func", [
// GenericCallOp // GenericCallOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def GenericCallOp : Hello_Op<"generic_call"> { def GenericCallOp : Hello_Op<"generic_call",
[DeclareOpInterfaceMethods<CallOpInterface>]> {
let summary = "generic call operation"; let summary = "generic call operation";
let description = [{ let description = [{
Generic calls represent calls to a user defined function that needs to Generic calls represent calls to a user defined function that needs to
@ -187,7 +193,7 @@ def GenericCallOp : Hello_Op<"generic_call"> {
// MulOp // MulOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def MulOp : Hello_Op<"mul"> { def MulOp : Hello_Op<"mul", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise multiplication operation"; let summary = "element-wise multiplication operation";
let description = [{ let description = [{
The "mul" operation performs element-wise multiplication between two The "mul" operation performs element-wise multiplication between two
@ -296,7 +302,7 @@ def ReturnOp : Hello_Op<"return", [Pure, HasParent<"FuncOp">,
// TransposeOp // TransposeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def TransposeOp : Hello_Op<"transpose", [Pure]> { def TransposeOp : Hello_Op<"transpose", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "transpose operation"; let summary = "transpose operation";
let arguments = (ins F64Tensor:$input); let arguments = (ins F64Tensor:$input);
@ -316,4 +322,25 @@ def TransposeOp : Hello_Op<"transpose", [Pure]> {
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
def CastOp : Hello_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
Pure,
SameOperandsAndResultShape
]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
without changing any data elements. The source and destination types
must both be tensor types with the same element type. If both are ranked,
then shape is required to match. The operation is invalid if converting
to a mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}
#endif #endif

View File

@ -0,0 +1,18 @@
//
// Created by ricardo on 02/06/25.
//
#ifndef SHAPEINFERENCEINTERFACE_H
#define SHAPEINFERENCEINTERFACE_H
#include "mlir/IR/OpDefinition.h"
namespace mlir
{
namespace hello
{
#include "hello/ShapeInferenceInterface.h.inc"
}
}
#endif //SHAPEINFERENCEINTERFACE_H

View File

@ -0,0 +1,18 @@
#ifndef SHAPE_INFERENCE_INTERFACE
#define SHAPE_INFERENCE_INTERFACE
include "mlir/IR/OpBase.td"
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
}];
let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.",
"void", "inferShapes">
];
}
#endif

View File

@ -6,6 +6,49 @@
#include "hello/Dialect.cpp.inc" #include "hello/Dialect.cpp.inc"
#include <mlir/Interfaces/FunctionImplementation.h> #include <mlir/Interfaces/FunctionImplementation.h>
#include <mlir/Transforms/InliningUtils.h>
#include <oneapi/tbb/detail/_template_helpers.h>
struct HelloDialectInlinerInterface : mlir::DialectInlinerInterface
{
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable, bool wouldBeCloned) const override
{
return true;
}
bool isLegalToInline(mlir::Region* dest, mlir::Region* src, bool wouldBeCloned,
mlir::IRMapping& valueMapping) const override
{
return true;
}
bool isLegalToInline(mlir::Operation* op, mlir::Region* dest, bool wouldBeCloned,
mlir::IRMapping& valueMapping) const override
{
return true;
}
void handleTerminator(mlir::Operation* op, mlir::ValueRange returnValues) const override
{
// Only the `hello.returnOp` is the function terminator
auto returnOp = llvm::cast<mlir::hello::ReturnOp>(op);
assert(returnOp.getNumOperands() == returnValues.size());
for (const auto& it : llvm::enumerate(returnOp.getOperands()))
{
returnValues[it.index()].replaceAllUsesWith(it.value());
}
}
mlir::Operation* materializeCallConversion(mlir::OpBuilder& builder, mlir::Value input, mlir::Type resultType,
mlir::Location conversionLoc) const override
{
return builder.create<mlir::hello::CastOp>(conversionLoc, resultType, input);
}
};
void mlir::hello::HelloDialect::initialize() void mlir::hello::HelloDialect::initialize()
{ {
@ -13,6 +56,7 @@ void mlir::hello::HelloDialect::initialize()
#define GET_OP_LIST #define GET_OP_LIST
#include "hello/Ops.cpp.inc" #include "hello/Ops.cpp.inc"
>(); >();
addInterfaces<HelloDialectInlinerInterface>();
} }
using namespace mlir; using namespace mlir;
@ -25,43 +69,43 @@ static ParseResult parseBinaryOp(OpAsmParser& parser,
OperationState& result) OperationState& result)
{ {
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
llvm::SMLoc operandsLoc = parser.getCurrentLocation(); SMLoc operandsLoc = parser.getCurrentLocation();
mlir::Type type; Type type;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type)) parser.parseColonType(type))
return mlir::failure(); return failure();
// If the type is a function type, it contains the input and result types of // If the type is a function type, it contains the input and result types of
// this operation. // this operation.
if (mlir::FunctionType funcType = llvm::dyn_cast<mlir::FunctionType>(type)) if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type))
{ {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands)) result.operands))
return mlir::failure(); return failure();
result.addTypes(funcType.getResults()); result.addTypes(funcType.getResults());
return mlir::success(); return success();
} }
// Otherwise, the parsed type is the type of both operands and results. // Otherwise, the parsed type is the type of both operands and results.
if (parser.resolveOperands(operands, type, result.operands)) if (parser.resolveOperands(operands, type, result.operands))
return mlir::failure(); return failure();
result.addTypes(type); result.addTypes(type);
return mlir::success(); return success();
} }
/// A generalized printer for binary operations. It prints in two different /// A generalized printer for binary operations. It prints in two different
/// forms depending on if all of the types match. /// forms depending on if all of the types match.
static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op) static void printBinaryOp(OpAsmPrinter& printer, Operation* op)
{ {
printer << " " << op->getOperands(); printer << " " << op->getOperands();
printer.printOptionalAttrDict(op->getAttrs()); printer.printOptionalAttrDict(op->getAttrs());
printer << " : "; printer << " : ";
// If all of the types are the same, print the type directly. // If all of the types are the same, print the type directly.
mlir::Type resultType = *op->result_type_begin(); Type resultType = *op->result_type_begin();
if (llvm::all_of(op->getOperandTypes(), if (llvm::all_of(op->getOperandTypes(),
[=](mlir::Type type) { return type == resultType; })) [=](Type type) { return type == resultType; }))
{ {
printer << resultType; printer << resultType;
return; return;
@ -78,7 +122,7 @@ static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op)
/// Build a constant operation. /// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is /// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation. /// expected to fill in order to build the operation.
void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state, void ConstantOp::build(OpBuilder& builder, OperationState& state,
double value) double value)
{ {
auto dataType = RankedTensorType::get({}, builder.getF64Type()); auto dataType = RankedTensorType::get({}, builder.getF64Type());
@ -93,10 +137,10 @@ void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state,
/// or `false` on success. This allows for easily chaining together a set of /// or `false` on success. This allows for easily chaining together a set of
/// parser rules. These rules are used to populate an `mlir::OperationState` /// parser rules. These rules are used to populate an `mlir::OperationState`
/// similarly to the `build` methods described above. /// similarly to the `build` methods described above.
mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser, ParseResult ConstantOp::parse(OpAsmParser& parser,
mlir::OperationState& result) OperationState& result)
{ {
mlir::DenseElementsAttr value; DenseElementsAttr value;
if (parser.parseOptionalAttrDict(result.attributes) || if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(value, "value", result.attributes)) parser.parseAttribute(value, "value", result.attributes))
return failure(); return failure();
@ -107,7 +151,7 @@ mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser,
/// The 'OpAsmPrinter' class is a stream that allows for formatting /// The 'OpAsmPrinter' class is a stream that allows for formatting
/// strings, attributes, operands, types, etc. /// strings, attributes, operands, types, etc.
void ConstantOp::print(mlir::OpAsmPrinter& printer) void ConstantOp::print(OpAsmPrinter& printer)
{ {
printer << " "; printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
@ -116,17 +160,17 @@ void ConstantOp::print(mlir::OpAsmPrinter& printer)
/// Verifier for the constant operation. This corresponds to the /// Verifier for the constant operation. This corresponds to the
/// `let hasVerifier = 1` in the op definition. /// `let hasVerifier = 1` in the op definition.
mlir::LogicalResult ConstantOp::verify() LogicalResult ConstantOp::verify()
{ {
// If the return type of the constant is not an unranked tensor, the shape // If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data. // must match the shape of the attribute holding the data.
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType()); auto resultType = llvm::dyn_cast<RankedTensorType>(getResult().getType());
if (!resultType) if (!resultType)
return success(); return success();
// Check that the rank of the attribute type matches the rank of the constant // Check that the rank of the attribute type matches the rank of the constant
// result type. // result type.
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType()); auto attrType = llvm::cast<RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) if (attrType.getRank() != resultType.getRank())
{ {
return emitOpError("return type must match the one of the attached value " return emitOpError("return type must match the one of the attached value "
@ -145,56 +189,81 @@ mlir::LogicalResult ConstantOp::verify()
<< " != " << resultType.getShape()[dim]; << " != " << resultType.getShape()[dim];
} }
} }
return mlir::success(); return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AddOp // AddOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void AddOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, void AddOp::build(OpBuilder& builder, OperationState& state,
mlir::Value lhs, mlir::Value rhs) Value lhs, Value rhs)
{ {
state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands({lhs, rhs}); state.addOperands({lhs, rhs});
} }
mlir::ParseResult AddOp::parse(mlir::OpAsmParser& parser, ParseResult AddOp::parse(OpAsmParser& parser,
mlir::OperationState& result) OperationState& result)
{ {
return parseBinaryOp(parser, result); return parseBinaryOp(parser, result);
} }
void AddOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); } void AddOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); }
void AddOp::inferShapes()
{
getResult().setType(getLhs().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// GenericCallOp // GenericCallOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void GenericCallOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, void GenericCallOp::build(OpBuilder& builder, OperationState& state,
StringRef callee, ArrayRef<mlir::Value> arguments) StringRef callee, ArrayRef<Value> arguments)
{ {
// Generic call always returns an unranked Tensor initially. // Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands(arguments); state.addOperands(arguments);
state.addAttribute("callee", state.addAttribute("callee",
mlir::SymbolRefAttr::get(builder.getContext(), callee)); SymbolRefAttr::get(builder.getContext(), callee));
}
CallInterfaceCallable GenericCallOp::getCallableForCallee()
{
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee)
{
(*this)->setAttr("callee", mlir::cast<SymbolRefAttr>(callee));
}
Operation::operand_range GenericCallOp::getArgOperands()
{
return getInputs();
}
MutableOperandRange GenericCallOp::getArgOperandsMutable()
{
return getInputsMutable();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// FuncOp // FuncOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void FuncOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, void FuncOp::build(OpBuilder& builder, OperationState& state,
llvm::StringRef name, mlir::FunctionType type, StringRef name, FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) ArrayRef<NamedAttribute> attrs)
{ {
// FunctionOpInterface provides a convenient `build` method that will populate // FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block. // the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
} }
mlir::ParseResult FuncOp::parse(OpAsmParser& parser, ParseResult FuncOp::parse(OpAsmParser& parser,
OperationState& result) OperationState& result)
{ {
// Dispatch to the FunctionOpInterface provided utility method that parses the // Dispatch to the FunctionOpInterface provided utility method that parses the
@ -208,17 +277,17 @@ mlir::ParseResult FuncOp::parse(OpAsmParser& parser,
return builder.getFunctionType(argTypes, results); return builder.getFunctionType(argTypes, results);
}; };
return mlir::function_interface_impl::parseFunctionOp( return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType, getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
} }
void FuncOp::print(mlir::OpAsmPrinter& p) void FuncOp::print(OpAsmPrinter& p)
{ {
// Dispatch to the FunctionOpInterface provided utility method that prints the // Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation. // function operation.
mlir::function_interface_impl::printFunctionOp( function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName()); getArgAttrsAttrName(), getResAttrsAttrName());
} }
@ -227,26 +296,31 @@ void FuncOp::print(mlir::OpAsmPrinter& p)
// MulOp // MulOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void MulOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, void MulOp::build(OpBuilder& builder, OperationState& state,
mlir::Value lhs, mlir::Value rhs) Value lhs, Value rhs)
{ {
state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands({lhs, rhs}); state.addOperands({lhs, rhs});
} }
mlir::ParseResult MulOp::parse(mlir::OpAsmParser& parser, ParseResult MulOp::parse(OpAsmParser& parser,
mlir::OperationState& result) OperationState& result)
{ {
return parseBinaryOp(parser, result); return parseBinaryOp(parser, result);
} }
void MulOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); } void MulOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); }
void MulOp::inferShapes()
{
getResult().setType(getLhs().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReturnOp // ReturnOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
mlir::LogicalResult ReturnOp::verify() LogicalResult ReturnOp::verify()
{ {
// We know that the parent operation is a function, because of the 'HasParent' // We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition. // trait attached to the operation definition.
@ -265,15 +339,15 @@ mlir::LogicalResult ReturnOp::verify()
// If the operation does not have an input, we are done. // If the operation does not have an input, we are done.
if (!hasOperand()) if (!hasOperand())
return mlir::success(); return success();
auto inputType = *operand_type_begin(); auto inputType = *operand_type_begin();
auto resultType = results.front(); auto resultType = results.front();
// Check that the result type of the function matches the operand type. // Check that the result type of the function matches the operand type.
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) || if (inputType == resultType || llvm::isa<UnrankedTensorType>(inputType) ||
llvm::isa<mlir::UnrankedTensorType>(resultType)) llvm::isa<UnrankedTensorType>(resultType))
return mlir::success(); return success();
return emitError() << "type of return operand (" << inputType return emitError() << "type of return operand (" << inputType
<< ") doesn't match function result type (" << resultType << ") doesn't match function result type (" << resultType
@ -284,19 +358,19 @@ mlir::LogicalResult ReturnOp::verify()
// TransposeOp // TransposeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void TransposeOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, void TransposeOp::build(OpBuilder& builder, OperationState& state,
mlir::Value value) Value value)
{ {
state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands(value); state.addOperands(value);
} }
mlir::LogicalResult TransposeOp::verify() LogicalResult TransposeOp::verify()
{ {
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType()); auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
auto resultType = llvm::dyn_cast<RankedTensorType>(getType()); auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType) if (!inputType || !resultType)
return mlir::success(); return success();
auto inputShape = inputType.getShape(); auto inputShape = inputType.getShape();
if (!std::equal(inputShape.begin(), inputShape.end(), if (!std::equal(inputShape.begin(), inputShape.end(),
@ -305,7 +379,43 @@ mlir::LogicalResult TransposeOp::verify()
return emitError() return emitError()
<< "expected result shape to be a transpose of the input"; << "expected result shape to be a transpose of the input";
} }
return mlir::success(); return success();
}
void TransposeOp::inferShapes()
{
// Transpose will reverse the shape of tensor.
auto tensorType = llvm::cast<RankedTensorType>(getOperand().getType());
// And assume that transpose only applies for matrix.
SmallVector<int64_t, 2> dimensions(llvm::reverse(tensorType.getShape()));
getResult().setType(RankedTensorType::get(dimensions, tensorType.getElementType()));
}
// CastOp
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs)
{
if (inputs.size() != 1 || outputs.size() != 1)
{
return false;
}
const auto inputTensorType = mlir::dyn_cast<TensorType>(inputs.front());
const auto outputTensorType = mlir::dyn_cast<TensorType>(outputs.front());
if (!inputTensorType || !outputTensorType || inputTensorType.getElementType() != outputTensorType.getElementType())
{
return false;
}
// If both have rank, they must be to the size.
// And the known size can be cast into unknown size.
return !inputTensorType.hasRank() || !outputTensorType.hasRank() || inputTensorType == outputTensorType;
}
void CastOp::inferShapes()
{
getResult().setType(getInput().getType());
} }

View File

@ -44,7 +44,9 @@ namespace
theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (Function& f : moduleAST) for (Function& f : moduleAST)
{
mlirGen(f); mlirGen(f);
}
// Verify the module after we have finished constructing it, this will check // Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we // the structural properties of the IR and invoke any specific verifiers we
@ -160,6 +162,13 @@ namespace
function.getFunctionType().getInputs(), getType(ValueType{}))); function.getFunctionType().getInputs(), getType(ValueType{})));
} }
// Jus set all functions except 'main' to private
// which is used to inline the other functions.
if (funcAST.getPrototype()->getName() != "main")
{
function.setPrivate();
}
return function; return function;
} }

View File

@ -0,0 +1,95 @@
//
// Created by ricardo on 02/06/25.
//
#include <llvm/Support/Debug.h>
#include <mlir/Pass/Pass.h>
#include "Dialect.h"
#include "Passes.h"
namespace mlir::hello
{
#include "hello/ShapeInferenceInterface.cpp.inc"
}
#define DEBUG_TYPE "ShapeInference"
namespace
{
struct ShapeInferencePass : mlir::PassWrapper<ShapeInferencePass, mlir::OperationPass<mlir::hello::FuncOp>>
{
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass)
void runOnOperation() override
{
mlir::hello::FuncOp operation = getOperation();
llvm::SmallPtrSet<mlir::Operation*, 16> opWorkList;
operation.walk([&](mlir::Operation* op)
{
if (isDynamicShapes(op))
{
opWorkList.insert(op);
}
});
while (!opWorkList.empty())
{
auto nextOperation = llvm::find_if(opWorkList, isOperationInferred);
if (nextOperation == opWorkList.end())
{
break;
}
mlir::Operation* op = *nextOperation;
opWorkList.erase(op);
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
if (auto shapeInference = mlir::dyn_cast<mlir::hello::ShapeInference>(op))
{
shapeInference.inferShapes();
}
else
{
op->emitError(
std::string("Failed to inference shape for operation '") + op->getName().getIdentifier().str() +
"' without shape inference interface.");
signalPassFailure();
return;
}
}
if (!opWorkList.empty())
{
operation.emitError("Failed to inference shape, ") << opWorkList.size() <<
" operations failed to inference.\n";
signalPassFailure();
}
}
static bool isOperationInferred(mlir::Operation* op)
{
return llvm::all_of(op->getOperandTypes(), [](mlir::Type operandType)
{
return llvm::isa<mlir::RankedTensorType>(operandType);
});
}
static bool isDynamicShapes(mlir::Operation* op)
{
return llvm::any_of(op->getResultTypes(), [](mlir::Type operandType)
{
return !llvm::isa<mlir::RankedTensorType>(operandType);
});
}
};
}
std::unique_ptr<mlir::Pass> mlir::hello::createShapeInferencePass()
{
return std::make_unique<ShapeInferencePass>();
}

View File

@ -12,6 +12,7 @@
#include "Dialect.h" #include "Dialect.h"
#include "MLIRGen.h" #include "MLIRGen.h"
#include "Passes.h"
namespace mlir namespace mlir
{ {
@ -119,7 +120,13 @@ int dumpMLIR()
return 1; return 1;
} }
manager.addNestedPass<mlir::hello::FuncOp>(mlir::createCanonicalizerPass()); // To inline all functions except 'main' function.
manager.addPass(mlir::createInlinerPass());
// In the canonicalizer pass, we add Transpose Pass and Reshape Pass.
mlir::OpPassManager& functionPassManager = manager.nest<mlir::hello::FuncOp>();
functionPassManager.addPass(mlir::createCanonicalizerPass());
functionPassManager.addPass(mlir::createCSEPass());
functionPassManager.addPass(mlir::hello::createShapeInferencePass());
if (mlir::failed(manager.run(*module))) if (mlir::failed(manager.run(*module)))
{ {