feat: toy tutorial chapter 3.
Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
parent
8d2f844e2b
commit
eacf20fe3c
|
@ -30,16 +30,22 @@ include_directories(include)
|
|||
include_directories(${CMAKE_BINARY_DIR}/include)
|
||||
add_subdirectory(include)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS lib/HelloCombine.td)
|
||||
mlir_tablegen(HelloCombine.inc -gen-rewriters)
|
||||
include_directories(${CMAKE_BINARY_DIR})
|
||||
add_public_tablegen_target(HelloCombineIncGen)
|
||||
|
||||
add_library(SyntaxNode STATIC
|
||||
lib/SyntaxNode.cpp
|
||||
lib/Dialect.cpp
|
||||
lib/MLIRGen.cpp
|
||||
lib/HelloCombine.cpp
|
||||
include/SyntaxNode.h
|
||||
include/Parser.h
|
||||
include/Lexer.h
|
||||
)
|
||||
|
||||
add_dependencies(SyntaxNode HelloOpsIncGen)
|
||||
add_dependencies(SyntaxNode HelloOpsIncGen HelloCombineIncGen)
|
||||
|
||||
target_link_libraries(SyntaxNode
|
||||
PRIVATE
|
||||
|
|
6
examples/reshape_reshape.hello
Normal file
6
examples/reshape_reshape.hello
Normal file
|
@ -0,0 +1,6 @@
|
|||
def main() {
|
||||
var a<2,1> = [1, 2];
|
||||
var b<2,1> = a;
|
||||
var c<2,1> = b;
|
||||
print(c);
|
||||
}
|
3
examples/transpose_transpose.hello
Normal file
3
examples/transpose_transpose.hello
Normal file
|
@ -0,0 +1,3 @@
|
|||
def transpose_transpose(x) {
|
||||
return transpose(transpose(x));
|
||||
}
|
|
@ -246,6 +246,8 @@ def ReshapeOp : Hello_Op<"reshape"> {
|
|||
let assemblyFormat = [{
|
||||
`(` $input `:` type($input) `)` attr-dict `to` type(results)
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -294,7 +296,7 @@ def ReturnOp : Hello_Op<"return", [Pure, HasParent<"FuncOp">,
|
|||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TransposeOp : Hello_Op<"transpose"> {
|
||||
def TransposeOp : Hello_Op<"transpose", [Pure]> {
|
||||
let summary = "transpose operation";
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
|
@ -311,6 +313,7 @@ def TransposeOp : Hello_Op<"transpose"> {
|
|||
|
||||
// Invoke a static verify method to verify this transpose operation.
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
41
lib/HelloCombine.cpp
Normal file
41
lib/HelloCombine.cpp
Normal file
|
@ -0,0 +1,41 @@
|
|||
//
|
||||
// Created by ricardo on 02/06/25.
|
||||
//
|
||||
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include "Dialect.h"
|
||||
#include "HelloCombine.inc"
|
||||
|
||||
|
||||
struct SimplifyRedundantTranspose final : mlir::OpRewritePattern<mlir::hello::TransposeOp>
|
||||
{
|
||||
explicit SimplifyRedundantTranspose(mlir::MLIRContext* context) : OpRewritePattern(
|
||||
context)
|
||||
{
|
||||
}
|
||||
|
||||
/// Transpose(Transpose(x)) = x
|
||||
mlir::LogicalResult matchAndRewrite(mlir::hello::TransposeOp op, mlir::PatternRewriter& rewriter) const override
|
||||
{
|
||||
mlir::Value transposeInput = op.getOperand();
|
||||
auto transposeInputOp = transposeInput.getDefiningOp<mlir::hello::TransposeOp>();
|
||||
|
||||
if (!transposeInputOp)
|
||||
{
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void mlir::hello::TransposeOp::getCanonicalizationPatterns(RewritePatternSet& set, MLIRContext* context)
|
||||
{
|
||||
set.add<SimplifyRedundantTranspose>(context);
|
||||
}
|
||||
|
||||
void mlir::hello::ReshapeOp::getCanonicalizationPatterns(RewritePatternSet& set, MLIRContext* context)
|
||||
{
|
||||
set.add<ReshapeReshapeOptPattern, RedundantShapeOptPattern, FoldConstantReshapeOptPattern>(context);
|
||||
}
|
23
lib/HelloCombine.td
Normal file
23
lib/HelloCombine.td
Normal file
|
@ -0,0 +1,23 @@
|
|||
#ifndef HELLO_COMBINE
|
||||
#define HELLO_COMBINE
|
||||
|
||||
include "mlir/IR/PatternBase.td"
|
||||
include "hello/Ops.td"
|
||||
|
||||
// Reshape(Reshape(x)) = Reshape(x)
|
||||
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), (ReshapeOp $arg)>;
|
||||
|
||||
// Reshape(Consant(x)) = x'
|
||||
|
||||
def ReshapeConstant : NativeCodeCall<"$0.reshape(::llvm::cast<::mlir::ShapedType>($1.getType()))">;
|
||||
|
||||
def FoldConstantReshapeOptPattern : Pat<(ReshapeOp:$res (ConstantOp $arg)), (ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
|
||||
// Reshape(x) =x , where input and output shapes are the same.
|
||||
def TypesAreSame : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
|
||||
def RedundantShapeOptPattern : Pat<
|
||||
(ReshapeOp: $res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreSame $res, $arg)]>;
|
||||
|
||||
#endif
|
102
main.cpp
102
main.cpp
|
@ -5,10 +5,10 @@
|
|||
#include <llvm/Support/ErrorOr.h>
|
||||
#include <llvm/Support/MemoryBuffer.h>
|
||||
#include <llvm/Support/SourceMgr.h>
|
||||
#include <mlir/IR/BuiltinOps.h.inc>
|
||||
#include <mlir/IR/BuiltinOps.h.inc>
|
||||
#include <mlir/IR/OwningOpRef.h>
|
||||
#include <mlir/Parser/Parser.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
|
||||
#include "Dialect.h"
|
||||
#include "MLIRGen.h"
|
||||
|
@ -30,6 +30,7 @@ namespace
|
|||
enum InputType { Hello, MLIR };
|
||||
}
|
||||
|
||||
/// The input file type
|
||||
static llvm::cl::opt<InputType> inputType("x", llvm::cl::init(Hello),
|
||||
llvm::cl::desc("Decided the kind of input desired."),
|
||||
llvm::cl::values(
|
||||
|
@ -37,10 +38,13 @@ static llvm::cl::opt<InputType> inputType("x", llvm::cl::init(Hello),
|
|||
llvm::cl::values(
|
||||
clEnumValN(MLIR, "mlir", "load the input file as a mlir source.")));
|
||||
|
||||
/// What is the action the compiler will do
|
||||
static llvm::cl::opt<Action> emitAction("emit", llvm::cl::desc("Select the kind of output desired"),
|
||||
llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")),
|
||||
llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")));
|
||||
|
||||
static llvm::cl::opt<bool> enableOpt("opt", llvm::cl::desc("Enable optimizations"));
|
||||
|
||||
std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
|
||||
{
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
|
@ -56,27 +60,18 @@ std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
|
|||
return parser.parseModule();
|
||||
}
|
||||
|
||||
int dumpMLIR()
|
||||
int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::OwningOpRef<mlir::ModuleOp>& module)
|
||||
{
|
||||
mlir::MLIRContext context;
|
||||
context.getOrLoadDialect<mlir::hello::HelloDialect>();
|
||||
|
||||
if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir"))
|
||||
{
|
||||
auto module = parseInputFile(inputFilename);
|
||||
if (module == nullptr)
|
||||
auto syntaxNode = parseInputFile(inputFilename);
|
||||
if (syntaxNode == nullptr)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
mlir::OwningOpRef<mlir::ModuleOp> mlirModule = hello::mlirGen(context, *module);
|
||||
module = hello::mlirGen(context, *syntaxNode);
|
||||
|
||||
if (!mlirModule)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
mlirModule->dump();
|
||||
return 0;
|
||||
return module ? 0 : 1;
|
||||
}
|
||||
|
||||
// Then the input file is mlir
|
||||
|
@ -91,7 +86,7 @@ int dumpMLIR()
|
|||
// Parse the input mlir.
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
mlir::OwningOpRef<mlir::ModuleOp> module =
|
||||
module =
|
||||
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &context);
|
||||
if (!module)
|
||||
{
|
||||
|
@ -99,29 +94,78 @@ int dumpMLIR()
|
|||
return 1;
|
||||
}
|
||||
|
||||
module->dump();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int dumpMLIR()
|
||||
{
|
||||
mlir::MLIRContext context;
|
||||
context.getOrLoadDialect<mlir::hello::HelloDialect>();
|
||||
|
||||
mlir::OwningOpRef<mlir::ModuleOp> module;
|
||||
llvm::SourceMgr sourceManager;
|
||||
|
||||
if (int error = loadMLIR(sourceManager, context, module))
|
||||
{
|
||||
return error;
|
||||
}
|
||||
|
||||
if (enableOpt)
|
||||
{
|
||||
mlir::PassManager manager(module.get()->getName());
|
||||
|
||||
if (mlir::failed(mlir::applyPassManagerCLOptions(manager)))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
manager.addNestedPass<mlir::hello::FuncOp>(mlir::createCanonicalizerPass());
|
||||
|
||||
if (mlir::failed(manager.run(*module)))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
module->print(llvm::outs());
|
||||
return 0;
|
||||
}
|
||||
|
||||
module->print(llvm::outs());
|
||||
return 0;
|
||||
}
|
||||
|
||||
int dumpSyntaxNode()
|
||||
{
|
||||
if (inputType == MLIR)
|
||||
{
|
||||
llvm::errs() << "Failed to dump hello syntax node when input type is MLIR.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto syntaxNode = parseInputFile(inputFilename);
|
||||
if (syntaxNode == nullptr)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
syntaxNode->dump();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
mlir::registerAsmPrinterCLOptions();
|
||||
mlir::registerMLIRContextCLOptions();
|
||||
mlir::registerPassManagerCLOptions();
|
||||
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n");
|
||||
|
||||
auto module = parseInputFile(inputFilename);
|
||||
|
||||
if (!module)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
switch (emitAction)
|
||||
{
|
||||
case DumpSyntaxNode:
|
||||
module->dump();
|
||||
return 0;
|
||||
return dumpSyntaxNode();
|
||||
case DumpMLIR:
|
||||
dumpMLIR();
|
||||
return 0;
|
||||
return dumpMLIR();
|
||||
default:
|
||||
llvm::errs() << "Unrecognized action\n";
|
||||
return 1;
|
||||
|
|
Loading…
Reference in New Issue
Block a user