#include #include "Lexer.h" #include "Parser.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "Dialect.h" #include "MLIRGen.h" #include "Passes.h" namespace mlir { class ModuleOp; } static llvm::cl::opt inputFilename(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::value_desc("filename")); namespace { enum Action { None, DumpSyntaxNode, DumpMLIR, DumpAffineMLIR, DumpLLVMMLIR, DumpLLVM, RunJit }; enum InputType { Hello, MLIR }; } /// The input file type. static llvm::cl::opt inputType("x", llvm::cl::init(Hello), llvm::cl::desc("Decided the kind of input desired."), llvm::cl::values( clEnumValN(Hello, "hello", "load the input file as a hello source.")), llvm::cl::values( clEnumValN(MLIR, "mlir", "load the input file as a mlir source."))); /// What is the action the compiler will do. static llvm::cl::opt emitAction("emit", llvm::cl::desc("Select the kind of output desired"), llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")), llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")), llvm::cl::values(clEnumValN(DumpAffineMLIR, "affine-mlir", "Dump mlir code after lowering to affine loops")), llvm::cl::values(clEnumValN(DumpLLVMMLIR, "llvm-mlir", "Dump mlir code after lowering to llvm.")), llvm::cl::values(clEnumValN(DumpLLVM, "llvm", "Dump llvm code.")), llvm::cl::values(clEnumValN(RunJit, "jit", "Run the input by jitter."))); /// Whether to enable the optimization. static llvm::cl::opt enableOpt("opt", llvm::cl::desc("Enable optimizations")); std::unique_ptr parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (const std::error_code ec = fileOrErr.getError()) { llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return nullptr; } auto buffer = fileOrErr.get()->getBuffer(); hello::LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); hello::Parser parser(lexer); return parser.parseModule(); } int loadAndProcessMLIR(mlir::MLIRContext& context, mlir::OwningOpRef& module) { if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir")) { // The input file is hello language. auto syntaxNode = parseInputFile(inputFilename); if (syntaxNode == nullptr) { return 1; } module = hello::mlirGen(context, *syntaxNode); if (!module) { llvm::errs() << "Failed to convert hello syntax tree to MLIR.\n"; return 1; } } else { // The the input file is mlir. llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); if (std::error_code ec = fileOrErr.getError()) { llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return 1; } // Parse the input mlir. llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); module = mlir::parseSourceFile(sourceMgr, &context); if (!module) { llvm::errs() << "Error can't load file " << inputFilename << "\n"; return 1; } } const bool isLoweringToAffine = emitAction >= DumpAffineMLIR; const bool isLoweringToLLVM = emitAction >= DumpLLVMMLIR; mlir::PassManager manager(module.get()->getName()); if (mlir::failed(mlir::applyPassManagerCLOptions(manager))) { return 1; } if (enableOpt || isLoweringToAffine) { // To inline all functions except 'main' function. manager.addPass(mlir::createInlinerPass()); // In the canonicalizer pass, we add Transpose Pass and Reshape Pass. mlir::OpPassManager& functionPassManager = manager.nest(); functionPassManager.addPass(mlir::createCanonicalizerPass()); functionPassManager.addPass(mlir::createCSEPass()); functionPassManager.addPass(mlir::hello::createShapeInferencePass()); } if (isLoweringToAffine) { manager.addPass(mlir::hello::createLowerToAffineLoopsPass()); mlir::OpPassManager& functionPassManager = manager.nest(); // Add some optimization from the affine dialect. if (enableOpt) { manager.addPass(mlir::affine::createLoopFusionPass()); functionPassManager.addPass(mlir::affine::createAffineScalarReplacementPass()); } } if (isLoweringToLLVM) { manager.addPass(mlir::hello::createLowerToLLVMPass()); } if (mlir::failed(manager.run(*module))) { return 1; } return 0; } int dumpLLVMIR(mlir::ModuleOp module) { mlir::registerBuiltinDialectTranslation(*module->getContext()); mlir::registerLLVMDialectTranslation(*module->getContext()); // Convert the mlir IR to llvm IR. llvm::LLVMContext context; const auto llvmModule = mlir::translateModuleToLLVMIR(module, context); if (llvmModule == nullptr) { llvm::errs() << "Could not translate LLVM IR to LLVM IR\n"; return 1; } llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); auto targetMachineBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); if (!targetMachineBuilderOrError) { llvm::errs() << "Could not detect host machine\n"; return 1; } auto targetMachineOrError = targetMachineBuilderOrError->createTargetMachine(); if (!targetMachineOrError) { llvm::errs() << "Could not create target machine\n"; return 1; } mlir::ExecutionEngine::setupTargetTripleAndDataLayout(llvmModule.get(), targetMachineOrError.get().get()); const auto optimizationPipeline = mlir::makeOptimizingTransformer(enableOpt ? 3 : 0, 0, nullptr); if (auto err = optimizationPipeline(llvmModule.get())) { llvm::errs() << "Failed to run optimization pipeline for LLVM IR:" << err << "\n"; return 1; } llvm::outs() << *llvmModule << "\n"; return 0; } int dumpSyntaxNode() { if (inputType == MLIR) { llvm::errs() << "Failed to dump hello syntax node when input type is MLIR."; return 1; } auto syntaxNode = parseInputFile(inputFilename); if (syntaxNode == nullptr) { return 1; } syntaxNode->dump(); return 0; } int runJitter(mlir::ModuleOp module) { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); mlir::registerBuiltinDialectTranslation(*module->getContext()); mlir::registerLLVMDialectTranslation(*module->getContext()); auto optimizationPipeline = mlir::makeOptimizingTransformer(enableOpt ? 3 : 0, 0, nullptr); mlir::ExecutionEngineOptions options; options.transformer = optimizationPipeline; auto engineOrError = mlir::ExecutionEngine::create(module, options); if (!engineOrError) { llvm::errs() << "Failed to create execution engine\n"; return 1; } const auto& engine = engineOrError.get(); if (auto invocationResult = engine->invokePacked("main")) { llvm::errs() << "Failed to run main function by jitter?:" << invocationResult << "\n"; return 1; } return 0; } int main(int argc, char** argv) { mlir::registerAsmPrinterCLOptions(); mlir::registerMLIRContextCLOptions(); mlir::registerPassManagerCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n"); if (emitAction == DumpSyntaxNode) { return dumpSyntaxNode(); } mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); mlir::LLVM::registerInlinerInterface(registry); mlir::MLIRContext context(registry); context.getOrLoadDialect(); mlir::OwningOpRef module; if (int error = loadAndProcessMLIR(context, module)) { return error; } if (emitAction <= DumpLLVMMLIR) { llvm::outs() << *module << "\n"; return 0; } if (emitAction == DumpLLVM) { return dumpLLVMIR(*module); } if (emitAction == RunJit) { return runJitter(*module); } llvm::errs() << "No action specified (parsing only?), use -emit=\n"; return -1; }