From c5ab1a6bc0268b6288e45a51287dd47cf9525864 Mon Sep 17 00:00:00 2001 From: jackfiled Date: Thu, 5 Jun 2025 23:46:52 +0800 Subject: [PATCH] feat: toy tutorial chapter 5. Signed-off-by: jackfiled --- CMakeLists.txt | 6 + include/Passes.h | 2 + include/hello/Ops.td | 2 +- lib/LowerToAffineLoopsPass.cpp | 346 +++++++++++++++++++++++++++++++++ main.cpp | 55 ++++-- 5 files changed, 393 insertions(+), 18 deletions(-) create mode 100644 lib/LowerToAffineLoopsPass.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 754304f..44d929f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,6 +41,7 @@ add_library(HelloDialect STATIC lib/MLIRGen.cpp lib/HelloCombine.cpp lib/ShapeInferencePass.cpp + lib/LowerToAffineLoopsPass.cpp include/SyntaxNode.h include/Parser.h @@ -52,8 +53,13 @@ add_library(HelloDialect STATIC add_dependencies(HelloDialect HelloOpsIncGen HelloCombineIncGen HelloInterfaceIncGen) +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + target_link_libraries(HelloDialect PRIVATE + ${dialect_libs} + ${extension_libs} MLIRSupport MLIRAnalysis MLIRFunctionInterfaces diff --git a/include/Passes.h b/include/Passes.h index a92cbdb..b10a43c 100644 --- a/include/Passes.h +++ b/include/Passes.h @@ -14,6 +14,8 @@ namespace mlir namespace hello { std::unique_ptr createShapeInferencePass(); + + std::unique_ptr createLowerToAffineLoopsPass(); } } diff --git a/include/hello/Ops.td b/include/hello/Ops.td index 77b3dc5..1843e4a 100644 --- a/include/hello/Ops.td +++ b/include/hello/Ops.td @@ -224,7 +224,7 @@ def PrintOp : Hello_Op<"print"> { }]; // The print operation takes an input tensor to print. - let arguments = (ins F64Tensor:$input); + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); let assemblyFormat = "$input attr-dict `:` type($input)"; } diff --git a/lib/LowerToAffineLoopsPass.cpp b/lib/LowerToAffineLoopsPass.cpp new file mode 100644 index 0000000..c591243 --- /dev/null +++ b/lib/LowerToAffineLoopsPass.cpp @@ -0,0 +1,346 @@ +// +// Created by ricardo on 05/06/25. +// + +#include +#include +#include +#include +#include +#include +#include + +#include "Passes.h" +#include "Dialect.h" + +static mlir::MemRefType convertTensorTypeTOMemRefType(const mlir::RankedTensorType& tensorType) +{ + return mlir::MemRefType::get(tensorType.getShape(), tensorType.getElementType()); +} + +static mlir::Value insertAllocAndDealloc(const mlir::MemRefType& type, const mlir::Location location, + mlir::PatternRewriter& rewriter) +{ + auto allocateOperation = rewriter.create(location, type); + auto* parentBlock = allocateOperation->getBlock(); + + // Move the allocate operation before the first operation in this block. + allocateOperation->moveBefore(&parentBlock->front()); + + auto deallocateOperation = rewriter.create(location, allocateOperation); + // Move the release operation before the last operation in this block, as the last operation + // usually is the return operation. + deallocateOperation->moveBefore(&parentBlock->back()); + + return allocateOperation; +} + +/// The function type used for process an iteration of a lowered loop. +/// The `memRefOperands` is the operands of the input operations. +/// The `loopIterators` is the induction variables for the iteration. +/// The return value is the value to store and the current index of the iteration. +using LoopIterationFunction = mlir::function_ref; + +/// Helper function to lower operation to loops. +/// +static void lowerOperationToLoops(mlir::Operation* op, mlir::ValueRange operands, mlir::PatternRewriter& rewriter, + LoopIterationFunction function) +{ + auto tensorType = llvm::cast(op->getResultTypes().front()); + auto location = op->getLoc(); + + auto memRefType = convertTensorTypeTOMemRefType(tensorType); + auto allocation = insertAllocAndDealloc(memRefType, location, rewriter); + + // Create a loop for every dimension of the shape. + // This vector stores the beginning value of each loop, as there all zeros. + llvm::SmallVector lowerBounds(tensorType.getRank(), 0); + // This vector stores the step of each loop, as there all ones. + llvm::SmallVector steps(tensorType.getRank(), 1); + + + mlir::affine::buildAffineLoopNest(rewriter, location, lowerBounds, tensorType.getShape(), steps, [&]( + mlir::OpBuilder& nestedBuilder, const mlir::Location nestedLocation, + mlir::ValueRange iterators) + { + mlir::Value storedValue = function(nestedBuilder, operands, iterators); + nestedBuilder.create( + nestedLocation, storedValue, allocation, iterators); + }); + + // Use the created buffer to replace the operation. + rewriter.replaceOp(op, allocation); +} + +namespace +{ + /// A generic pattern to convert a hello binary operation into an arith operation. + template + struct BinaryOperationLoweringPattern : mlir::ConversionPattern + { + explicit BinaryOperationLoweringPattern(mlir::MLIRContext* context) : ConversionPattern( + BinaryOperation::getOperationName(), 1, context) + { + } + + mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter& rewriter) const final + { + auto location = op->getLoc(); + lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder, + mlir::ValueRange memRefOperands, + mlir::ValueRange loopIterators) + { + // The adaptor is generated automatically by the framework. + // This adaptor will convert + typename BinaryOperation::Adaptor binaryOperationAdaptor(memRefOperands); + + auto loadLeftOperand = builder.create( + location, binaryOperationAdaptor.getLhs(), loopIterators); + auto loadRightOperand = builder.create( + location, binaryOperationAdaptor.getRhs(), loopIterators); + + return builder.create(location, loadLeftOperand, loadRightOperand); + }); + + return mlir::success(); + } + }; +} + +using AddOperationLoweringPattern = BinaryOperationLoweringPattern; +using MulOperationLoweringPattern = BinaryOperationLoweringPattern; + +namespace +{ + /// Lower the constant values into the buffer. + struct ConstantOperationLoweringPattern : mlir::OpRewritePattern + { + /// Constructor. + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::hello::ConstantOp op, mlir::PatternRewriter& rewriter) const final + { + mlir::DenseElementsAttr constantValues = op.getValue(); + mlir::Location location = op->getLoc(); + + // Allocate buffer for these constant values. + mlir::RankedTensorType tensorType = llvm::cast(op.getType()); + mlir::MemRefType memRefTypes = convertTensorTypeTOMemRefType(tensorType); + mlir::Value allocationBuffer = insertAllocAndDealloc(memRefTypes, location, rewriter); + + // Generate the constant indices up to the largest dimension, + // so that we can avoid a amount of redundant operations. + auto valueShape = memRefTypes.getShape(); + llvm::SmallVector constantIndices; + + if (!valueShape.empty()) + { + for (const auto i : llvm::seq(0, *llvm::max_element(valueShape))) + { + constantIndices.push_back(rewriter.create(location, i)); + } + } + else + { + // If the value shape is empty, as the tensor of rank 0, so this is just a number. + constantIndices.push_back(rewriter.create(location, 0)); + } + + // The store the constant values into the buffer. + // Define the recursive function to store. + + // Vector to store the indices. + llvm::SmallVector indices; + // Iterator to iterate the constant values. + auto valueIterator = constantValues.value_begin(); + + std::function storeElementFunc = [&](const uint64_t dimension) + { + // When the dimension reach the size of shape, we reach the end of recursion, + // where we store the value. + if (dimension == valueShape.size()) + { + rewriter.create( + location, rewriter.create(location, *valueIterator), allocationBuffer, + llvm::ArrayRef(indices)); + ++valueIterator; + return; + } + + // There build the indices. + for (const auto i : llvm::seq(0, valueShape[dimension])) + { + indices.push_back(constantIndices[i]); + storeElementFunc(dimension + 1); + indices.pop_back(); + } + }; + + // Start the recursion from dimension 0; + storeElementFunc(0); + + rewriter.replaceOp(op, allocationBuffer); + return mlir::success(); + } + }; + + /// Lower the hello::func to func:func. + struct FuncOpLoweringPattern : mlir::OpConversionPattern + { + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite(mlir::hello::FuncOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) const final + { + // Only lowering the main function as assuming that other functions have been inlined. + if (op.getName() != "main") + { + return mlir::failure(); + } + + // Validate the main function no arguments and no return values. + if (op.getNumArguments() != 0 || op.getNumResults() != 0) + { + return rewriter.notifyMatchFailure(op, [](mlir::Diagnostic& diagnostic) + { + diagnostic << "Expect the 'main' function to have no arguments and no result."; + }); + } + + // Create a new function with the same region. + auto function = rewriter.create(op.getLoc(), op.getName(), op.getFunctionType()); + rewriter.inlineRegionBefore(op.getRegion(), function.getBody(), function.end()); + rewriter.eraseOp(op); + return mlir::success(); + } + }; + + /// Lower the hello::print, just replace the operand. + struct PrintOpLoweringPattern : mlir::OpConversionPattern + { + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite(mlir::hello::PrintOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) const final + { + rewriter.modifyOpInPlace(op, [&] + { + op->setOperands(adaptor.getOperands()); + }); + + return mlir::success(); + } + }; + + /// Lower hello::return to func::return. + struct ReturnOpLoweringPattern : mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::hello::ReturnOp op, mlir::PatternRewriter& rewriter) const final + { + // As all function calls have been inlined, + // So there is only one return in main function and return no value. + if (op.hasOperand()) + { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp(op); + return mlir::success(); + } + }; + + /// Lower the transpose operation to affine loops. + struct TransposeOpLoweringPattern : mlir::ConversionPattern + { + explicit TransposeOpLoweringPattern(mlir::MLIRContext* context) : ConversionPattern( + mlir::hello::TransposeOp::getOperationName(), + 1, context) + { + } + + mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter& rewriter) const final + { + auto location = op->getLoc(); + + lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder, + mlir::ValueRange memRefOperands, + mlir::ValueRange loopIterators) + { + mlir::hello::TransposeOpAdaptor adaptor(memRefOperands); + mlir::Value input = adaptor.getInput(); + + // Transpose operation just reverse the indices array as (x, y) -> (y, x). + llvm::SmallVector reversedIterators(llvm::reverse(loopIterators)); + return builder.create(location, input, reversedIterators); + }); + + return mlir::success(); + } + }; +} + +namespace +{ + struct HelloToAffineLoopsLoweringPass : mlir::PassWrapper< + HelloToAffineLoopsLoweringPass, mlir::OperationPass> + { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToAffineLoopsLoweringPass); + + /// Argument used by command line options. + llvm::StringRef getArgument() const override + { + return "hello-to-affine-loops"; + } + + void getDependentDialects(mlir::DialectRegistry& registry) const override + { + registry.insert(); + } + + void runOnOperation() final; + }; +} + +void HelloToAffineLoopsLoweringPass::runOnOperation() +{ + mlir::ConversionTarget target(getContext()); + + // Only allow the affine, func, arith and memRef dialect. + target.addLegalDialect(); + // Disable hello dialect. + target.addIllegalDialect(); + // Manual allow the print op which will be allowed to llvm dialect. + target.addDynamicallyLegalOp([](const mlir::hello::PrintOp op) + { + // Make sure the operand is memRef op, as non of the operand types is tensor. + return llvm::none_of(op->getOperandTypes(), llvm::IsaPred); + }); + + // Construct the patterns. + mlir::RewritePatternSet set(&getContext()); + set.add(&getContext()); + + + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(set)))) + { + signalPassFailure(); + } +} + + +namespace mlir::hello +{ + std::unique_ptr createLowerToAffineLoopsPass() + { + return std::make_unique(); + } +} diff --git a/main.cpp b/main.cpp index 0ae2363..203f31e 100644 --- a/main.cpp +++ b/main.cpp @@ -9,6 +9,9 @@ #include #include #include +#include +#include +#include #include "Dialect.h" #include "MLIRGen.h" @@ -26,12 +29,12 @@ static llvm::cl::opt inputFilename(llvm::cl::Positional, namespace { - enum Action { None, DumpSyntaxNode, DumpMLIR }; + enum Action { None, DumpSyntaxNode, DumpMLIR, DumpAffineMLIR }; enum InputType { Hello, MLIR }; } -/// The input file type +/// The input file type. static llvm::cl::opt inputType("x", llvm::cl::init(Hello), llvm::cl::desc("Decided the kind of input desired."), llvm::cl::values( @@ -39,18 +42,21 @@ static llvm::cl::opt inputType("x", llvm::cl::init(Hello), llvm::cl::values( clEnumValN(MLIR, "mlir", "load the input file as a mlir source."))); -/// What is the action the compiler will do +/// What is the action the compiler will do. static llvm::cl::opt emitAction("emit", llvm::cl::desc("Select the kind of output desired"), llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")), - llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code"))); + llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")), + llvm::cl::values(clEnumValN(DumpAffineMLIR, "affine-mlir", + "Dump mlir code after lowering to affine loops"))); +/// Whether to enable the optimization. static llvm::cl::opt enableOpt("opt", llvm::cl::desc("Enable optimizations")); std::unique_ptr parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); - if (std::error_code ec = fileOrErr.getError()) + if (const std::error_code ec = fileOrErr.getError()) { llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return nullptr; @@ -100,7 +106,10 @@ int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::O int dumpMLIR() { - mlir::MLIRContext context; + mlir::DialectRegistry registry; + mlir::func::registerAllExtensions(registry); + + mlir::MLIRContext context(registry); context.getOrLoadDialect(); mlir::OwningOpRef module; @@ -111,15 +120,16 @@ int dumpMLIR() return error; } - if (enableOpt) + const bool isLoweringToAffine = emitAction >= DumpAffineMLIR; + mlir::PassManager manager(module.get()->getName()); + + if (mlir::failed(mlir::applyPassManagerCLOptions(manager))) { - mlir::PassManager manager(module.get()->getName()); - - if (mlir::failed(mlir::applyPassManagerCLOptions(manager))) - { - return 1; - } + return 1; + } + if (enableOpt || isLoweringToAffine) + { // To inline all functions except 'main' function. manager.addPass(mlir::createInlinerPass()); // In the canonicalizer pass, we add Transpose Pass and Reshape Pass. @@ -127,14 +137,24 @@ int dumpMLIR() functionPassManager.addPass(mlir::createCanonicalizerPass()); functionPassManager.addPass(mlir::createCSEPass()); functionPassManager.addPass(mlir::hello::createShapeInferencePass()); + } - if (mlir::failed(manager.run(*module))) + if (isLoweringToAffine) + { + manager.addPass(mlir::hello::createLowerToAffineLoopsPass()); + mlir::OpPassManager& functionPassManager = manager.nest(); + + // Add some optimization from the affine dialect. + if (enableOpt) { - return 1; + manager.addPass(mlir::affine::createLoopFusionPass()); + functionPassManager.addPass(mlir::affine::createAffineScalarReplacementPass()); } + } - module->print(llvm::outs()); - return 0; + if (mlir::failed(manager.run(*module))) + { + return 1; } module->print(llvm::outs()); @@ -172,6 +192,7 @@ int main(int argc, char** argv) case DumpSyntaxNode: return dumpSyntaxNode(); case DumpMLIR: + case DumpAffineMLIR: return dumpMLIR(); default: llvm::errs() << "Unrecognized action\n";