diff --git a/CMakeLists.txt b/CMakeLists.txt index 44d929f..591bcce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,6 +42,7 @@ add_library(HelloDialect STATIC lib/HelloCombine.cpp lib/ShapeInferencePass.cpp lib/LowerToAffineLoopsPass.cpp + lib/LowerToLLVMPass.cpp include/SyntaxNode.h include/Parser.h @@ -54,21 +55,30 @@ add_library(HelloDialect STATIC add_dependencies(HelloDialect HelloOpsIncGen HelloCombineIncGen HelloInterfaceIncGen) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(HelloDialect PRIVATE ${dialect_libs} + ${conversion_libs} ${extension_libs} - MLIRSupport MLIRAnalysis - MLIRFunctionInterfaces + MLIRBuiltinToLLVMIRTranslation MLIRCallInterfaces MLIRCastInterfaces + MLIRExecutionEngine + MLIRFunctionInterfaces MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMToLLVMIRTranslation + MLIRMemRefDialect MLIRParser MLIRSideEffectInterfaces - MLIRTransforms) + MLIRSupport + MLIRTargetLLVMIRExport + MLIRTransforms +) add_executable(hello-mlir main.cpp) diff --git a/include/Passes.h b/include/Passes.h index b10a43c..881bf5e 100644 --- a/include/Passes.h +++ b/include/Passes.h @@ -16,6 +16,8 @@ namespace mlir std::unique_ptr createShapeInferencePass(); std::unique_ptr createLowerToAffineLoopsPass(); + + std::unique_ptr createLowerToLLVMPass(); } } diff --git a/lib/LowerToLLVMPass.cpp b/lib/LowerToLLVMPass.cpp new file mode 100644 index 0000000..00f769f --- /dev/null +++ b/lib/LowerToLLVMPass.cpp @@ -0,0 +1,209 @@ +// +// Created by ricardo on 07/06/25. +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Dialect.h" +#include "Passes.h" + +namespace +{ + class PrintOpLoweringPattern : public mlir::ConversionPattern + { + public: + explicit PrintOpLoweringPattern(mlir::MLIRContext* context) : ConversionPattern( + mlir::hello::PrintOp::getOperationName(), 1, context) + { + } + + mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter& rewriter) const final + { + auto* context = rewriter.getContext(); + auto memRefType = llvm::cast(op->getOperandTypes().front()); + auto memRefShape = memRefType.getShape(); + auto location = op->getLoc(); + + auto parentModule = op->getParentOfType(); + + // Get the `printf` function declaration. + auto printfRef = getOrInsertPrintf(rewriter, parentModule); + + // Create the format string in C format. + mlir::Value formatSpecifierString = getOrCreateGlobalString(location, rewriter, "format_specifier", + "%f \0", + parentModule); + // Create the LF format string in C format. + mlir::Value newLineString = getOrCreateGlobalString(location, rewriter, "new_line", "\n\0", parentModule); + + // Create a loop to print the value in the tensor. + llvm::SmallVector loopIterators; + for (const auto i : llvm::seq(0, memRefShape.size())) + { + auto lowerBound = rewriter.create(location, 0); + auto upperBound = rewriter.create(location, memRefShape[i]); + auto step = rewriter.create(location, 1); + + auto loop = rewriter.create(location, lowerBound, upperBound, step); + // FIXME: Remove the nested operation in loop, Why? + for (mlir::Operation& nestedOperation : *loop.getBody()) + { + rewriter.eraseOp(&nestedOperation); + } + + loopIterators.push_back(loop.getInductionVar()); + + // Place the new line output and terminator in the end of loop. + rewriter.setInsertionPointToEnd(loop.getBody()); + + if (i != memRefShape.size() - 1) + { + // Add change line when in a row. + rewriter.create(location, getPrintfFunctionType(context), printfRef, + newLineString); + } + rewriter.create(location); + + // Then place the rewriter at the start of the loop, so when finishing once, + // ths rewriter will at the start of newly created loop. + rewriter.setInsertionPointToStart(loop.getBody()); + } + + auto printOperation = llvm::cast(op); + auto loadedElement = rewriter.create(location, printOperation.getInput(), + loopIterators); + rewriter.create(location, getPrintfFunctionType(context), printfRef, + llvm::ArrayRef({formatSpecifierString, loadedElement})); + + // At last remove the print operation. + rewriter.eraseOp(printOperation); + return mlir::success(); + } + + private: + static mlir::LLVM::LLVMFunctionType getPrintfFunctionType(mlir::MLIRContext* context) + { + const auto llvmInteger32Type = mlir::IntegerType::get(context, 32); + const auto llvmPointerType = mlir::LLVM::LLVMPointerType::get(context); + // The `printf` is `int printf(char *, ...)`. + const auto llvmFunctionType = mlir::LLVM::LLVMFunctionType::get(llvmInteger32Type, llvmPointerType, true); + + return llvmFunctionType; + } + + static mlir::FlatSymbolRefAttr getOrInsertPrintf(mlir::PatternRewriter& rewriter, mlir::ModuleOp& module) + { + auto* context = module.getContext(); + if (module.lookupSymbol("printf")) + { + return mlir::SymbolRefAttr::get(context, "printf"); + } + + // Insert the `printf` declarations in the body of the parent module. + mlir::PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + rewriter.create(module.getLoc(), "printf", getPrintfFunctionType(context)); + return mlir::SymbolRefAttr::get(context, "printf"); + } + + /// Create or get a global string used for print. + static mlir::Value getOrCreateGlobalString(mlir::Location& loc, mlir::OpBuilder& builder, llvm::StringRef name, + llvm::StringRef value, mlir::ModuleOp& module) + { + auto global = module.lookupSymbol(name); + if (!global) + { + // Failed to find the global value, create it. + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto stringType = mlir::LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size()); + global = builder.create(loc, stringType, true, mlir::LLVM::Linkage::Internal, + name, builder.getStringAttr(value), 0); + } + + // Get the pointer to the first character in the global string. + mlir::Value globalPointer = builder.create(loc, global); + mlir::Value constantZero = builder.create( + loc, builder.getI64Type(), builder.getIndexAttr(0)); + + return builder.create(loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()), + global.getType(), + globalPointer, llvm::ArrayRef({ + constantZero, constantZero + })); + } + }; + + struct HelloToLLVMLoweringPass : mlir::PassWrapper> + { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToLLVMLoweringPass) + + llvm::StringRef getArgument() const final + { + return "hello-to-llvm"; + } + + void getDependentDialects(mlir::DialectRegistry& registry) const final + { + registry.insert(); + } + + void runOnOperation() final; + }; +} + +void HelloToLLVMLoweringPass::runOnOperation() +{ + mlir::LLVMConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + + mlir::LLVMTypeConverter typeConverter(&getContext()); + + mlir::RewritePatternSet patterns(&getContext()); + + // The lower path is a little bit complicated. + // Convert affine dialect to standard dialect. + mlir::populateAffineToStdConversionPatterns(patterns); + // Convert scf dialect to cf dialect. + mlir::populateSCFToControlFlowConversionPatterns(patterns); + // Convert arith to llvm dialect. + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateFuncToLLVMConversionPatterns(typeConverter, patterns); + + patterns.add(&getContext()); + + // As we need to convert all operations to LLVM dialect, so + // we perform a full convertion, which will permit the legal opertions in result. + if (const auto module = getOperation(); + mlir::failed(mlir::applyFullConversion(module, target, std::move(patterns)))) + { + signalPassFailure(); + } +} + +std::unique_ptr mlir::hello::createLowerToLLVMPass() +{ + return std::make_unique(); +} diff --git a/main.cpp b/main.cpp index 203f31e..88d081d 100644 --- a/main.cpp +++ b/main.cpp @@ -1,3 +1,5 @@ +#include + #include "Lexer.h" #include "Parser.h" @@ -12,6 +14,15 @@ #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include "Dialect.h" #include "MLIRGen.h" @@ -29,7 +40,7 @@ static llvm::cl::opt inputFilename(llvm::cl::Positional, namespace { - enum Action { None, DumpSyntaxNode, DumpMLIR, DumpAffineMLIR }; + enum Action { None, DumpSyntaxNode, DumpMLIR, DumpAffineMLIR, DumpLLVMMLIR, DumpLLVM, RunJit }; enum InputType { Hello, MLIR }; } @@ -47,7 +58,12 @@ static llvm::cl::opt emitAction("emit", llvm::cl::desc("Select the kind llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")), llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")), llvm::cl::values(clEnumValN(DumpAffineMLIR, "affine-mlir", - "Dump mlir code after lowering to affine loops"))); + "Dump mlir code after lowering to affine loops")), + llvm::cl::values(clEnumValN(DumpLLVMMLIR, "llvm-mlir", + "Dump mlir code after lowering to llvm.")), + llvm::cl::values(clEnumValN(DumpLLVM, "llvm", + "Dump llvm code.")), + llvm::cl::values(clEnumValN(RunJit, "jit", "Run the input by jitter."))); /// Whether to enable the optimization. static llvm::cl::opt enableOpt("opt", llvm::cl::desc("Enable optimizations")); @@ -67,10 +83,11 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) return parser.parseModule(); } -int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::OwningOpRef& module) +int loadAndProcessMLIR(mlir::MLIRContext& context, mlir::OwningOpRef& module) { if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir")) { + // The input file is hello language. auto syntaxNode = parseInputFile(inputFilename); if (syntaxNode == nullptr) { @@ -78,49 +95,37 @@ int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::O } module = hello::mlirGen(context, *syntaxNode); - return module ? 0 : 1; + if (!module) + { + llvm::errs() << "Failed to convert hello syntax tree to MLIR.\n"; + return 1; + } } - - // Then the input file is mlir - llvm::ErrorOr> fileOrErr = - llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); - if (std::error_code ec = fileOrErr.getError()) + else { - llvm::errs() << "Could not open input file: " << ec.message() << "\n"; - return 1; - } + // The the input file is mlir. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code ec = fileOrErr.getError()) + { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return 1; + } - // Parse the input mlir. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module = - mlir::parseSourceFile(sourceMgr, &context); - if (!module) - { - llvm::errs() << "Error can't load file " << inputFilename << "\n"; - return 1; - } - - return 0; -} - -int dumpMLIR() -{ - mlir::DialectRegistry registry; - mlir::func::registerAllExtensions(registry); - - mlir::MLIRContext context(registry); - context.getOrLoadDialect(); - - mlir::OwningOpRef module; - llvm::SourceMgr sourceManager; - - if (int error = loadMLIR(sourceManager, context, module)) - { - return error; + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = + mlir::parseSourceFile(sourceMgr, &context); + if (!module) + { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 1; + } } const bool isLoweringToAffine = emitAction >= DumpAffineMLIR; + const bool isLoweringToLLVM = emitAction >= DumpLLVMMLIR; mlir::PassManager manager(module.get()->getName()); if (mlir::failed(mlir::applyPassManagerCLOptions(manager))) @@ -152,12 +157,60 @@ int dumpMLIR() } } + if (isLoweringToLLVM) + { + manager.addPass(mlir::hello::createLowerToLLVMPass()); + } + if (mlir::failed(manager.run(*module))) { return 1; } - module->print(llvm::outs()); + return 0; +} + +int dumpLLVMIR(mlir::ModuleOp module) +{ + mlir::registerBuiltinDialectTranslation(*module->getContext()); + mlir::registerLLVMDialectTranslation(*module->getContext()); + + // Convert the mlir IR to llvm IR. + llvm::LLVMContext context; + const auto llvmModule = mlir::translateModuleToLLVMIR(module, context); + if (llvmModule == nullptr) + { + llvm::errs() << "Could not translate LLVM IR to LLVM IR\n"; + return 1; + } + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + auto targetMachineBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!targetMachineBuilderOrError) + { + llvm::errs() << "Could not detect host machine\n"; + return 1; + } + + auto targetMachineOrError = targetMachineBuilderOrError->createTargetMachine(); + if (!targetMachineOrError) + { + llvm::errs() << "Could not create target machine\n"; + return 1; + } + + mlir::ExecutionEngine::setupTargetTripleAndDataLayout(llvmModule.get(), targetMachineOrError.get().get()); + + const auto optimizationPipeline = mlir::makeOptimizingTransformer(enableOpt ? 3 : 0, 0, nullptr); + if (auto err = optimizationPipeline(llvmModule.get())) + { + llvm::errs() << "Failed to run optimization pipeline for LLVM IR:" << err << "\n"; + return 1; + } + + llvm::outs() << *llvmModule << "\n"; return 0; } @@ -179,6 +232,37 @@ int dumpSyntaxNode() return 0; } +int runJitter(mlir::ModuleOp module) +{ + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + mlir::registerBuiltinDialectTranslation(*module->getContext()); + mlir::registerLLVMDialectTranslation(*module->getContext()); + + auto optimizationPipeline = mlir::makeOptimizingTransformer(enableOpt ? 3 : 0, 0, nullptr); + + mlir::ExecutionEngineOptions options; + options.transformer = optimizationPipeline; + + auto engineOrError = mlir::ExecutionEngine::create(module, options); + if (!engineOrError) + { + llvm::errs() << "Failed to create execution engine\n"; + return 1; + } + + const auto& engine = engineOrError.get(); + + if (auto invocationResult = engine->invokePacked("main")) + { + llvm::errs() << "Failed to run main function by jitter?:" << invocationResult << "\n"; + return 1; + } + + return 0; +} + int main(int argc, char** argv) { mlir::registerAsmPrinterCLOptions(); @@ -187,15 +271,40 @@ int main(int argc, char** argv) llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n"); - switch (emitAction) + if (emitAction == DumpSyntaxNode) { - case DumpSyntaxNode: return dumpSyntaxNode(); - case DumpMLIR: - case DumpAffineMLIR: - return dumpMLIR(); - default: - llvm::errs() << "Unrecognized action\n"; - return 1; } + + mlir::DialectRegistry registry; + mlir::func::registerAllExtensions(registry); + mlir::LLVM::registerInlinerInterface(registry); + + mlir::MLIRContext context(registry); + context.getOrLoadDialect(); + + mlir::OwningOpRef module; + if (int error = loadAndProcessMLIR(context, module)) + { + return error; + } + + if (emitAction <= DumpLLVMMLIR) + { + llvm::outs() << *module << "\n"; + return 0; + } + + if (emitAction == DumpLLVM) + { + return dumpLLVMIR(*module); + } + + if (emitAction == RunJit) + { + return runJitter(*module); + } + + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + return -1; }