feat: toy tutorial chapter 6.
Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
209
lib/LowerToLLVMPass.cpp
Normal file
209
lib/LowerToLLVMPass.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
//
|
||||
// Created by ricardo on 07/06/25.
|
||||
//
|
||||
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
|
||||
#include <mlir/Conversion/LLVMCommon/TypeConverter.h>
|
||||
#include <mlir/Conversion/AffineToStandard/AffineToStandard.h>
|
||||
#include <mlir/Conversion/ArithToLLVM/ArithToLLVM.h>
|
||||
#include <mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h>
|
||||
#include <mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h>
|
||||
#include <mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h>
|
||||
#include <mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h>
|
||||
#include <mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h>
|
||||
|
||||
#include "Dialect.h"
|
||||
#include "Passes.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
class PrintOpLoweringPattern : public mlir::ConversionPattern
|
||||
{
|
||||
public:
|
||||
explicit PrintOpLoweringPattern(mlir::MLIRContext* context) : ConversionPattern(
|
||||
mlir::hello::PrintOp::getOperationName(), 1, context)
|
||||
{
|
||||
}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef<mlir::Value> operands,
|
||||
mlir::ConversionPatternRewriter& rewriter) const final
|
||||
{
|
||||
auto* context = rewriter.getContext();
|
||||
auto memRefType = llvm::cast<mlir::MemRefType>(op->getOperandTypes().front());
|
||||
auto memRefShape = memRefType.getShape();
|
||||
auto location = op->getLoc();
|
||||
|
||||
auto parentModule = op->getParentOfType<mlir::ModuleOp>();
|
||||
|
||||
// Get the `printf` function declaration.
|
||||
auto printfRef = getOrInsertPrintf(rewriter, parentModule);
|
||||
|
||||
// Create the format string in C format.
|
||||
mlir::Value formatSpecifierString = getOrCreateGlobalString(location, rewriter, "format_specifier",
|
||||
"%f \0",
|
||||
parentModule);
|
||||
// Create the LF format string in C format.
|
||||
mlir::Value newLineString = getOrCreateGlobalString(location, rewriter, "new_line", "\n\0", parentModule);
|
||||
|
||||
// Create a loop to print the value in the tensor.
|
||||
llvm::SmallVector<mlir::Value, 4> loopIterators;
|
||||
for (const auto i : llvm::seq<size_t>(0, memRefShape.size()))
|
||||
{
|
||||
auto lowerBound = rewriter.create<mlir::arith::ConstantIndexOp>(location, 0);
|
||||
auto upperBound = rewriter.create<mlir::arith::ConstantIndexOp>(location, memRefShape[i]);
|
||||
auto step = rewriter.create<mlir::arith::ConstantIndexOp>(location, 1);
|
||||
|
||||
auto loop = rewriter.create<mlir::scf::ForOp>(location, lowerBound, upperBound, step);
|
||||
// FIXME: Remove the nested operation in loop, Why?
|
||||
for (mlir::Operation& nestedOperation : *loop.getBody())
|
||||
{
|
||||
rewriter.eraseOp(&nestedOperation);
|
||||
}
|
||||
|
||||
loopIterators.push_back(loop.getInductionVar());
|
||||
|
||||
// Place the new line output and terminator in the end of loop.
|
||||
rewriter.setInsertionPointToEnd(loop.getBody());
|
||||
|
||||
if (i != memRefShape.size() - 1)
|
||||
{
|
||||
// Add change line when in a row.
|
||||
rewriter.create<mlir::LLVM::CallOp>(location, getPrintfFunctionType(context), printfRef,
|
||||
newLineString);
|
||||
}
|
||||
rewriter.create<mlir::scf::YieldOp>(location);
|
||||
|
||||
// Then place the rewriter at the start of the loop, so when finishing once,
|
||||
// ths rewriter will at the start of newly created loop.
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
}
|
||||
|
||||
auto printOperation = llvm::cast<mlir::hello::PrintOp>(op);
|
||||
auto loadedElement = rewriter.create<mlir::memref::LoadOp>(location, printOperation.getInput(),
|
||||
loopIterators);
|
||||
rewriter.create<mlir::LLVM::CallOp>(location, getPrintfFunctionType(context), printfRef,
|
||||
llvm::ArrayRef<mlir::Value>({formatSpecifierString, loadedElement}));
|
||||
|
||||
// At last remove the print operation.
|
||||
rewriter.eraseOp(printOperation);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
private:
|
||||
static mlir::LLVM::LLVMFunctionType getPrintfFunctionType(mlir::MLIRContext* context)
|
||||
{
|
||||
const auto llvmInteger32Type = mlir::IntegerType::get(context, 32);
|
||||
const auto llvmPointerType = mlir::LLVM::LLVMPointerType::get(context);
|
||||
// The `printf` is `int printf(char *, ...)`.
|
||||
const auto llvmFunctionType = mlir::LLVM::LLVMFunctionType::get(llvmInteger32Type, llvmPointerType, true);
|
||||
|
||||
return llvmFunctionType;
|
||||
}
|
||||
|
||||
static mlir::FlatSymbolRefAttr getOrInsertPrintf(mlir::PatternRewriter& rewriter, mlir::ModuleOp& module)
|
||||
{
|
||||
auto* context = module.getContext();
|
||||
if (module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("printf"))
|
||||
{
|
||||
return mlir::SymbolRefAttr::get(context, "printf");
|
||||
}
|
||||
|
||||
// Insert the `printf` declarations in the body of the parent module.
|
||||
mlir::PatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
|
||||
rewriter.create<mlir::LLVM::LLVMFuncOp>(module.getLoc(), "printf", getPrintfFunctionType(context));
|
||||
return mlir::SymbolRefAttr::get(context, "printf");
|
||||
}
|
||||
|
||||
/// Create or get a global string used for print.
|
||||
static mlir::Value getOrCreateGlobalString(mlir::Location& loc, mlir::OpBuilder& builder, llvm::StringRef name,
|
||||
llvm::StringRef value, mlir::ModuleOp& module)
|
||||
{
|
||||
auto global = module.lookupSymbol<mlir::LLVM::GlobalOp>(name);
|
||||
if (!global)
|
||||
{
|
||||
// Failed to find the global value, create it.
|
||||
mlir::OpBuilder::InsertionGuard guard(builder);
|
||||
builder.setInsertionPointToStart(module.getBody());
|
||||
|
||||
auto stringType = mlir::LLVM::LLVMArrayType::get(
|
||||
mlir::IntegerType::get(builder.getContext(), 8), value.size());
|
||||
global = builder.create<mlir::LLVM::GlobalOp>(loc, stringType, true, mlir::LLVM::Linkage::Internal,
|
||||
name, builder.getStringAttr(value), 0);
|
||||
}
|
||||
|
||||
// Get the pointer to the first character in the global string.
|
||||
mlir::Value globalPointer = builder.create<mlir::LLVM::AddressOfOp>(loc, global);
|
||||
mlir::Value constantZero = builder.create<mlir::LLVM::ConstantOp>(
|
||||
loc, builder.getI64Type(), builder.getIndexAttr(0));
|
||||
|
||||
return builder.create<mlir::LLVM::GEPOp>(loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()),
|
||||
global.getType(),
|
||||
globalPointer, llvm::ArrayRef({
|
||||
constantZero, constantZero
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
struct HelloToLLVMLoweringPass : mlir::PassWrapper<HelloToLLVMLoweringPass, mlir::OperationPass<mlir::ModuleOp>>
|
||||
{
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToLLVMLoweringPass)
|
||||
|
||||
llvm::StringRef getArgument() const final
|
||||
{
|
||||
return "hello-to-llvm";
|
||||
}
|
||||
|
||||
void getDependentDialects(mlir::DialectRegistry& registry) const final
|
||||
{
|
||||
registry.insert<mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() final;
|
||||
};
|
||||
}
|
||||
|
||||
void HelloToLLVMLoweringPass::runOnOperation()
|
||||
{
|
||||
mlir::LLVMConversionTarget target(getContext());
|
||||
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
|
||||
target.addLegalOp<mlir::ModuleOp>();
|
||||
|
||||
mlir::LLVMTypeConverter typeConverter(&getContext());
|
||||
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// The lower path is a little bit complicated.
|
||||
// Convert affine dialect to standard dialect.
|
||||
mlir::populateAffineToStdConversionPatterns(patterns);
|
||||
// Convert scf dialect to cf dialect.
|
||||
mlir::populateSCFToControlFlowConversionPatterns(patterns);
|
||||
// Convert arith to llvm dialect.
|
||||
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateFuncToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
patterns.add<PrintOpLoweringPattern>(&getContext());
|
||||
|
||||
// As we need to convert all operations to LLVM dialect, so
|
||||
// we perform a full convertion, which will permit the legal opertions in result.
|
||||
if (const auto module = getOperation();
|
||||
mlir::failed(mlir::applyFullConversion(module, target, std::move(patterns))))
|
||||
{
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::Pass> mlir::hello::createLowerToLLVMPass()
|
||||
{
|
||||
return std::make_unique<HelloToLLVMLoweringPass>();
|
||||
}
|
Reference in New Issue
Block a user