feat: toy tutorial chapter 5.

Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
jackfiled 2025-06-05 23:46:52 +08:00
parent 902915a57b
commit c5ab1a6bc0
Signed by: jackfiled
GPG Key ID: DEF448811AE0286D
5 changed files with 393 additions and 18 deletions

View File

@ -41,6 +41,7 @@ add_library(HelloDialect STATIC
lib/MLIRGen.cpp
lib/HelloCombine.cpp
lib/ShapeInferencePass.cpp
lib/LowerToAffineLoopsPass.cpp
include/SyntaxNode.h
include/Parser.h
@ -52,8 +53,13 @@ add_library(HelloDialect STATIC
add_dependencies(HelloDialect HelloOpsIncGen HelloCombineIncGen HelloInterfaceIncGen)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
target_link_libraries(HelloDialect
PRIVATE
${dialect_libs}
${extension_libs}
MLIRSupport
MLIRAnalysis
MLIRFunctionInterfaces

View File

@ -14,6 +14,8 @@ namespace mlir
namespace hello
{
std::unique_ptr<Pass> createShapeInferencePass();
std::unique_ptr<Pass> createLowerToAffineLoopsPass();
}
}

View File

@ -224,7 +224,7 @@ def PrintOp : Hello_Op<"print"> {
}];
// The print operation takes an input tensor to print.
let arguments = (ins F64Tensor:$input);
let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
let assemblyFormat = "$input attr-dict `:` type($input)";
}

View File

@ -0,0 +1,346 @@
//
// Created by ricardo on 05/06/25.
//
#include <mlir/IR/BuiltinOps.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Transforms/DialectConversion.h>
#include "Passes.h"
#include "Dialect.h"
static mlir::MemRefType convertTensorTypeTOMemRefType(const mlir::RankedTensorType& tensorType)
{
return mlir::MemRefType::get(tensorType.getShape(), tensorType.getElementType());
}
static mlir::Value insertAllocAndDealloc(const mlir::MemRefType& type, const mlir::Location location,
mlir::PatternRewriter& rewriter)
{
auto allocateOperation = rewriter.create<mlir::memref::AllocOp>(location, type);
auto* parentBlock = allocateOperation->getBlock();
// Move the allocate operation before the first operation in this block.
allocateOperation->moveBefore(&parentBlock->front());
auto deallocateOperation = rewriter.create<mlir::memref::DeallocOp>(location, allocateOperation);
// Move the release operation before the last operation in this block, as the last operation
// usually is the return operation.
deallocateOperation->moveBefore(&parentBlock->back());
return allocateOperation;
}
/// The function type used for process an iteration of a lowered loop.
/// The `memRefOperands` is the operands of the input operations.
/// The `loopIterators` is the induction variables for the iteration.
/// The return value is the value to store and the current index of the iteration.
using LoopIterationFunction = mlir::function_ref<mlir::Value(mlir::OpBuilder& rewriter, mlir::ValueRange memRefOperands,
mlir::ValueRange loopIterators)>;
/// Helper function to lower operation to loops.
///
static void lowerOperationToLoops(mlir::Operation* op, mlir::ValueRange operands, mlir::PatternRewriter& rewriter,
LoopIterationFunction function)
{
auto tensorType = llvm::cast<mlir::RankedTensorType>(op->getResultTypes().front());
auto location = op->getLoc();
auto memRefType = convertTensorTypeTOMemRefType(tensorType);
auto allocation = insertAllocAndDealloc(memRefType, location, rewriter);
// Create a loop for every dimension of the shape.
// This vector stores the beginning value of each loop, as there all zeros.
llvm::SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), 0);
// This vector stores the step of each loop, as there all ones.
llvm::SmallVector<int64_t, 4> steps(tensorType.getRank(), 1);
mlir::affine::buildAffineLoopNest(rewriter, location, lowerBounds, tensorType.getShape(), steps, [&](
mlir::OpBuilder& nestedBuilder, const mlir::Location nestedLocation,
mlir::ValueRange iterators)
{
mlir::Value storedValue = function(nestedBuilder, operands, iterators);
nestedBuilder.create<mlir::affine::AffineStoreOp>(
nestedLocation, storedValue, allocation, iterators);
});
// Use the created buffer to replace the operation.
rewriter.replaceOp(op, allocation);
}
namespace
{
/// A generic pattern to convert a hello binary operation into an arith operation.
template <typename BinaryOperation, typename LoweredBinaryOperation>
struct BinaryOperationLoweringPattern : mlir::ConversionPattern
{
explicit BinaryOperationLoweringPattern(mlir::MLIRContext* context) : ConversionPattern(
BinaryOperation::getOperationName(), 1, context)
{
}
mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter& rewriter) const final
{
auto location = op->getLoc();
lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder,
mlir::ValueRange memRefOperands,
mlir::ValueRange loopIterators)
{
// The adaptor is generated automatically by the framework.
// This adaptor will convert
typename BinaryOperation::Adaptor binaryOperationAdaptor(memRefOperands);
auto loadLeftOperand = builder.create<mlir::affine::AffineLoadOp>(
location, binaryOperationAdaptor.getLhs(), loopIterators);
auto loadRightOperand = builder.create<mlir::affine::AffineLoadOp>(
location, binaryOperationAdaptor.getRhs(), loopIterators);
return builder.create<LoweredBinaryOperation>(location, loadLeftOperand, loadRightOperand);
});
return mlir::success();
}
};
}
using AddOperationLoweringPattern = BinaryOperationLoweringPattern<mlir::hello::AddOp, mlir::arith::AddFOp>;
using MulOperationLoweringPattern = BinaryOperationLoweringPattern<mlir::hello::MulOp, mlir::arith::MulFOp>;
namespace
{
/// Lower the constant values into the buffer.
struct ConstantOperationLoweringPattern : mlir::OpRewritePattern<mlir::hello::ConstantOp>
{
/// Constructor.
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(mlir::hello::ConstantOp op, mlir::PatternRewriter& rewriter) const final
{
mlir::DenseElementsAttr constantValues = op.getValue();
mlir::Location location = op->getLoc();
// Allocate buffer for these constant values.
mlir::RankedTensorType tensorType = llvm::cast<mlir::RankedTensorType>(op.getType());
mlir::MemRefType memRefTypes = convertTensorTypeTOMemRefType(tensorType);
mlir::Value allocationBuffer = insertAllocAndDealloc(memRefTypes, location, rewriter);
// Generate the constant indices up to the largest dimension,
// so that we can avoid a amount of redundant operations.
auto valueShape = memRefTypes.getShape();
llvm::SmallVector<mlir::Value, 8> constantIndices;
if (!valueShape.empty())
{
for (const auto i : llvm::seq<int64_t>(0, *llvm::max_element(valueShape)))
{
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(location, i));
}
}
else
{
// If the value shape is empty, as the tensor of rank 0, so this is just a number.
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(location, 0));
}
// The store the constant values into the buffer.
// Define the recursive function to store.
// Vector to store the indices.
llvm::SmallVector<mlir::Value, 2> indices;
// Iterator to iterate the constant values.
auto valueIterator = constantValues.value_begin<mlir::FloatAttr>();
std::function<void(uint64_t)> storeElementFunc = [&](const uint64_t dimension)
{
// When the dimension reach the size of shape, we reach the end of recursion,
// where we store the value.
if (dimension == valueShape.size())
{
rewriter.create<mlir::affine::AffineStoreOp>(
location, rewriter.create<mlir::arith::ConstantOp>(location, *valueIterator), allocationBuffer,
llvm::ArrayRef(indices));
++valueIterator;
return;
}
// There build the indices.
for (const auto i : llvm::seq<int64_t>(0, valueShape[dimension]))
{
indices.push_back(constantIndices[i]);
storeElementFunc(dimension + 1);
indices.pop_back();
}
};
// Start the recursion from dimension 0;
storeElementFunc(0);
rewriter.replaceOp(op, allocationBuffer);
return mlir::success();
}
};
/// Lower the hello::func to func:func.
struct FuncOpLoweringPattern : mlir::OpConversionPattern<mlir::hello::FuncOp>
{
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::hello::FuncOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter& rewriter) const final
{
// Only lowering the main function as assuming that other functions have been inlined.
if (op.getName() != "main")
{
return mlir::failure();
}
// Validate the main function no arguments and no return values.
if (op.getNumArguments() != 0 || op.getNumResults() != 0)
{
return rewriter.notifyMatchFailure(op, [](mlir::Diagnostic& diagnostic)
{
diagnostic << "Expect the 'main' function to have no arguments and no result.";
});
}
// Create a new function with the same region.
auto function = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(), op.getFunctionType());
rewriter.inlineRegionBefore(op.getRegion(), function.getBody(), function.end());
rewriter.eraseOp(op);
return mlir::success();
}
};
/// Lower the hello::print, just replace the operand.
struct PrintOpLoweringPattern : mlir::OpConversionPattern<mlir::hello::PrintOp>
{
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::hello::PrintOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter& rewriter) const final
{
rewriter.modifyOpInPlace(op, [&]
{
op->setOperands(adaptor.getOperands());
});
return mlir::success();
}
};
/// Lower hello::return to func::return.
struct ReturnOpLoweringPattern : mlir::OpRewritePattern<mlir::hello::ReturnOp>
{
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(mlir::hello::ReturnOp op, mlir::PatternRewriter& rewriter) const final
{
// As all function calls have been inlined,
// So there is only one return in main function and return no value.
if (op.hasOperand())
{
return mlir::failure();
}
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op);
return mlir::success();
}
};
/// Lower the transpose operation to affine loops.
struct TransposeOpLoweringPattern : mlir::ConversionPattern
{
explicit TransposeOpLoweringPattern(mlir::MLIRContext* context) : ConversionPattern(
mlir::hello::TransposeOp::getOperationName(),
1, context)
{
}
mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter& rewriter) const final
{
auto location = op->getLoc();
lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder,
mlir::ValueRange memRefOperands,
mlir::ValueRange loopIterators)
{
mlir::hello::TransposeOpAdaptor adaptor(memRefOperands);
mlir::Value input = adaptor.getInput();
// Transpose operation just reverse the indices array as (x, y) -> (y, x).
llvm::SmallVector<mlir::Value, 2> reversedIterators(llvm::reverse(loopIterators));
return builder.create<mlir::affine::AffineLoadOp>(location, input, reversedIterators);
});
return mlir::success();
}
};
}
namespace
{
struct HelloToAffineLoopsLoweringPass : mlir::PassWrapper<
HelloToAffineLoopsLoweringPass, mlir::OperationPass<mlir::ModuleOp>>
{
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToAffineLoopsLoweringPass);
/// Argument used by command line options.
llvm::StringRef getArgument() const override
{
return "hello-to-affine-loops";
}
void getDependentDialects(mlir::DialectRegistry& registry) const override
{
registry.insert<mlir::affine::AffineDialect, mlir::func::FuncDialect, mlir::arith::ArithDialect,
mlir::memref::MemRefDialect>();
}
void runOnOperation() final;
};
}
void HelloToAffineLoopsLoweringPass::runOnOperation()
{
mlir::ConversionTarget target(getContext());
// Only allow the affine, func, arith and memRef dialect.
target.addLegalDialect<mlir::affine::AffineDialect, mlir::func::FuncDialect, mlir::arith::ArithDialect,
mlir::memref::MemRefDialect>();
// Disable hello dialect.
target.addIllegalDialect<mlir::hello::HelloDialect>();
// Manual allow the print op which will be allowed to llvm dialect.
target.addDynamicallyLegalOp<mlir::hello::PrintOp>([](const mlir::hello::PrintOp op)
{
// Make sure the operand is memRef op, as non of the operand types is tensor.
return llvm::none_of(op->getOperandTypes(), llvm::IsaPred<mlir::TensorType>);
});
// Construct the patterns.
mlir::RewritePatternSet set(&getContext());
set.add<AddOperationLoweringPattern, MulOperationLoweringPattern, ConstantOperationLoweringPattern,
ConstantOperationLoweringPattern, PrintOpLoweringPattern, FuncOpLoweringPattern, ReturnOpLoweringPattern,
TransposeOpLoweringPattern>(&getContext());
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(set))))
{
signalPassFailure();
}
}
namespace mlir::hello
{
std::unique_ptr<Pass> createLowerToAffineLoopsPass()
{
return std::make_unique<HelloToAffineLoopsLoweringPass>();
}
}

View File

@ -9,6 +9,9 @@
#include <mlir/Parser/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Transforms/Passes.h>
#include <mlir/Dialect/Func/Extensions/AllExtensions.h>
#include <mlir/Dialect/Affine/Passes.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include "Dialect.h"
#include "MLIRGen.h"
@ -26,12 +29,12 @@ static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
namespace
{
enum Action { None, DumpSyntaxNode, DumpMLIR };
enum Action { None, DumpSyntaxNode, DumpMLIR, DumpAffineMLIR };
enum InputType { Hello, MLIR };
}
/// The input file type
/// The input file type.
static llvm::cl::opt<InputType> inputType("x", llvm::cl::init(Hello),
llvm::cl::desc("Decided the kind of input desired."),
llvm::cl::values(
@ -39,18 +42,21 @@ static llvm::cl::opt<InputType> inputType("x", llvm::cl::init(Hello),
llvm::cl::values(
clEnumValN(MLIR, "mlir", "load the input file as a mlir source.")));
/// What is the action the compiler will do
/// What is the action the compiler will do.
static llvm::cl::opt<Action> emitAction("emit", llvm::cl::desc("Select the kind of output desired"),
llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")),
llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")));
llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")),
llvm::cl::values(clEnumValN(DumpAffineMLIR, "affine-mlir",
"Dump mlir code after lowering to affine loops")));
/// Whether to enable the optimization.
static llvm::cl::opt<bool> enableOpt("opt", llvm::cl::desc("Enable optimizations"));
std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
{
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError())
if (const std::error_code ec = fileOrErr.getError())
{
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return nullptr;
@ -100,7 +106,10 @@ int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::O
int dumpMLIR()
{
mlir::MLIRContext context;
mlir::DialectRegistry registry;
mlir::func::registerAllExtensions(registry);
mlir::MLIRContext context(registry);
context.getOrLoadDialect<mlir::hello::HelloDialect>();
mlir::OwningOpRef<mlir::ModuleOp> module;
@ -111,8 +120,7 @@ int dumpMLIR()
return error;
}
if (enableOpt)
{
const bool isLoweringToAffine = emitAction >= DumpAffineMLIR;
mlir::PassManager manager(module.get()->getName());
if (mlir::failed(mlir::applyPassManagerCLOptions(manager)))
@ -120,6 +128,8 @@ int dumpMLIR()
return 1;
}
if (enableOpt || isLoweringToAffine)
{
// To inline all functions except 'main' function.
manager.addPass(mlir::createInlinerPass());
// In the canonicalizer pass, we add Transpose Pass and Reshape Pass.
@ -127,16 +137,26 @@ int dumpMLIR()
functionPassManager.addPass(mlir::createCanonicalizerPass());
functionPassManager.addPass(mlir::createCSEPass());
functionPassManager.addPass(mlir::hello::createShapeInferencePass());
}
if (isLoweringToAffine)
{
manager.addPass(mlir::hello::createLowerToAffineLoopsPass());
mlir::OpPassManager& functionPassManager = manager.nest<mlir::func::FuncOp>();
// Add some optimization from the affine dialect.
if (enableOpt)
{
manager.addPass(mlir::affine::createLoopFusionPass());
functionPassManager.addPass(mlir::affine::createAffineScalarReplacementPass());
}
}
if (mlir::failed(manager.run(*module)))
{
return 1;
}
module->print(llvm::outs());
return 0;
}
module->print(llvm::outs());
return 0;
}
@ -172,6 +192,7 @@ int main(int argc, char** argv)
case DumpSyntaxNode:
return dumpSyntaxNode();
case DumpMLIR:
case DumpAffineMLIR:
return dumpMLIR();
default:
llvm::errs() << "Unrecognized action\n";