feat: toy tutorial chapter 6.
Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
parent
c5ab1a6bc0
commit
14a2b4c558
|
@ -42,6 +42,7 @@ add_library(HelloDialect STATIC
|
||||||
lib/HelloCombine.cpp
|
lib/HelloCombine.cpp
|
||||||
lib/ShapeInferencePass.cpp
|
lib/ShapeInferencePass.cpp
|
||||||
lib/LowerToAffineLoopsPass.cpp
|
lib/LowerToAffineLoopsPass.cpp
|
||||||
|
lib/LowerToLLVMPass.cpp
|
||||||
|
|
||||||
include/SyntaxNode.h
|
include/SyntaxNode.h
|
||||||
include/Parser.h
|
include/Parser.h
|
||||||
|
@ -54,21 +55,30 @@ add_library(HelloDialect STATIC
|
||||||
add_dependencies(HelloDialect HelloOpsIncGen HelloCombineIncGen HelloInterfaceIncGen)
|
add_dependencies(HelloDialect HelloOpsIncGen HelloCombineIncGen HelloInterfaceIncGen)
|
||||||
|
|
||||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||||
|
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||||
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
|
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
|
||||||
|
|
||||||
target_link_libraries(HelloDialect
|
target_link_libraries(HelloDialect
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${dialect_libs}
|
${dialect_libs}
|
||||||
|
${conversion_libs}
|
||||||
${extension_libs}
|
${extension_libs}
|
||||||
MLIRSupport
|
|
||||||
MLIRAnalysis
|
MLIRAnalysis
|
||||||
MLIRFunctionInterfaces
|
MLIRBuiltinToLLVMIRTranslation
|
||||||
MLIRCallInterfaces
|
MLIRCallInterfaces
|
||||||
MLIRCastInterfaces
|
MLIRCastInterfaces
|
||||||
|
MLIRExecutionEngine
|
||||||
|
MLIRFunctionInterfaces
|
||||||
MLIRIR
|
MLIRIR
|
||||||
|
MLIRLLVMCommonConversion
|
||||||
|
MLIRLLVMToLLVMIRTranslation
|
||||||
|
MLIRMemRefDialect
|
||||||
MLIRParser
|
MLIRParser
|
||||||
MLIRSideEffectInterfaces
|
MLIRSideEffectInterfaces
|
||||||
MLIRTransforms)
|
MLIRSupport
|
||||||
|
MLIRTargetLLVMIRExport
|
||||||
|
MLIRTransforms
|
||||||
|
)
|
||||||
|
|
||||||
add_executable(hello-mlir main.cpp)
|
add_executable(hello-mlir main.cpp)
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@ namespace mlir
|
||||||
std::unique_ptr<Pass> createShapeInferencePass();
|
std::unique_ptr<Pass> createShapeInferencePass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createLowerToAffineLoopsPass();
|
std::unique_ptr<Pass> createLowerToAffineLoopsPass();
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createLowerToLLVMPass();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
209
lib/LowerToLLVMPass.cpp
Normal file
209
lib/LowerToLLVMPass.cpp
Normal file
|
@ -0,0 +1,209 @@
|
||||||
|
//
|
||||||
|
// Created by ricardo on 07/06/25.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||||
|
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||||
|
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||||
|
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||||
|
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||||
|
#include <mlir/Pass/Pass.h>
|
||||||
|
#include <mlir/Transforms/DialectConversion.h>
|
||||||
|
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
|
||||||
|
#include <mlir/Conversion/LLVMCommon/TypeConverter.h>
|
||||||
|
#include <mlir/Conversion/AffineToStandard/AffineToStandard.h>
|
||||||
|
#include <mlir/Conversion/ArithToLLVM/ArithToLLVM.h>
|
||||||
|
#include <mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h>
|
||||||
|
#include <mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h>
|
||||||
|
#include <mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h>
|
||||||
|
#include <mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h>
|
||||||
|
#include <mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h>
|
||||||
|
|
||||||
|
#include "Dialect.h"
|
||||||
|
#include "Passes.h"
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
class PrintOpLoweringPattern : public mlir::ConversionPattern
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
explicit PrintOpLoweringPattern(mlir::MLIRContext* context) : ConversionPattern(
|
||||||
|
mlir::hello::PrintOp::getOperationName(), 1, context)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef<mlir::Value> operands,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter) const final
|
||||||
|
{
|
||||||
|
auto* context = rewriter.getContext();
|
||||||
|
auto memRefType = llvm::cast<mlir::MemRefType>(op->getOperandTypes().front());
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
auto location = op->getLoc();
|
||||||
|
|
||||||
|
auto parentModule = op->getParentOfType<mlir::ModuleOp>();
|
||||||
|
|
||||||
|
// Get the `printf` function declaration.
|
||||||
|
auto printfRef = getOrInsertPrintf(rewriter, parentModule);
|
||||||
|
|
||||||
|
// Create the format string in C format.
|
||||||
|
mlir::Value formatSpecifierString = getOrCreateGlobalString(location, rewriter, "format_specifier",
|
||||||
|
"%f \0",
|
||||||
|
parentModule);
|
||||||
|
// Create the LF format string in C format.
|
||||||
|
mlir::Value newLineString = getOrCreateGlobalString(location, rewriter, "new_line", "\n\0", parentModule);
|
||||||
|
|
||||||
|
// Create a loop to print the value in the tensor.
|
||||||
|
llvm::SmallVector<mlir::Value, 4> loopIterators;
|
||||||
|
for (const auto i : llvm::seq<size_t>(0, memRefShape.size()))
|
||||||
|
{
|
||||||
|
auto lowerBound = rewriter.create<mlir::arith::ConstantIndexOp>(location, 0);
|
||||||
|
auto upperBound = rewriter.create<mlir::arith::ConstantIndexOp>(location, memRefShape[i]);
|
||||||
|
auto step = rewriter.create<mlir::arith::ConstantIndexOp>(location, 1);
|
||||||
|
|
||||||
|
auto loop = rewriter.create<mlir::scf::ForOp>(location, lowerBound, upperBound, step);
|
||||||
|
// FIXME: Remove the nested operation in loop, Why?
|
||||||
|
for (mlir::Operation& nestedOperation : *loop.getBody())
|
||||||
|
{
|
||||||
|
rewriter.eraseOp(&nestedOperation);
|
||||||
|
}
|
||||||
|
|
||||||
|
loopIterators.push_back(loop.getInductionVar());
|
||||||
|
|
||||||
|
// Place the new line output and terminator in the end of loop.
|
||||||
|
rewriter.setInsertionPointToEnd(loop.getBody());
|
||||||
|
|
||||||
|
if (i != memRefShape.size() - 1)
|
||||||
|
{
|
||||||
|
// Add change line when in a row.
|
||||||
|
rewriter.create<mlir::LLVM::CallOp>(location, getPrintfFunctionType(context), printfRef,
|
||||||
|
newLineString);
|
||||||
|
}
|
||||||
|
rewriter.create<mlir::scf::YieldOp>(location);
|
||||||
|
|
||||||
|
// Then place the rewriter at the start of the loop, so when finishing once,
|
||||||
|
// ths rewriter will at the start of newly created loop.
|
||||||
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto printOperation = llvm::cast<mlir::hello::PrintOp>(op);
|
||||||
|
auto loadedElement = rewriter.create<mlir::memref::LoadOp>(location, printOperation.getInput(),
|
||||||
|
loopIterators);
|
||||||
|
rewriter.create<mlir::LLVM::CallOp>(location, getPrintfFunctionType(context), printfRef,
|
||||||
|
llvm::ArrayRef<mlir::Value>({formatSpecifierString, loadedElement}));
|
||||||
|
|
||||||
|
// At last remove the print operation.
|
||||||
|
rewriter.eraseOp(printOperation);
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static mlir::LLVM::LLVMFunctionType getPrintfFunctionType(mlir::MLIRContext* context)
|
||||||
|
{
|
||||||
|
const auto llvmInteger32Type = mlir::IntegerType::get(context, 32);
|
||||||
|
const auto llvmPointerType = mlir::LLVM::LLVMPointerType::get(context);
|
||||||
|
// The `printf` is `int printf(char *, ...)`.
|
||||||
|
const auto llvmFunctionType = mlir::LLVM::LLVMFunctionType::get(llvmInteger32Type, llvmPointerType, true);
|
||||||
|
|
||||||
|
return llvmFunctionType;
|
||||||
|
}
|
||||||
|
|
||||||
|
static mlir::FlatSymbolRefAttr getOrInsertPrintf(mlir::PatternRewriter& rewriter, mlir::ModuleOp& module)
|
||||||
|
{
|
||||||
|
auto* context = module.getContext();
|
||||||
|
if (module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("printf"))
|
||||||
|
{
|
||||||
|
return mlir::SymbolRefAttr::get(context, "printf");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the `printf` declarations in the body of the parent module.
|
||||||
|
mlir::PatternRewriter::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
|
|
||||||
|
rewriter.create<mlir::LLVM::LLVMFuncOp>(module.getLoc(), "printf", getPrintfFunctionType(context));
|
||||||
|
return mlir::SymbolRefAttr::get(context, "printf");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create or get a global string used for print.
|
||||||
|
static mlir::Value getOrCreateGlobalString(mlir::Location& loc, mlir::OpBuilder& builder, llvm::StringRef name,
|
||||||
|
llvm::StringRef value, mlir::ModuleOp& module)
|
||||||
|
{
|
||||||
|
auto global = module.lookupSymbol<mlir::LLVM::GlobalOp>(name);
|
||||||
|
if (!global)
|
||||||
|
{
|
||||||
|
// Failed to find the global value, create it.
|
||||||
|
mlir::OpBuilder::InsertionGuard guard(builder);
|
||||||
|
builder.setInsertionPointToStart(module.getBody());
|
||||||
|
|
||||||
|
auto stringType = mlir::LLVM::LLVMArrayType::get(
|
||||||
|
mlir::IntegerType::get(builder.getContext(), 8), value.size());
|
||||||
|
global = builder.create<mlir::LLVM::GlobalOp>(loc, stringType, true, mlir::LLVM::Linkage::Internal,
|
||||||
|
name, builder.getStringAttr(value), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the pointer to the first character in the global string.
|
||||||
|
mlir::Value globalPointer = builder.create<mlir::LLVM::AddressOfOp>(loc, global);
|
||||||
|
mlir::Value constantZero = builder.create<mlir::LLVM::ConstantOp>(
|
||||||
|
loc, builder.getI64Type(), builder.getIndexAttr(0));
|
||||||
|
|
||||||
|
return builder.create<mlir::LLVM::GEPOp>(loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()),
|
||||||
|
global.getType(),
|
||||||
|
globalPointer, llvm::ArrayRef({
|
||||||
|
constantZero, constantZero
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct HelloToLLVMLoweringPass : mlir::PassWrapper<HelloToLLVMLoweringPass, mlir::OperationPass<mlir::ModuleOp>>
|
||||||
|
{
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToLLVMLoweringPass)
|
||||||
|
|
||||||
|
llvm::StringRef getArgument() const final
|
||||||
|
{
|
||||||
|
return "hello-to-llvm";
|
||||||
|
}
|
||||||
|
|
||||||
|
void getDependentDialects(mlir::DialectRegistry& registry) const final
|
||||||
|
{
|
||||||
|
registry.insert<mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() final;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
void HelloToLLVMLoweringPass::runOnOperation()
|
||||||
|
{
|
||||||
|
mlir::LLVMConversionTarget target(getContext());
|
||||||
|
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
|
||||||
|
target.addLegalOp<mlir::ModuleOp>();
|
||||||
|
|
||||||
|
mlir::LLVMTypeConverter typeConverter(&getContext());
|
||||||
|
|
||||||
|
mlir::RewritePatternSet patterns(&getContext());
|
||||||
|
|
||||||
|
// The lower path is a little bit complicated.
|
||||||
|
// Convert affine dialect to standard dialect.
|
||||||
|
mlir::populateAffineToStdConversionPatterns(patterns);
|
||||||
|
// Convert scf dialect to cf dialect.
|
||||||
|
mlir::populateSCFToControlFlowConversionPatterns(patterns);
|
||||||
|
// Convert arith to llvm dialect.
|
||||||
|
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
mlir::populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
mlir::populateFuncToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
|
patterns.add<PrintOpLoweringPattern>(&getContext());
|
||||||
|
|
||||||
|
// As we need to convert all operations to LLVM dialect, so
|
||||||
|
// we perform a full convertion, which will permit the legal opertions in result.
|
||||||
|
if (const auto module = getOperation();
|
||||||
|
mlir::failed(mlir::applyFullConversion(module, target, std::move(patterns))))
|
||||||
|
{
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> mlir::hello::createLowerToLLVMPass()
|
||||||
|
{
|
||||||
|
return std::make_unique<HelloToLLVMLoweringPass>();
|
||||||
|
}
|
175
main.cpp
175
main.cpp
|
@ -1,3 +1,5 @@
|
||||||
|
#include <llvm/IR/LLVMContext.h>
|
||||||
|
|
||||||
#include "Lexer.h"
|
#include "Lexer.h"
|
||||||
#include "Parser.h"
|
#include "Parser.h"
|
||||||
|
|
||||||
|
@ -12,6 +14,15 @@
|
||||||
#include <mlir/Dialect/Func/Extensions/AllExtensions.h>
|
#include <mlir/Dialect/Func/Extensions/AllExtensions.h>
|
||||||
#include <mlir/Dialect/Affine/Passes.h>
|
#include <mlir/Dialect/Affine/Passes.h>
|
||||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||||
|
#include <mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h>
|
||||||
|
#include <mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h>
|
||||||
|
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||||
|
#include <mlir/Target/LLVMIR/Export.h>
|
||||||
|
#include <mlir/ExecutionEngine/ExecutionEngine.h>
|
||||||
|
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||||
|
#include <llvm/IR/Module.h>
|
||||||
|
#include <llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h>
|
||||||
|
#include <llvm/Support/TargetSelect.h>
|
||||||
|
|
||||||
#include "Dialect.h"
|
#include "Dialect.h"
|
||||||
#include "MLIRGen.h"
|
#include "MLIRGen.h"
|
||||||
|
@ -29,7 +40,7 @@ static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
enum Action { None, DumpSyntaxNode, DumpMLIR, DumpAffineMLIR };
|
enum Action { None, DumpSyntaxNode, DumpMLIR, DumpAffineMLIR, DumpLLVMMLIR, DumpLLVM, RunJit };
|
||||||
|
|
||||||
enum InputType { Hello, MLIR };
|
enum InputType { Hello, MLIR };
|
||||||
}
|
}
|
||||||
|
@ -47,7 +58,12 @@ static llvm::cl::opt<Action> emitAction("emit", llvm::cl::desc("Select the kind
|
||||||
llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")),
|
llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")),
|
||||||
llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")),
|
llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")),
|
||||||
llvm::cl::values(clEnumValN(DumpAffineMLIR, "affine-mlir",
|
llvm::cl::values(clEnumValN(DumpAffineMLIR, "affine-mlir",
|
||||||
"Dump mlir code after lowering to affine loops")));
|
"Dump mlir code after lowering to affine loops")),
|
||||||
|
llvm::cl::values(clEnumValN(DumpLLVMMLIR, "llvm-mlir",
|
||||||
|
"Dump mlir code after lowering to llvm.")),
|
||||||
|
llvm::cl::values(clEnumValN(DumpLLVM, "llvm",
|
||||||
|
"Dump llvm code.")),
|
||||||
|
llvm::cl::values(clEnumValN(RunJit, "jit", "Run the input by jitter.")));
|
||||||
|
|
||||||
/// Whether to enable the optimization.
|
/// Whether to enable the optimization.
|
||||||
static llvm::cl::opt<bool> enableOpt("opt", llvm::cl::desc("Enable optimizations"));
|
static llvm::cl::opt<bool> enableOpt("opt", llvm::cl::desc("Enable optimizations"));
|
||||||
|
@ -67,10 +83,11 @@ std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
|
||||||
return parser.parseModule();
|
return parser.parseModule();
|
||||||
}
|
}
|
||||||
|
|
||||||
int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::OwningOpRef<mlir::ModuleOp>& module)
|
int loadAndProcessMLIR(mlir::MLIRContext& context, mlir::OwningOpRef<mlir::ModuleOp>& module)
|
||||||
{
|
{
|
||||||
if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir"))
|
if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir"))
|
||||||
{
|
{
|
||||||
|
// The input file is hello language.
|
||||||
auto syntaxNode = parseInputFile(inputFilename);
|
auto syntaxNode = parseInputFile(inputFilename);
|
||||||
if (syntaxNode == nullptr)
|
if (syntaxNode == nullptr)
|
||||||
{
|
{
|
||||||
|
@ -78,10 +95,15 @@ int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::O
|
||||||
}
|
}
|
||||||
module = hello::mlirGen(context, *syntaxNode);
|
module = hello::mlirGen(context, *syntaxNode);
|
||||||
|
|
||||||
return module ? 0 : 1;
|
if (!module)
|
||||||
|
{
|
||||||
|
llvm::errs() << "Failed to convert hello syntax tree to MLIR.\n";
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// Then the input file is mlir
|
else
|
||||||
|
{
|
||||||
|
// The the input file is mlir.
|
||||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||||
if (std::error_code ec = fileOrErr.getError())
|
if (std::error_code ec = fileOrErr.getError())
|
||||||
|
@ -100,27 +122,10 @@ int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::O
|
||||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int dumpMLIR()
|
|
||||||
{
|
|
||||||
mlir::DialectRegistry registry;
|
|
||||||
mlir::func::registerAllExtensions(registry);
|
|
||||||
|
|
||||||
mlir::MLIRContext context(registry);
|
|
||||||
context.getOrLoadDialect<mlir::hello::HelloDialect>();
|
|
||||||
|
|
||||||
mlir::OwningOpRef<mlir::ModuleOp> module;
|
|
||||||
llvm::SourceMgr sourceManager;
|
|
||||||
|
|
||||||
if (int error = loadMLIR(sourceManager, context, module))
|
|
||||||
{
|
|
||||||
return error;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool isLoweringToAffine = emitAction >= DumpAffineMLIR;
|
const bool isLoweringToAffine = emitAction >= DumpAffineMLIR;
|
||||||
|
const bool isLoweringToLLVM = emitAction >= DumpLLVMMLIR;
|
||||||
mlir::PassManager manager(module.get()->getName());
|
mlir::PassManager manager(module.get()->getName());
|
||||||
|
|
||||||
if (mlir::failed(mlir::applyPassManagerCLOptions(manager)))
|
if (mlir::failed(mlir::applyPassManagerCLOptions(manager)))
|
||||||
|
@ -152,12 +157,60 @@ int dumpMLIR()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isLoweringToLLVM)
|
||||||
|
{
|
||||||
|
manager.addPass(mlir::hello::createLowerToLLVMPass());
|
||||||
|
}
|
||||||
|
|
||||||
if (mlir::failed(manager.run(*module)))
|
if (mlir::failed(manager.run(*module)))
|
||||||
{
|
{
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
module->print(llvm::outs());
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int dumpLLVMIR(mlir::ModuleOp module)
|
||||||
|
{
|
||||||
|
mlir::registerBuiltinDialectTranslation(*module->getContext());
|
||||||
|
mlir::registerLLVMDialectTranslation(*module->getContext());
|
||||||
|
|
||||||
|
// Convert the mlir IR to llvm IR.
|
||||||
|
llvm::LLVMContext context;
|
||||||
|
const auto llvmModule = mlir::translateModuleToLLVMIR(module, context);
|
||||||
|
if (llvmModule == nullptr)
|
||||||
|
{
|
||||||
|
llvm::errs() << "Could not translate LLVM IR to LLVM IR\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::InitializeNativeTarget();
|
||||||
|
llvm::InitializeNativeTargetAsmPrinter();
|
||||||
|
|
||||||
|
auto targetMachineBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
|
||||||
|
if (!targetMachineBuilderOrError)
|
||||||
|
{
|
||||||
|
llvm::errs() << "Could not detect host machine\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto targetMachineOrError = targetMachineBuilderOrError->createTargetMachine();
|
||||||
|
if (!targetMachineOrError)
|
||||||
|
{
|
||||||
|
llvm::errs() << "Could not create target machine\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::ExecutionEngine::setupTargetTripleAndDataLayout(llvmModule.get(), targetMachineOrError.get().get());
|
||||||
|
|
||||||
|
const auto optimizationPipeline = mlir::makeOptimizingTransformer(enableOpt ? 3 : 0, 0, nullptr);
|
||||||
|
if (auto err = optimizationPipeline(llvmModule.get()))
|
||||||
|
{
|
||||||
|
llvm::errs() << "Failed to run optimization pipeline for LLVM IR:" << err << "\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::outs() << *llvmModule << "\n";
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,6 +232,37 @@ int dumpSyntaxNode()
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int runJitter(mlir::ModuleOp module)
|
||||||
|
{
|
||||||
|
llvm::InitializeNativeTarget();
|
||||||
|
llvm::InitializeNativeTargetAsmPrinter();
|
||||||
|
|
||||||
|
mlir::registerBuiltinDialectTranslation(*module->getContext());
|
||||||
|
mlir::registerLLVMDialectTranslation(*module->getContext());
|
||||||
|
|
||||||
|
auto optimizationPipeline = mlir::makeOptimizingTransformer(enableOpt ? 3 : 0, 0, nullptr);
|
||||||
|
|
||||||
|
mlir::ExecutionEngineOptions options;
|
||||||
|
options.transformer = optimizationPipeline;
|
||||||
|
|
||||||
|
auto engineOrError = mlir::ExecutionEngine::create(module, options);
|
||||||
|
if (!engineOrError)
|
||||||
|
{
|
||||||
|
llvm::errs() << "Failed to create execution engine\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& engine = engineOrError.get();
|
||||||
|
|
||||||
|
if (auto invocationResult = engine->invokePacked("main"))
|
||||||
|
{
|
||||||
|
llvm::errs() << "Failed to run main function by jitter?:" << invocationResult << "\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
mlir::registerAsmPrinterCLOptions();
|
mlir::registerAsmPrinterCLOptions();
|
||||||
|
@ -187,15 +271,40 @@ int main(int argc, char** argv)
|
||||||
|
|
||||||
llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n");
|
llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n");
|
||||||
|
|
||||||
switch (emitAction)
|
if (emitAction == DumpSyntaxNode)
|
||||||
{
|
{
|
||||||
case DumpSyntaxNode:
|
|
||||||
return dumpSyntaxNode();
|
return dumpSyntaxNode();
|
||||||
case DumpMLIR:
|
|
||||||
case DumpAffineMLIR:
|
|
||||||
return dumpMLIR();
|
|
||||||
default:
|
|
||||||
llvm::errs() << "Unrecognized action\n";
|
|
||||||
return 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mlir::DialectRegistry registry;
|
||||||
|
mlir::func::registerAllExtensions(registry);
|
||||||
|
mlir::LLVM::registerInlinerInterface(registry);
|
||||||
|
|
||||||
|
mlir::MLIRContext context(registry);
|
||||||
|
context.getOrLoadDialect<mlir::hello::HelloDialect>();
|
||||||
|
|
||||||
|
mlir::OwningOpRef<mlir::ModuleOp> module;
|
||||||
|
if (int error = loadAndProcessMLIR(context, module))
|
||||||
|
{
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (emitAction <= DumpLLVMMLIR)
|
||||||
|
{
|
||||||
|
llvm::outs() << *module << "\n";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (emitAction == DumpLLVM)
|
||||||
|
{
|
||||||
|
return dumpLLVMIR(*module);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (emitAction == RunJit)
|
||||||
|
{
|
||||||
|
return runJitter(*module);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user