#include "Lexer.h" #include "Parser.h" #include #include #include #include #include #include #include #include #include "Dialect.h" #include "MLIRGen.h" #include "Passes.h" namespace mlir { class ModuleOp; } static llvm::cl::opt inputFilename(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::value_desc("filename")); namespace { enum Action { None, DumpSyntaxNode, DumpMLIR }; enum InputType { Hello, MLIR }; } /// The input file type static llvm::cl::opt inputType("x", llvm::cl::init(Hello), llvm::cl::desc("Decided the kind of input desired."), llvm::cl::values( clEnumValN(Hello, "hello", "load the input file as a hello source.")), llvm::cl::values( clEnumValN(MLIR, "mlir", "load the input file as a mlir source."))); /// What is the action the compiler will do static llvm::cl::opt emitAction("emit", llvm::cl::desc("Select the kind of output desired"), llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")), llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code"))); static llvm::cl::opt enableOpt("opt", llvm::cl::desc("Enable optimizations")); std::unique_ptr parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code ec = fileOrErr.getError()) { llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return nullptr; } auto buffer = fileOrErr.get()->getBuffer(); hello::LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); hello::Parser parser(lexer); return parser.parseModule(); } int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::OwningOpRef& module) { if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir")) { auto syntaxNode = parseInputFile(inputFilename); if (syntaxNode == nullptr) { return 1; } module = hello::mlirGen(context, *syntaxNode); return module ? 0 : 1; } // Then the input file is mlir llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); if (std::error_code ec = fileOrErr.getError()) { llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return 1; } // Parse the input mlir. llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); module = mlir::parseSourceFile(sourceMgr, &context); if (!module) { llvm::errs() << "Error can't load file " << inputFilename << "\n"; return 1; } return 0; } int dumpMLIR() { mlir::MLIRContext context; context.getOrLoadDialect(); mlir::OwningOpRef module; llvm::SourceMgr sourceManager; if (int error = loadMLIR(sourceManager, context, module)) { return error; } if (enableOpt) { mlir::PassManager manager(module.get()->getName()); if (mlir::failed(mlir::applyPassManagerCLOptions(manager))) { return 1; } // To inline all functions except 'main' function. manager.addPass(mlir::createInlinerPass()); // In the canonicalizer pass, we add Transpose Pass and Reshape Pass. mlir::OpPassManager& functionPassManager = manager.nest(); functionPassManager.addPass(mlir::createCanonicalizerPass()); functionPassManager.addPass(mlir::createCSEPass()); functionPassManager.addPass(mlir::hello::createShapeInferencePass()); if (mlir::failed(manager.run(*module))) { return 1; } module->print(llvm::outs()); return 0; } module->print(llvm::outs()); return 0; } int dumpSyntaxNode() { if (inputType == MLIR) { llvm::errs() << "Failed to dump hello syntax node when input type is MLIR."; return 1; } auto syntaxNode = parseInputFile(inputFilename); if (syntaxNode == nullptr) { return 1; } syntaxNode->dump(); return 0; } int main(int argc, char** argv) { mlir::registerAsmPrinterCLOptions(); mlir::registerMLIRContextCLOptions(); mlir::registerPassManagerCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n"); switch (emitAction) { case DumpSyntaxNode: return dumpSyntaxNode(); case DumpMLIR: return dumpMLIR(); default: llvm::errs() << "Unrecognized action\n"; return 1; } }