hello-mlir/main.cpp
jackfiled 902915a57b
feat: toy tutorial chapter 4.
Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
2025-06-03 16:03:17 +08:00

181 lines
5.3 KiB
C++

#include "Lexer.h"
#include "Parser.h"
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/ErrorOr.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/OwningOpRef.h>
#include <mlir/Parser/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Transforms/Passes.h>
#include "Dialect.h"
#include "MLIRGen.h"
#include "Passes.h"
namespace mlir
{
class ModuleOp;
}
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input hello file>"),
llvm::cl::init("-"),
llvm::cl::value_desc("filename"));
namespace
{
enum Action { None, DumpSyntaxNode, DumpMLIR };
enum InputType { Hello, MLIR };
}
/// The input file type
static llvm::cl::opt<InputType> inputType("x", llvm::cl::init(Hello),
llvm::cl::desc("Decided the kind of input desired."),
llvm::cl::values(
clEnumValN(Hello, "hello", "load the input file as a hello source.")),
llvm::cl::values(
clEnumValN(MLIR, "mlir", "load the input file as a mlir source.")));
/// What is the action the compiler will do
static llvm::cl::opt<Action> emitAction("emit", llvm::cl::desc("Select the kind of output desired"),
llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")),
llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")));
static llvm::cl::opt<bool> enableOpt("opt", llvm::cl::desc("Enable optimizations"));
std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
{
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError())
{
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return nullptr;
}
auto buffer = fileOrErr.get()->getBuffer();
hello::LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename));
hello::Parser parser(lexer);
return parser.parseModule();
}
int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::OwningOpRef<mlir::ModuleOp>& module)
{
if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir"))
{
auto syntaxNode = parseInputFile(inputFilename);
if (syntaxNode == nullptr)
{
return 1;
}
module = hello::mlirGen(context, *syntaxNode);
return module ? 0 : 1;
}
// Then the input file is mlir
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code ec = fileOrErr.getError())
{
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return 1;
}
// Parse the input mlir.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module =
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &context);
if (!module)
{
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return 1;
}
return 0;
}
int dumpMLIR()
{
mlir::MLIRContext context;
context.getOrLoadDialect<mlir::hello::HelloDialect>();
mlir::OwningOpRef<mlir::ModuleOp> module;
llvm::SourceMgr sourceManager;
if (int error = loadMLIR(sourceManager, context, module))
{
return error;
}
if (enableOpt)
{
mlir::PassManager manager(module.get()->getName());
if (mlir::failed(mlir::applyPassManagerCLOptions(manager)))
{
return 1;
}
// To inline all functions except 'main' function.
manager.addPass(mlir::createInlinerPass());
// In the canonicalizer pass, we add Transpose Pass and Reshape Pass.
mlir::OpPassManager& functionPassManager = manager.nest<mlir::hello::FuncOp>();
functionPassManager.addPass(mlir::createCanonicalizerPass());
functionPassManager.addPass(mlir::createCSEPass());
functionPassManager.addPass(mlir::hello::createShapeInferencePass());
if (mlir::failed(manager.run(*module)))
{
return 1;
}
module->print(llvm::outs());
return 0;
}
module->print(llvm::outs());
return 0;
}
int dumpSyntaxNode()
{
if (inputType == MLIR)
{
llvm::errs() << "Failed to dump hello syntax node when input type is MLIR.";
return 1;
}
auto syntaxNode = parseInputFile(inputFilename);
if (syntaxNode == nullptr)
{
return 1;
}
syntaxNode->dump();
return 0;
}
int main(int argc, char** argv)
{
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
mlir::registerPassManagerCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n");
switch (emitAction)
{
case DumpSyntaxNode:
return dumpSyntaxNode();
case DumpMLIR:
return dumpMLIR();
default:
llvm::errs() << "Unrecognized action\n";
return 1;
}
}