311 lines
10 KiB
C++
311 lines
10 KiB
C++
#include <llvm/IR/LLVMContext.h>
|
|
|
|
#include "Lexer.h"
|
|
#include "Parser.h"
|
|
|
|
#include <llvm/Support/CommandLine.h>
|
|
#include <llvm/Support/ErrorOr.h>
|
|
#include <llvm/Support/MemoryBuffer.h>
|
|
#include <llvm/Support/SourceMgr.h>
|
|
#include <mlir/IR/OwningOpRef.h>
|
|
#include <mlir/Parser/Parser.h>
|
|
#include <mlir/Pass/PassManager.h>
|
|
#include <mlir/Transforms/Passes.h>
|
|
#include <mlir/Dialect/Func/Extensions/AllExtensions.h>
|
|
#include <mlir/Dialect/Affine/Passes.h>
|
|
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
|
#include <mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h>
|
|
#include <mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h>
|
|
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
|
#include <mlir/Target/LLVMIR/Export.h>
|
|
#include <mlir/ExecutionEngine/ExecutionEngine.h>
|
|
#include <mlir/ExecutionEngine/OptUtils.h>
|
|
#include <llvm/IR/Module.h>
|
|
#include <llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h>
|
|
#include <llvm/Support/TargetSelect.h>
|
|
|
|
#include "Dialect.h"
|
|
#include "MLIRGen.h"
|
|
#include "Passes.h"
|
|
|
|
namespace mlir
|
|
{
|
|
class ModuleOp;
|
|
}
|
|
|
|
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
|
|
llvm::cl::desc("<input hello file>"),
|
|
llvm::cl::init("-"),
|
|
llvm::cl::value_desc("filename"));
|
|
|
|
namespace
|
|
{
|
|
enum Action { None, DumpSyntaxNode, DumpMLIR, DumpAffineMLIR, DumpLLVMMLIR, DumpLLVM, RunJit };
|
|
|
|
enum InputType { Hello, MLIR };
|
|
}
|
|
|
|
/// The input file type.
|
|
static llvm::cl::opt<InputType> inputType("x", llvm::cl::init(Hello),
|
|
llvm::cl::desc("Decided the kind of input desired."),
|
|
llvm::cl::values(
|
|
clEnumValN(Hello, "hello", "load the input file as a hello source.")),
|
|
llvm::cl::values(
|
|
clEnumValN(MLIR, "mlir", "load the input file as a mlir source.")));
|
|
|
|
/// What is the action the compiler will do.
|
|
static llvm::cl::opt<Action> emitAction("emit", llvm::cl::desc("Select the kind of output desired"),
|
|
llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")),
|
|
llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")),
|
|
llvm::cl::values(clEnumValN(DumpAffineMLIR, "affine-mlir",
|
|
"Dump mlir code after lowering to affine loops")),
|
|
llvm::cl::values(clEnumValN(DumpLLVMMLIR, "llvm-mlir",
|
|
"Dump mlir code after lowering to llvm.")),
|
|
llvm::cl::values(clEnumValN(DumpLLVM, "llvm",
|
|
"Dump llvm code.")),
|
|
llvm::cl::values(clEnumValN(RunJit, "jit", "Run the input by jitter.")));
|
|
|
|
/// Whether to enable the optimization.
|
|
static llvm::cl::opt<bool> enableOpt("opt", llvm::cl::desc("Enable optimizations"));
|
|
|
|
std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
|
|
{
|
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
|
llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
|
if (const std::error_code ec = fileOrErr.getError())
|
|
{
|
|
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
|
return nullptr;
|
|
}
|
|
auto buffer = fileOrErr.get()->getBuffer();
|
|
hello::LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename));
|
|
hello::Parser parser(lexer);
|
|
return parser.parseModule();
|
|
}
|
|
|
|
int loadAndProcessMLIR(mlir::MLIRContext& context, mlir::OwningOpRef<mlir::ModuleOp>& module)
|
|
{
|
|
if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir"))
|
|
{
|
|
// The input file is hello language.
|
|
auto syntaxNode = parseInputFile(inputFilename);
|
|
if (syntaxNode == nullptr)
|
|
{
|
|
return 1;
|
|
}
|
|
module = hello::mlirGen(context, *syntaxNode);
|
|
|
|
if (!module)
|
|
{
|
|
llvm::errs() << "Failed to convert hello syntax tree to MLIR.\n";
|
|
return 1;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// The the input file is mlir.
|
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
|
if (std::error_code ec = fileOrErr.getError())
|
|
{
|
|
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
|
return 1;
|
|
}
|
|
|
|
// Parse the input mlir.
|
|
llvm::SourceMgr sourceMgr;
|
|
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
|
module =
|
|
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &context);
|
|
if (!module)
|
|
{
|
|
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
const bool isLoweringToAffine = emitAction >= DumpAffineMLIR;
|
|
const bool isLoweringToLLVM = emitAction >= DumpLLVMMLIR;
|
|
mlir::PassManager manager(module.get()->getName());
|
|
|
|
if (mlir::failed(mlir::applyPassManagerCLOptions(manager)))
|
|
{
|
|
return 1;
|
|
}
|
|
|
|
if (enableOpt || isLoweringToAffine)
|
|
{
|
|
// To inline all functions except 'main' function.
|
|
manager.addPass(mlir::createInlinerPass());
|
|
// In the canonicalizer pass, we add Transpose Pass and Reshape Pass.
|
|
mlir::OpPassManager& functionPassManager = manager.nest<mlir::hello::FuncOp>();
|
|
functionPassManager.addPass(mlir::createCanonicalizerPass());
|
|
functionPassManager.addPass(mlir::createCSEPass());
|
|
functionPassManager.addPass(mlir::hello::createShapeInferencePass());
|
|
}
|
|
|
|
if (isLoweringToAffine)
|
|
{
|
|
manager.addPass(mlir::hello::createLowerToAffineLoopsPass());
|
|
mlir::OpPassManager& functionPassManager = manager.nest<mlir::func::FuncOp>();
|
|
|
|
// Add some optimization from the affine dialect.
|
|
if (enableOpt)
|
|
{
|
|
manager.addPass(mlir::affine::createLoopFusionPass());
|
|
functionPassManager.addPass(mlir::affine::createAffineScalarReplacementPass());
|
|
}
|
|
}
|
|
|
|
if (isLoweringToLLVM)
|
|
{
|
|
manager.addPass(mlir::hello::createLowerToLLVMPass());
|
|
}
|
|
|
|
if (mlir::failed(manager.run(*module)))
|
|
{
|
|
return 1;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
int dumpLLVMIR(mlir::ModuleOp module)
|
|
{
|
|
mlir::registerBuiltinDialectTranslation(*module->getContext());
|
|
mlir::registerLLVMDialectTranslation(*module->getContext());
|
|
|
|
// Convert the mlir IR to llvm IR.
|
|
llvm::LLVMContext context;
|
|
const auto llvmModule = mlir::translateModuleToLLVMIR(module, context);
|
|
if (llvmModule == nullptr)
|
|
{
|
|
llvm::errs() << "Could not translate LLVM IR to LLVM IR\n";
|
|
return 1;
|
|
}
|
|
|
|
llvm::InitializeNativeTarget();
|
|
llvm::InitializeNativeTargetAsmPrinter();
|
|
|
|
auto targetMachineBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
|
|
if (!targetMachineBuilderOrError)
|
|
{
|
|
llvm::errs() << "Could not detect host machine\n";
|
|
return 1;
|
|
}
|
|
|
|
auto targetMachineOrError = targetMachineBuilderOrError->createTargetMachine();
|
|
if (!targetMachineOrError)
|
|
{
|
|
llvm::errs() << "Could not create target machine\n";
|
|
return 1;
|
|
}
|
|
|
|
mlir::ExecutionEngine::setupTargetTripleAndDataLayout(llvmModule.get(), targetMachineOrError.get().get());
|
|
|
|
const auto optimizationPipeline = mlir::makeOptimizingTransformer(enableOpt ? 3 : 0, 0, nullptr);
|
|
if (auto err = optimizationPipeline(llvmModule.get()))
|
|
{
|
|
llvm::errs() << "Failed to run optimization pipeline for LLVM IR:" << err << "\n";
|
|
return 1;
|
|
}
|
|
|
|
llvm::outs() << *llvmModule << "\n";
|
|
return 0;
|
|
}
|
|
|
|
int dumpSyntaxNode()
|
|
{
|
|
if (inputType == MLIR)
|
|
{
|
|
llvm::errs() << "Failed to dump hello syntax node when input type is MLIR.";
|
|
return 1;
|
|
}
|
|
|
|
auto syntaxNode = parseInputFile(inputFilename);
|
|
if (syntaxNode == nullptr)
|
|
{
|
|
return 1;
|
|
}
|
|
|
|
syntaxNode->dump();
|
|
return 0;
|
|
}
|
|
|
|
int runJitter(mlir::ModuleOp module)
|
|
{
|
|
llvm::InitializeNativeTarget();
|
|
llvm::InitializeNativeTargetAsmPrinter();
|
|
|
|
mlir::registerBuiltinDialectTranslation(*module->getContext());
|
|
mlir::registerLLVMDialectTranslation(*module->getContext());
|
|
|
|
auto optimizationPipeline = mlir::makeOptimizingTransformer(enableOpt ? 3 : 0, 0, nullptr);
|
|
|
|
mlir::ExecutionEngineOptions options;
|
|
options.transformer = optimizationPipeline;
|
|
|
|
auto engineOrError = mlir::ExecutionEngine::create(module, options);
|
|
if (!engineOrError)
|
|
{
|
|
llvm::errs() << "Failed to create execution engine\n";
|
|
return 1;
|
|
}
|
|
|
|
const auto& engine = engineOrError.get();
|
|
|
|
if (auto invocationResult = engine->invokePacked("main"))
|
|
{
|
|
llvm::errs() << "Failed to run main function by jitter?:" << invocationResult << "\n";
|
|
return 1;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
int main(int argc, char** argv)
|
|
{
|
|
mlir::registerAsmPrinterCLOptions();
|
|
mlir::registerMLIRContextCLOptions();
|
|
mlir::registerPassManagerCLOptions();
|
|
|
|
llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n");
|
|
|
|
if (emitAction == DumpSyntaxNode)
|
|
{
|
|
return dumpSyntaxNode();
|
|
}
|
|
|
|
mlir::DialectRegistry registry;
|
|
mlir::func::registerAllExtensions(registry);
|
|
mlir::LLVM::registerInlinerInterface(registry);
|
|
|
|
mlir::MLIRContext context(registry);
|
|
context.getOrLoadDialect<mlir::hello::HelloDialect>();
|
|
|
|
mlir::OwningOpRef<mlir::ModuleOp> module;
|
|
if (int error = loadAndProcessMLIR(context, module))
|
|
{
|
|
return error;
|
|
}
|
|
|
|
if (emitAction <= DumpLLVMMLIR)
|
|
{
|
|
llvm::outs() << *module << "\n";
|
|
return 0;
|
|
}
|
|
|
|
if (emitAction == DumpLLVM)
|
|
{
|
|
return dumpLLVMIR(*module);
|
|
}
|
|
|
|
if (emitAction == RunJit)
|
|
{
|
|
return runJitter(*module);
|
|
}
|
|
|
|
llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
|
|
return -1;
|
|
}
|