feat: toy tutorial chapter 4.

Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
2025-06-03 16:03:17 +08:00
parent eacf20fe3c
commit 902915a57b
12 changed files with 380 additions and 68 deletions

View File

@@ -6,6 +6,49 @@
#include "hello/Dialect.cpp.inc"
#include <mlir/Interfaces/FunctionImplementation.h>
#include <mlir/Transforms/InliningUtils.h>
#include <oneapi/tbb/detail/_template_helpers.h>
struct HelloDialectInlinerInterface : mlir::DialectInlinerInterface
{
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable, bool wouldBeCloned) const override
{
return true;
}
bool isLegalToInline(mlir::Region* dest, mlir::Region* src, bool wouldBeCloned,
mlir::IRMapping& valueMapping) const override
{
return true;
}
bool isLegalToInline(mlir::Operation* op, mlir::Region* dest, bool wouldBeCloned,
mlir::IRMapping& valueMapping) const override
{
return true;
}
void handleTerminator(mlir::Operation* op, mlir::ValueRange returnValues) const override
{
// Only the `hello.returnOp` is the function terminator
auto returnOp = llvm::cast<mlir::hello::ReturnOp>(op);
assert(returnOp.getNumOperands() == returnValues.size());
for (const auto& it : llvm::enumerate(returnOp.getOperands()))
{
returnValues[it.index()].replaceAllUsesWith(it.value());
}
}
mlir::Operation* materializeCallConversion(mlir::OpBuilder& builder, mlir::Value input, mlir::Type resultType,
mlir::Location conversionLoc) const override
{
return builder.create<mlir::hello::CastOp>(conversionLoc, resultType, input);
}
};
void mlir::hello::HelloDialect::initialize()
{
@@ -13,6 +56,7 @@ void mlir::hello::HelloDialect::initialize()
#define GET_OP_LIST
#include "hello/Ops.cpp.inc"
>();
addInterfaces<HelloDialectInlinerInterface>();
}
using namespace mlir;
@@ -25,43 +69,43 @@ static ParseResult parseBinaryOp(OpAsmParser& parser,
OperationState& result)
{
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
llvm::SMLoc operandsLoc = parser.getCurrentLocation();
mlir::Type type;
SMLoc operandsLoc = parser.getCurrentLocation();
Type type;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type))
return mlir::failure();
return failure();
// If the type is a function type, it contains the input and result types of
// this operation.
if (mlir::FunctionType funcType = llvm::dyn_cast<mlir::FunctionType>(type))
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type))
{
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
return failure();
result.addTypes(funcType.getResults());
return mlir::success();
return success();
}
// Otherwise, the parsed type is the type of both operands and results.
if (parser.resolveOperands(operands, type, result.operands))
return mlir::failure();
return failure();
result.addTypes(type);
return mlir::success();
return success();
}
/// A generalized printer for binary operations. It prints in two different
/// forms depending on if all of the types match.
static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op)
static void printBinaryOp(OpAsmPrinter& printer, Operation* op)
{
printer << " " << op->getOperands();
printer.printOptionalAttrDict(op->getAttrs());
printer << " : ";
// If all of the types are the same, print the type directly.
mlir::Type resultType = *op->result_type_begin();
Type resultType = *op->result_type_begin();
if (llvm::all_of(op->getOperandTypes(),
[=](mlir::Type type) { return type == resultType; }))
[=](Type type) { return type == resultType; }))
{
printer << resultType;
return;
@@ -78,8 +122,8 @@ static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op)
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state,
double value)
void ConstantOp::build(OpBuilder& builder, OperationState& state,
double value)
{
auto dataType = RankedTensorType::get({}, builder.getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
@@ -93,10 +137,10 @@ void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state,
/// or `false` on success. This allows for easily chaining together a set of
/// parser rules. These rules are used to populate an `mlir::OperationState`
/// similarly to the `build` methods described above.
mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser,
mlir::OperationState& result)
ParseResult ConstantOp::parse(OpAsmParser& parser,
OperationState& result)
{
mlir::DenseElementsAttr value;
DenseElementsAttr value;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(value, "value", result.attributes))
return failure();
@@ -107,7 +151,7 @@ mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser,
/// The 'OpAsmPrinter' class is a stream that allows for formatting
/// strings, attributes, operands, types, etc.
void ConstantOp::print(mlir::OpAsmPrinter& printer)
void ConstantOp::print(OpAsmPrinter& printer)
{
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
@@ -116,17 +160,17 @@ void ConstantOp::print(mlir::OpAsmPrinter& printer)
/// Verifier for the constant operation. This corresponds to the
/// `let hasVerifier = 1` in the op definition.
mlir::LogicalResult ConstantOp::verify()
LogicalResult ConstantOp::verify()
{
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
auto resultType = llvm::dyn_cast<RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
auto attrType = llvm::cast<RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank())
{
return emitOpError("return type must match the one of the attached value "
@@ -145,57 +189,82 @@ mlir::LogicalResult ConstantOp::verify()
<< " != " << resultType.getShape()[dim];
}
}
return mlir::success();
return success();
}
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
void AddOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
mlir::Value lhs, mlir::Value rhs)
void AddOp::build(OpBuilder& builder, OperationState& state,
Value lhs, Value rhs)
{
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands({lhs, rhs});
}
mlir::ParseResult AddOp::parse(mlir::OpAsmParser& parser,
mlir::OperationState& result)
ParseResult AddOp::parse(OpAsmParser& parser,
OperationState& result)
{
return parseBinaryOp(parser, result);
}
void AddOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); }
void AddOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); }
void AddOp::inferShapes()
{
getResult().setType(getLhs().getType());
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//
void GenericCallOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
StringRef callee, ArrayRef<mlir::Value> arguments)
void GenericCallOp::build(OpBuilder& builder, OperationState& state,
StringRef callee, ArrayRef<Value> arguments)
{
// Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands(arguments);
state.addAttribute("callee",
mlir::SymbolRefAttr::get(builder.getContext(), callee));
SymbolRefAttr::get(builder.getContext(), callee));
}
CallInterfaceCallable GenericCallOp::getCallableForCallee()
{
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee)
{
(*this)->setAttr("callee", mlir::cast<SymbolRefAttr>(callee));
}
Operation::operand_range GenericCallOp::getArgOperands()
{
return getInputs();
}
MutableOperandRange GenericCallOp::getArgOperandsMutable()
{
return getInputsMutable();
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
void FuncOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs)
void FuncOp::build(OpBuilder& builder, OperationState& state,
StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
{
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}
mlir::ParseResult FuncOp::parse(OpAsmParser& parser,
OperationState& result)
ParseResult FuncOp::parse(OpAsmParser& parser,
OperationState& result)
{
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
@@ -208,17 +277,17 @@ mlir::ParseResult FuncOp::parse(OpAsmParser& parser,
return builder.getFunctionType(argTypes, results);
};
return mlir::function_interface_impl::parseFunctionOp(
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(mlir::OpAsmPrinter& p)
void FuncOp::print(OpAsmPrinter& p)
{
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}
@@ -227,26 +296,31 @@ void FuncOp::print(mlir::OpAsmPrinter& p)
// MulOp
//===----------------------------------------------------------------------===//
void MulOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
mlir::Value lhs, mlir::Value rhs)
void MulOp::build(OpBuilder& builder, OperationState& state,
Value lhs, Value rhs)
{
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands({lhs, rhs});
}
mlir::ParseResult MulOp::parse(mlir::OpAsmParser& parser,
mlir::OperationState& result)
ParseResult MulOp::parse(OpAsmParser& parser,
OperationState& result)
{
return parseBinaryOp(parser, result);
}
void MulOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); }
void MulOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); }
void MulOp::inferShapes()
{
getResult().setType(getLhs().getType());
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
mlir::LogicalResult ReturnOp::verify()
LogicalResult ReturnOp::verify()
{
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
@@ -265,15 +339,15 @@ mlir::LogicalResult ReturnOp::verify()
// If the operation does not have an input, we are done.
if (!hasOperand())
return mlir::success();
return success();
auto inputType = *operand_type_begin();
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
if (inputType == resultType || llvm::isa<UnrankedTensorType>(inputType) ||
llvm::isa<UnrankedTensorType>(resultType))
return success();
return emitError() << "type of return operand (" << inputType
<< ") doesn't match function result type (" << resultType
@@ -284,19 +358,19 @@ mlir::LogicalResult ReturnOp::verify()
// TransposeOp
//===----------------------------------------------------------------------===//
void TransposeOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
mlir::Value value)
void TransposeOp::build(OpBuilder& builder, OperationState& state,
Value value)
{
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands(value);
}
mlir::LogicalResult TransposeOp::verify()
LogicalResult TransposeOp::verify()
{
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
return success();
auto inputShape = inputType.getShape();
if (!std::equal(inputShape.begin(), inputShape.end(),
@@ -305,7 +379,43 @@ mlir::LogicalResult TransposeOp::verify()
return emitError()
<< "expected result shape to be a transpose of the input";
}
return mlir::success();
return success();
}
void TransposeOp::inferShapes()
{
// Transpose will reverse the shape of tensor.
auto tensorType = llvm::cast<RankedTensorType>(getOperand().getType());
// And assume that transpose only applies for matrix.
SmallVector<int64_t, 2> dimensions(llvm::reverse(tensorType.getShape()));
getResult().setType(RankedTensorType::get(dimensions, tensorType.getElementType()));
}
// CastOp
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs)
{
if (inputs.size() != 1 || outputs.size() != 1)
{
return false;
}
const auto inputTensorType = mlir::dyn_cast<TensorType>(inputs.front());
const auto outputTensorType = mlir::dyn_cast<TensorType>(outputs.front());
if (!inputTensorType || !outputTensorType || inputTensorType.getElementType() != outputTensorType.getElementType())
{
return false;
}
// If both have rank, they must be to the size.
// And the known size can be cast into unknown size.
return !inputTensorType.hasRank() || !outputTensorType.hasRank() || inputTensorType == outputTensorType;
}
void CastOp::inferShapes()
{
getResult().setType(getInput().getType());
}

View File

@@ -44,7 +44,9 @@ namespace
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (Function& f : moduleAST)
{
mlirGen(f);
}
// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
@@ -160,6 +162,13 @@ namespace
function.getFunctionType().getInputs(), getType(ValueType{})));
}
// Jus set all functions except 'main' to private
// which is used to inline the other functions.
if (funcAST.getPrototype()->getName() != "main")
{
function.setPrivate();
}
return function;
}

View File

@@ -0,0 +1,95 @@
//
// Created by ricardo on 02/06/25.
//
#include <llvm/Support/Debug.h>
#include <mlir/Pass/Pass.h>
#include "Dialect.h"
#include "Passes.h"
namespace mlir::hello
{
#include "hello/ShapeInferenceInterface.cpp.inc"
}
#define DEBUG_TYPE "ShapeInference"
namespace
{
struct ShapeInferencePass : mlir::PassWrapper<ShapeInferencePass, mlir::OperationPass<mlir::hello::FuncOp>>
{
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass)
void runOnOperation() override
{
mlir::hello::FuncOp operation = getOperation();
llvm::SmallPtrSet<mlir::Operation*, 16> opWorkList;
operation.walk([&](mlir::Operation* op)
{
if (isDynamicShapes(op))
{
opWorkList.insert(op);
}
});
while (!opWorkList.empty())
{
auto nextOperation = llvm::find_if(opWorkList, isOperationInferred);
if (nextOperation == opWorkList.end())
{
break;
}
mlir::Operation* op = *nextOperation;
opWorkList.erase(op);
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
if (auto shapeInference = mlir::dyn_cast<mlir::hello::ShapeInference>(op))
{
shapeInference.inferShapes();
}
else
{
op->emitError(
std::string("Failed to inference shape for operation '") + op->getName().getIdentifier().str() +
"' without shape inference interface.");
signalPassFailure();
return;
}
}
if (!opWorkList.empty())
{
operation.emitError("Failed to inference shape, ") << opWorkList.size() <<
" operations failed to inference.\n";
signalPassFailure();
}
}
static bool isOperationInferred(mlir::Operation* op)
{
return llvm::all_of(op->getOperandTypes(), [](mlir::Type operandType)
{
return llvm::isa<mlir::RankedTensorType>(operandType);
});
}
static bool isDynamicShapes(mlir::Operation* op)
{
return llvm::any_of(op->getResultTypes(), [](mlir::Type operandType)
{
return !llvm::isa<mlir::RankedTensorType>(operandType);
});
}
};
}
std::unique_ptr<mlir::Pass> mlir::hello::createShapeInferencePass()
{
return std::make_unique<ShapeInferencePass>();
}