feat: toy tutorial chapter 4.
Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
parent
eacf20fe3c
commit
902915a57b
|
@ -35,23 +35,30 @@ mlir_tablegen(HelloCombine.inc -gen-rewriters)
|
||||||
include_directories(${CMAKE_BINARY_DIR})
|
include_directories(${CMAKE_BINARY_DIR})
|
||||||
add_public_tablegen_target(HelloCombineIncGen)
|
add_public_tablegen_target(HelloCombineIncGen)
|
||||||
|
|
||||||
add_library(SyntaxNode STATIC
|
add_library(HelloDialect STATIC
|
||||||
lib/SyntaxNode.cpp
|
lib/SyntaxNode.cpp
|
||||||
lib/Dialect.cpp
|
lib/Dialect.cpp
|
||||||
lib/MLIRGen.cpp
|
lib/MLIRGen.cpp
|
||||||
lib/HelloCombine.cpp
|
lib/HelloCombine.cpp
|
||||||
|
lib/ShapeInferencePass.cpp
|
||||||
|
|
||||||
include/SyntaxNode.h
|
include/SyntaxNode.h
|
||||||
include/Parser.h
|
include/Parser.h
|
||||||
include/Lexer.h
|
include/Lexer.h
|
||||||
|
include/Dialect.h
|
||||||
|
include/MLIRGen.h
|
||||||
|
include/Passes.h
|
||||||
)
|
)
|
||||||
|
|
||||||
add_dependencies(SyntaxNode HelloOpsIncGen HelloCombineIncGen)
|
add_dependencies(HelloDialect HelloOpsIncGen HelloCombineIncGen HelloInterfaceIncGen)
|
||||||
|
|
||||||
target_link_libraries(SyntaxNode
|
target_link_libraries(HelloDialect
|
||||||
PRIVATE
|
PRIVATE
|
||||||
MLIRSupport
|
MLIRSupport
|
||||||
MLIRAnalysis
|
MLIRAnalysis
|
||||||
MLIRFunctionInterfaces
|
MLIRFunctionInterfaces
|
||||||
|
MLIRCallInterfaces
|
||||||
|
MLIRCastInterfaces
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRParser
|
MLIRParser
|
||||||
MLIRSideEffectInterfaces
|
MLIRSideEffectInterfaces
|
||||||
|
@ -61,6 +68,6 @@ add_executable(hello-mlir main.cpp)
|
||||||
|
|
||||||
target_link_libraries(hello-mlir
|
target_link_libraries(hello-mlir
|
||||||
PRIVATE
|
PRIVATE
|
||||||
SyntaxNode
|
HelloDialect
|
||||||
LLVMSupport
|
LLVMSupport
|
||||||
LLVMCore)
|
LLVMCore)
|
|
@ -16,11 +16,5 @@ def main() {
|
||||||
# reuse the previously specialized and inferred version and return <3, 2>.
|
# reuse the previously specialized and inferred version and return <3, 2>.
|
||||||
var d = multiply_transpose(b, a);
|
var d = multiply_transpose(b, a);
|
||||||
|
|
||||||
# A new call with <3, 2> (instead of <2, 3>) for both dimensions will
|
print(d);
|
||||||
# trigger another specialization of `multiply_transpose`.
|
|
||||||
var e = multiply_transpose(c, d);
|
|
||||||
|
|
||||||
# Finally, calling into `multiply_transpose` with incompatible shapes
|
|
||||||
# (<2, 3> and <3, 2>) will trigger a shape inference error.
|
|
||||||
var f = multiply_transpose(a, c);
|
|
||||||
}
|
}
|
|
@ -9,8 +9,10 @@
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/SymbolTable.h"
|
#include "mlir/IR/SymbolTable.h"
|
||||||
#include "mlir/Interfaces/CallInterfaces.h"
|
#include "mlir/Interfaces/CallInterfaces.h"
|
||||||
|
#include "mlir/Interfaces/CastInterfaces.h"
|
||||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
#include "hello/ShapeInferenceInterface.h"
|
||||||
|
|
||||||
/// Include the auto-generated header file containing the declaration of the toy
|
/// Include the auto-generated header file containing the declaration of the toy
|
||||||
/// dialect.
|
/// dialect.
|
||||||
|
|
20
include/Passes.h
Normal file
20
include/Passes.h
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
//
|
||||||
|
// Created by ricardo on 02/06/25.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef PASSES_H
|
||||||
|
#define PASSES_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir
|
||||||
|
{
|
||||||
|
class Pass;
|
||||||
|
|
||||||
|
namespace hello
|
||||||
|
{
|
||||||
|
std::unique_ptr<Pass> createShapeInferencePass();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //PASSES_H
|
|
@ -4,3 +4,8 @@ mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||||
add_public_tablegen_target(HelloOpsIncGen)
|
add_public_tablegen_target(HelloOpsIncGen)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
|
||||||
|
mlir_tablegen(ShapeInferenceInterface.h.inc -gen-op-interface-decls)
|
||||||
|
mlir_tablegen(ShapeInferenceInterface.cpp.inc -gen-op-interface-defs)
|
||||||
|
add_public_tablegen_target(HelloInterfaceIncGen)
|
||||||
|
|
|
@ -5,12 +5,17 @@ include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Interfaces/FunctionInterfaces.td"
|
include "mlir/Interfaces/FunctionInterfaces.td"
|
||||||
include "mlir/IR/SymbolInterfaces.td"
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
include "mlir/Interfaces/CallInterfaces.td"
|
||||||
|
include "mlir/Interfaces/CastInterfaces.td"
|
||||||
|
include "hello/ShapeInferenceInterface.td"
|
||||||
|
|
||||||
def Hello_Dialect : Dialect {
|
def Hello_Dialect : Dialect {
|
||||||
let name = "hello";
|
let name = "hello";
|
||||||
let cppNamespace = "::mlir::hello";
|
let cppNamespace = "::mlir::hello";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Hello_Op<string mnemonic, list<Trait> traits = []> : Op<Hello_Dialect, mnemonic, traits>;
|
class Hello_Op<string mnemonic, list<Trait> traits = []> : Op<Hello_Dialect, mnemonic, traits>;
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +75,7 @@ def ConstantOp : Hello_Op<"constant", [Pure]> {
|
||||||
// AddOp
|
// AddOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def AddOp : Hello_Op<"add"> {
|
def AddOp : Hello_Op<"add", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "element-wise addition operation";
|
let summary = "element-wise addition operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
The "add" operation performs element-wise addition between two tensors.
|
The "add" operation performs element-wise addition between two tensors.
|
||||||
|
@ -148,7 +153,8 @@ def FuncOp : Hello_Op<"func", [
|
||||||
// GenericCallOp
|
// GenericCallOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def GenericCallOp : Hello_Op<"generic_call"> {
|
def GenericCallOp : Hello_Op<"generic_call",
|
||||||
|
[DeclareOpInterfaceMethods<CallOpInterface>]> {
|
||||||
let summary = "generic call operation";
|
let summary = "generic call operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
Generic calls represent calls to a user defined function that needs to
|
Generic calls represent calls to a user defined function that needs to
|
||||||
|
@ -187,7 +193,7 @@ def GenericCallOp : Hello_Op<"generic_call"> {
|
||||||
// MulOp
|
// MulOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def MulOp : Hello_Op<"mul"> {
|
def MulOp : Hello_Op<"mul", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "element-wise multiplication operation";
|
let summary = "element-wise multiplication operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
The "mul" operation performs element-wise multiplication between two
|
The "mul" operation performs element-wise multiplication between two
|
||||||
|
@ -296,7 +302,7 @@ def ReturnOp : Hello_Op<"return", [Pure, HasParent<"FuncOp">,
|
||||||
// TransposeOp
|
// TransposeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def TransposeOp : Hello_Op<"transpose", [Pure]> {
|
def TransposeOp : Hello_Op<"transpose", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "transpose operation";
|
let summary = "transpose operation";
|
||||||
|
|
||||||
let arguments = (ins F64Tensor:$input);
|
let arguments = (ins F64Tensor:$input);
|
||||||
|
@ -316,4 +322,25 @@ def TransposeOp : Hello_Op<"transpose", [Pure]> {
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def CastOp : Hello_Op<"cast", [
|
||||||
|
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||||
|
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
|
||||||
|
Pure,
|
||||||
|
SameOperandsAndResultShape
|
||||||
|
]> {
|
||||||
|
let summary = "shape cast operation";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
The "cast" operation converts a tensor from one type to an equivalent type
|
||||||
|
without changing any data elements. The source and destination types
|
||||||
|
must both be tensor types with the same element type. If both are ranked,
|
||||||
|
then shape is required to match. The operation is invalid if converting
|
||||||
|
to a mismatching constant dimension.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins F64Tensor:$input);
|
||||||
|
let results = (outs F64Tensor:$output);
|
||||||
|
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
18
include/hello/ShapeInferenceInterface.h
Normal file
18
include/hello/ShapeInferenceInterface.h
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
//
|
||||||
|
// Created by ricardo on 02/06/25.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SHAPEINFERENCEINTERFACE_H
|
||||||
|
#define SHAPEINFERENCEINTERFACE_H
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
|
namespace mlir
|
||||||
|
{
|
||||||
|
namespace hello
|
||||||
|
{
|
||||||
|
#include "hello/ShapeInferenceInterface.h.inc"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //SHAPEINFERENCEINTERFACE_H
|
18
include/hello/ShapeInferenceInterface.td
Normal file
18
include/hello/ShapeInferenceInterface.td
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
#ifndef SHAPE_INFERENCE_INTERFACE
|
||||||
|
#define SHAPE_INFERENCE_INTERFACE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
|
||||||
|
let description = [{
|
||||||
|
Interface to access a registered method to infer the return types for an
|
||||||
|
operation that can be used during type inference.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||||
|
"void", "inferShapes">
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
210
lib/Dialect.cpp
210
lib/Dialect.cpp
|
@ -6,6 +6,49 @@
|
||||||
#include "hello/Dialect.cpp.inc"
|
#include "hello/Dialect.cpp.inc"
|
||||||
|
|
||||||
#include <mlir/Interfaces/FunctionImplementation.h>
|
#include <mlir/Interfaces/FunctionImplementation.h>
|
||||||
|
#include <mlir/Transforms/InliningUtils.h>
|
||||||
|
#include <oneapi/tbb/detail/_template_helpers.h>
|
||||||
|
|
||||||
|
struct HelloDialectInlinerInterface : mlir::DialectInlinerInterface
|
||||||
|
{
|
||||||
|
using DialectInlinerInterface::DialectInlinerInterface;
|
||||||
|
|
||||||
|
bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable, bool wouldBeCloned) const override
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isLegalToInline(mlir::Region* dest, mlir::Region* src, bool wouldBeCloned,
|
||||||
|
mlir::IRMapping& valueMapping) const override
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isLegalToInline(mlir::Operation* op, mlir::Region* dest, bool wouldBeCloned,
|
||||||
|
mlir::IRMapping& valueMapping) const override
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void handleTerminator(mlir::Operation* op, mlir::ValueRange returnValues) const override
|
||||||
|
{
|
||||||
|
// Only the `hello.returnOp` is the function terminator
|
||||||
|
auto returnOp = llvm::cast<mlir::hello::ReturnOp>(op);
|
||||||
|
|
||||||
|
assert(returnOp.getNumOperands() == returnValues.size());
|
||||||
|
|
||||||
|
for (const auto& it : llvm::enumerate(returnOp.getOperands()))
|
||||||
|
{
|
||||||
|
returnValues[it.index()].replaceAllUsesWith(it.value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Operation* materializeCallConversion(mlir::OpBuilder& builder, mlir::Value input, mlir::Type resultType,
|
||||||
|
mlir::Location conversionLoc) const override
|
||||||
|
{
|
||||||
|
return builder.create<mlir::hello::CastOp>(conversionLoc, resultType, input);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void mlir::hello::HelloDialect::initialize()
|
void mlir::hello::HelloDialect::initialize()
|
||||||
{
|
{
|
||||||
|
@ -13,6 +56,7 @@ void mlir::hello::HelloDialect::initialize()
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "hello/Ops.cpp.inc"
|
#include "hello/Ops.cpp.inc"
|
||||||
>();
|
>();
|
||||||
|
addInterfaces<HelloDialectInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -25,43 +69,43 @@ static ParseResult parseBinaryOp(OpAsmParser& parser,
|
||||||
OperationState& result)
|
OperationState& result)
|
||||||
{
|
{
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
|
||||||
llvm::SMLoc operandsLoc = parser.getCurrentLocation();
|
SMLoc operandsLoc = parser.getCurrentLocation();
|
||||||
mlir::Type type;
|
Type type;
|
||||||
if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
|
if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
parser.parseOptionalAttrDict(result.attributes) ||
|
||||||
parser.parseColonType(type))
|
parser.parseColonType(type))
|
||||||
return mlir::failure();
|
return failure();
|
||||||
|
|
||||||
// If the type is a function type, it contains the input and result types of
|
// If the type is a function type, it contains the input and result types of
|
||||||
// this operation.
|
// this operation.
|
||||||
if (mlir::FunctionType funcType = llvm::dyn_cast<mlir::FunctionType>(type))
|
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type))
|
||||||
{
|
{
|
||||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||||
result.operands))
|
result.operands))
|
||||||
return mlir::failure();
|
return failure();
|
||||||
result.addTypes(funcType.getResults());
|
result.addTypes(funcType.getResults());
|
||||||
return mlir::success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, the parsed type is the type of both operands and results.
|
// Otherwise, the parsed type is the type of both operands and results.
|
||||||
if (parser.resolveOperands(operands, type, result.operands))
|
if (parser.resolveOperands(operands, type, result.operands))
|
||||||
return mlir::failure();
|
return failure();
|
||||||
result.addTypes(type);
|
result.addTypes(type);
|
||||||
return mlir::success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A generalized printer for binary operations. It prints in two different
|
/// A generalized printer for binary operations. It prints in two different
|
||||||
/// forms depending on if all of the types match.
|
/// forms depending on if all of the types match.
|
||||||
static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op)
|
static void printBinaryOp(OpAsmPrinter& printer, Operation* op)
|
||||||
{
|
{
|
||||||
printer << " " << op->getOperands();
|
printer << " " << op->getOperands();
|
||||||
printer.printOptionalAttrDict(op->getAttrs());
|
printer.printOptionalAttrDict(op->getAttrs());
|
||||||
printer << " : ";
|
printer << " : ";
|
||||||
|
|
||||||
// If all of the types are the same, print the type directly.
|
// If all of the types are the same, print the type directly.
|
||||||
mlir::Type resultType = *op->result_type_begin();
|
Type resultType = *op->result_type_begin();
|
||||||
if (llvm::all_of(op->getOperandTypes(),
|
if (llvm::all_of(op->getOperandTypes(),
|
||||||
[=](mlir::Type type) { return type == resultType; }))
|
[=](Type type) { return type == resultType; }))
|
||||||
{
|
{
|
||||||
printer << resultType;
|
printer << resultType;
|
||||||
return;
|
return;
|
||||||
|
@ -78,7 +122,7 @@ static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op)
|
||||||
/// Build a constant operation.
|
/// Build a constant operation.
|
||||||
/// The builder is passed as an argument, so is the state that this method is
|
/// The builder is passed as an argument, so is the state that this method is
|
||||||
/// expected to fill in order to build the operation.
|
/// expected to fill in order to build the operation.
|
||||||
void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state,
|
void ConstantOp::build(OpBuilder& builder, OperationState& state,
|
||||||
double value)
|
double value)
|
||||||
{
|
{
|
||||||
auto dataType = RankedTensorType::get({}, builder.getF64Type());
|
auto dataType = RankedTensorType::get({}, builder.getF64Type());
|
||||||
|
@ -93,10 +137,10 @@ void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state,
|
||||||
/// or `false` on success. This allows for easily chaining together a set of
|
/// or `false` on success. This allows for easily chaining together a set of
|
||||||
/// parser rules. These rules are used to populate an `mlir::OperationState`
|
/// parser rules. These rules are used to populate an `mlir::OperationState`
|
||||||
/// similarly to the `build` methods described above.
|
/// similarly to the `build` methods described above.
|
||||||
mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser,
|
ParseResult ConstantOp::parse(OpAsmParser& parser,
|
||||||
mlir::OperationState& result)
|
OperationState& result)
|
||||||
{
|
{
|
||||||
mlir::DenseElementsAttr value;
|
DenseElementsAttr value;
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) ||
|
if (parser.parseOptionalAttrDict(result.attributes) ||
|
||||||
parser.parseAttribute(value, "value", result.attributes))
|
parser.parseAttribute(value, "value", result.attributes))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -107,7 +151,7 @@ mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser,
|
||||||
|
|
||||||
/// The 'OpAsmPrinter' class is a stream that allows for formatting
|
/// The 'OpAsmPrinter' class is a stream that allows for formatting
|
||||||
/// strings, attributes, operands, types, etc.
|
/// strings, attributes, operands, types, etc.
|
||||||
void ConstantOp::print(mlir::OpAsmPrinter& printer)
|
void ConstantOp::print(OpAsmPrinter& printer)
|
||||||
{
|
{
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
|
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
|
||||||
|
@ -116,17 +160,17 @@ void ConstantOp::print(mlir::OpAsmPrinter& printer)
|
||||||
|
|
||||||
/// Verifier for the constant operation. This corresponds to the
|
/// Verifier for the constant operation. This corresponds to the
|
||||||
/// `let hasVerifier = 1` in the op definition.
|
/// `let hasVerifier = 1` in the op definition.
|
||||||
mlir::LogicalResult ConstantOp::verify()
|
LogicalResult ConstantOp::verify()
|
||||||
{
|
{
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
auto resultType = llvm::dyn_cast<RankedTensorType>(getResult().getType());
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the constant
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
// result type.
|
// result type.
|
||||||
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
auto attrType = llvm::cast<RankedTensorType>(getValue().getType());
|
||||||
if (attrType.getRank() != resultType.getRank())
|
if (attrType.getRank() != resultType.getRank())
|
||||||
{
|
{
|
||||||
return emitOpError("return type must match the one of the attached value "
|
return emitOpError("return type must match the one of the attached value "
|
||||||
|
@ -145,56 +189,81 @@ mlir::LogicalResult ConstantOp::verify()
|
||||||
<< " != " << resultType.getShape()[dim];
|
<< " != " << resultType.getShape()[dim];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return mlir::success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AddOp
|
// AddOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void AddOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
|
void AddOp::build(OpBuilder& builder, OperationState& state,
|
||||||
mlir::Value lhs, mlir::Value rhs)
|
Value lhs, Value rhs)
|
||||||
{
|
{
|
||||||
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||||
state.addOperands({lhs, rhs});
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::ParseResult AddOp::parse(mlir::OpAsmParser& parser,
|
ParseResult AddOp::parse(OpAsmParser& parser,
|
||||||
mlir::OperationState& result)
|
OperationState& result)
|
||||||
{
|
{
|
||||||
return parseBinaryOp(parser, result);
|
return parseBinaryOp(parser, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); }
|
void AddOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); }
|
||||||
|
|
||||||
|
void AddOp::inferShapes()
|
||||||
|
{
|
||||||
|
getResult().setType(getLhs().getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GenericCallOp
|
// GenericCallOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void GenericCallOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
|
void GenericCallOp::build(OpBuilder& builder, OperationState& state,
|
||||||
StringRef callee, ArrayRef<mlir::Value> arguments)
|
StringRef callee, ArrayRef<Value> arguments)
|
||||||
{
|
{
|
||||||
// Generic call always returns an unranked Tensor initially.
|
// Generic call always returns an unranked Tensor initially.
|
||||||
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||||
state.addOperands(arguments);
|
state.addOperands(arguments);
|
||||||
state.addAttribute("callee",
|
state.addAttribute("callee",
|
||||||
mlir::SymbolRefAttr::get(builder.getContext(), callee));
|
SymbolRefAttr::get(builder.getContext(), callee));
|
||||||
|
}
|
||||||
|
|
||||||
|
CallInterfaceCallable GenericCallOp::getCallableForCallee()
|
||||||
|
{
|
||||||
|
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
|
||||||
|
}
|
||||||
|
|
||||||
|
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee)
|
||||||
|
{
|
||||||
|
(*this)->setAttr("callee", mlir::cast<SymbolRefAttr>(callee));
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation::operand_range GenericCallOp::getArgOperands()
|
||||||
|
{
|
||||||
|
return getInputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
MutableOperandRange GenericCallOp::getArgOperandsMutable()
|
||||||
|
{
|
||||||
|
return getInputsMutable();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// FuncOp
|
// FuncOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void FuncOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
|
void FuncOp::build(OpBuilder& builder, OperationState& state,
|
||||||
llvm::StringRef name, mlir::FunctionType type,
|
StringRef name, FunctionType type,
|
||||||
llvm::ArrayRef<mlir::NamedAttribute> attrs)
|
ArrayRef<NamedAttribute> attrs)
|
||||||
{
|
{
|
||||||
// FunctionOpInterface provides a convenient `build` method that will populate
|
// FunctionOpInterface provides a convenient `build` method that will populate
|
||||||
// the state of our FuncOp, and create an entry block.
|
// the state of our FuncOp, and create an entry block.
|
||||||
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
|
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::ParseResult FuncOp::parse(OpAsmParser& parser,
|
ParseResult FuncOp::parse(OpAsmParser& parser,
|
||||||
OperationState& result)
|
OperationState& result)
|
||||||
{
|
{
|
||||||
// Dispatch to the FunctionOpInterface provided utility method that parses the
|
// Dispatch to the FunctionOpInterface provided utility method that parses the
|
||||||
|
@ -208,17 +277,17 @@ mlir::ParseResult FuncOp::parse(OpAsmParser& parser,
|
||||||
return builder.getFunctionType(argTypes, results);
|
return builder.getFunctionType(argTypes, results);
|
||||||
};
|
};
|
||||||
|
|
||||||
return mlir::function_interface_impl::parseFunctionOp(
|
return function_interface_impl::parseFunctionOp(
|
||||||
parser, result, /*allowVariadic=*/false,
|
parser, result, /*allowVariadic=*/false,
|
||||||
getFunctionTypeAttrName(result.name), buildFuncType,
|
getFunctionTypeAttrName(result.name), buildFuncType,
|
||||||
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
|
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncOp::print(mlir::OpAsmPrinter& p)
|
void FuncOp::print(OpAsmPrinter& p)
|
||||||
{
|
{
|
||||||
// Dispatch to the FunctionOpInterface provided utility method that prints the
|
// Dispatch to the FunctionOpInterface provided utility method that prints the
|
||||||
// function operation.
|
// function operation.
|
||||||
mlir::function_interface_impl::printFunctionOp(
|
function_interface_impl::printFunctionOp(
|
||||||
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
|
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
|
||||||
getArgAttrsAttrName(), getResAttrsAttrName());
|
getArgAttrsAttrName(), getResAttrsAttrName());
|
||||||
}
|
}
|
||||||
|
@ -227,26 +296,31 @@ void FuncOp::print(mlir::OpAsmPrinter& p)
|
||||||
// MulOp
|
// MulOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void MulOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
|
void MulOp::build(OpBuilder& builder, OperationState& state,
|
||||||
mlir::Value lhs, mlir::Value rhs)
|
Value lhs, Value rhs)
|
||||||
{
|
{
|
||||||
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||||
state.addOperands({lhs, rhs});
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::ParseResult MulOp::parse(mlir::OpAsmParser& parser,
|
ParseResult MulOp::parse(OpAsmParser& parser,
|
||||||
mlir::OperationState& result)
|
OperationState& result)
|
||||||
{
|
{
|
||||||
return parseBinaryOp(parser, result);
|
return parseBinaryOp(parser, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MulOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); }
|
void MulOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); }
|
||||||
|
|
||||||
|
void MulOp::inferShapes()
|
||||||
|
{
|
||||||
|
getResult().setType(getLhs().getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReturnOp
|
// ReturnOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
mlir::LogicalResult ReturnOp::verify()
|
LogicalResult ReturnOp::verify()
|
||||||
{
|
{
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
|
@ -265,15 +339,15 @@ mlir::LogicalResult ReturnOp::verify()
|
||||||
|
|
||||||
// If the operation does not have an input, we are done.
|
// If the operation does not have an input, we are done.
|
||||||
if (!hasOperand())
|
if (!hasOperand())
|
||||||
return mlir::success();
|
return success();
|
||||||
|
|
||||||
auto inputType = *operand_type_begin();
|
auto inputType = *operand_type_begin();
|
||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
if (inputType == resultType || llvm::isa<UnrankedTensorType>(inputType) ||
|
||||||
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
llvm::isa<UnrankedTensorType>(resultType))
|
||||||
return mlir::success();
|
return success();
|
||||||
|
|
||||||
return emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
<< ") doesn't match function result type (" << resultType
|
<< ") doesn't match function result type (" << resultType
|
||||||
|
@ -284,19 +358,19 @@ mlir::LogicalResult ReturnOp::verify()
|
||||||
// TransposeOp
|
// TransposeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void TransposeOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
|
void TransposeOp::build(OpBuilder& builder, OperationState& state,
|
||||||
mlir::Value value)
|
Value value)
|
||||||
{
|
{
|
||||||
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||||
state.addOperands(value);
|
state.addOperands(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult TransposeOp::verify()
|
LogicalResult TransposeOp::verify()
|
||||||
{
|
{
|
||||||
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||||
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return success();
|
||||||
|
|
||||||
auto inputShape = inputType.getShape();
|
auto inputShape = inputType.getShape();
|
||||||
if (!std::equal(inputShape.begin(), inputShape.end(),
|
if (!std::equal(inputShape.begin(), inputShape.end(),
|
||||||
|
@ -305,7 +379,43 @@ mlir::LogicalResult TransposeOp::verify()
|
||||||
return emitError()
|
return emitError()
|
||||||
<< "expected result shape to be a transpose of the input";
|
<< "expected result shape to be a transpose of the input";
|
||||||
}
|
}
|
||||||
return mlir::success();
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TransposeOp::inferShapes()
|
||||||
|
{
|
||||||
|
// Transpose will reverse the shape of tensor.
|
||||||
|
auto tensorType = llvm::cast<RankedTensorType>(getOperand().getType());
|
||||||
|
// And assume that transpose only applies for matrix.
|
||||||
|
SmallVector<int64_t, 2> dimensions(llvm::reverse(tensorType.getShape()));
|
||||||
|
getResult().setType(RankedTensorType::get(dimensions, tensorType.getElementType()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// CastOp
|
||||||
|
|
||||||
|
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs)
|
||||||
|
{
|
||||||
|
if (inputs.size() != 1 || outputs.size() != 1)
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto inputTensorType = mlir::dyn_cast<TensorType>(inputs.front());
|
||||||
|
const auto outputTensorType = mlir::dyn_cast<TensorType>(outputs.front());
|
||||||
|
|
||||||
|
if (!inputTensorType || !outputTensorType || inputTensorType.getElementType() != outputTensorType.getElementType())
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If both have rank, they must be to the size.
|
||||||
|
// And the known size can be cast into unknown size.
|
||||||
|
return !inputTensorType.hasRank() || !outputTensorType.hasRank() || inputTensorType == outputTensorType;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CastOp::inferShapes()
|
||||||
|
{
|
||||||
|
getResult().setType(getInput().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,9 @@ namespace
|
||||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||||
|
|
||||||
for (Function& f : moduleAST)
|
for (Function& f : moduleAST)
|
||||||
|
{
|
||||||
mlirGen(f);
|
mlirGen(f);
|
||||||
|
}
|
||||||
|
|
||||||
// Verify the module after we have finished constructing it, this will check
|
// Verify the module after we have finished constructing it, this will check
|
||||||
// the structural properties of the IR and invoke any specific verifiers we
|
// the structural properties of the IR and invoke any specific verifiers we
|
||||||
|
@ -160,6 +162,13 @@ namespace
|
||||||
function.getFunctionType().getInputs(), getType(ValueType{})));
|
function.getFunctionType().getInputs(), getType(ValueType{})));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Jus set all functions except 'main' to private
|
||||||
|
// which is used to inline the other functions.
|
||||||
|
if (funcAST.getPrototype()->getName() != "main")
|
||||||
|
{
|
||||||
|
function.setPrivate();
|
||||||
|
}
|
||||||
|
|
||||||
return function;
|
return function;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
95
lib/ShapeInferencePass.cpp
Normal file
95
lib/ShapeInferencePass.cpp
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
//
|
||||||
|
// Created by ricardo on 02/06/25.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <llvm/Support/Debug.h>
|
||||||
|
#include <mlir/Pass/Pass.h>
|
||||||
|
|
||||||
|
#include "Dialect.h"
|
||||||
|
#include "Passes.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace mlir::hello
|
||||||
|
{
|
||||||
|
#include "hello/ShapeInferenceInterface.cpp.inc"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "ShapeInference"
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
struct ShapeInferencePass : mlir::PassWrapper<ShapeInferencePass, mlir::OperationPass<mlir::hello::FuncOp>>
|
||||||
|
{
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass)
|
||||||
|
|
||||||
|
|
||||||
|
void runOnOperation() override
|
||||||
|
{
|
||||||
|
mlir::hello::FuncOp operation = getOperation();
|
||||||
|
llvm::SmallPtrSet<mlir::Operation*, 16> opWorkList;
|
||||||
|
|
||||||
|
operation.walk([&](mlir::Operation* op)
|
||||||
|
{
|
||||||
|
if (isDynamicShapes(op))
|
||||||
|
{
|
||||||
|
opWorkList.insert(op);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
while (!opWorkList.empty())
|
||||||
|
{
|
||||||
|
auto nextOperation = llvm::find_if(opWorkList, isOperationInferred);
|
||||||
|
if (nextOperation == opWorkList.end())
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Operation* op = *nextOperation;
|
||||||
|
opWorkList.erase(op);
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
|
||||||
|
|
||||||
|
if (auto shapeInference = mlir::dyn_cast<mlir::hello::ShapeInference>(op))
|
||||||
|
{
|
||||||
|
shapeInference.inferShapes();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
op->emitError(
|
||||||
|
std::string("Failed to inference shape for operation '") + op->getName().getIdentifier().str() +
|
||||||
|
"' without shape inference interface.");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!opWorkList.empty())
|
||||||
|
{
|
||||||
|
operation.emitError("Failed to inference shape, ") << opWorkList.size() <<
|
||||||
|
" operations failed to inference.\n";
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isOperationInferred(mlir::Operation* op)
|
||||||
|
{
|
||||||
|
return llvm::all_of(op->getOperandTypes(), [](mlir::Type operandType)
|
||||||
|
{
|
||||||
|
return llvm::isa<mlir::RankedTensorType>(operandType);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isDynamicShapes(mlir::Operation* op)
|
||||||
|
{
|
||||||
|
return llvm::any_of(op->getResultTypes(), [](mlir::Type operandType)
|
||||||
|
{
|
||||||
|
return !llvm::isa<mlir::RankedTensorType>(operandType);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> mlir::hello::createShapeInferencePass()
|
||||||
|
{
|
||||||
|
return std::make_unique<ShapeInferencePass>();
|
||||||
|
}
|
9
main.cpp
9
main.cpp
|
@ -12,6 +12,7 @@
|
||||||
|
|
||||||
#include "Dialect.h"
|
#include "Dialect.h"
|
||||||
#include "MLIRGen.h"
|
#include "MLIRGen.h"
|
||||||
|
#include "Passes.h"
|
||||||
|
|
||||||
namespace mlir
|
namespace mlir
|
||||||
{
|
{
|
||||||
|
@ -119,7 +120,13 @@ int dumpMLIR()
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.addNestedPass<mlir::hello::FuncOp>(mlir::createCanonicalizerPass());
|
// To inline all functions except 'main' function.
|
||||||
|
manager.addPass(mlir::createInlinerPass());
|
||||||
|
// In the canonicalizer pass, we add Transpose Pass and Reshape Pass.
|
||||||
|
mlir::OpPassManager& functionPassManager = manager.nest<mlir::hello::FuncOp>();
|
||||||
|
functionPassManager.addPass(mlir::createCanonicalizerPass());
|
||||||
|
functionPassManager.addPass(mlir::createCSEPass());
|
||||||
|
functionPassManager.addPass(mlir::hello::createShapeInferencePass());
|
||||||
|
|
||||||
if (mlir::failed(manager.run(*module)))
|
if (mlir::failed(manager.run(*module)))
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue
Block a user