feat: toy tutorial chapter 3.

Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
jackfiled 2025-06-02 17:22:53 +08:00
parent 8d2f844e2b
commit eacf20fe3c
Signed by: jackfiled
GPG Key ID: 5F7234760472A46A
7 changed files with 157 additions and 31 deletions

View File

@ -30,16 +30,22 @@ include_directories(include)
include_directories(${CMAKE_BINARY_DIR}/include)
add_subdirectory(include)
set(LLVM_TARGET_DEFINITIONS lib/HelloCombine.td)
mlir_tablegen(HelloCombine.inc -gen-rewriters)
include_directories(${CMAKE_BINARY_DIR})
add_public_tablegen_target(HelloCombineIncGen)
add_library(SyntaxNode STATIC
lib/SyntaxNode.cpp
lib/Dialect.cpp
lib/MLIRGen.cpp
lib/HelloCombine.cpp
include/SyntaxNode.h
include/Parser.h
include/Lexer.h
)
add_dependencies(SyntaxNode HelloOpsIncGen)
add_dependencies(SyntaxNode HelloOpsIncGen HelloCombineIncGen)
target_link_libraries(SyntaxNode
PRIVATE

View File

@ -0,0 +1,6 @@
def main() {
var a<2,1> = [1, 2];
var b<2,1> = a;
var c<2,1> = b;
print(c);
}

View File

@ -0,0 +1,3 @@
def transpose_transpose(x) {
return transpose(transpose(x));
}

View File

@ -246,6 +246,8 @@ def ReshapeOp : Hello_Op<"reshape"> {
let assemblyFormat = [{
`(` $input `:` type($input) `)` attr-dict `to` type(results)
}];
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@ -294,7 +296,7 @@ def ReturnOp : Hello_Op<"return", [Pure, HasParent<"FuncOp">,
// TransposeOp
//===----------------------------------------------------------------------===//
def TransposeOp : Hello_Op<"transpose"> {
def TransposeOp : Hello_Op<"transpose", [Pure]> {
let summary = "transpose operation";
let arguments = (ins F64Tensor:$input);
@ -311,6 +313,7 @@ def TransposeOp : Hello_Op<"transpose"> {
// Invoke a static verify method to verify this transpose operation.
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
#endif

41
lib/HelloCombine.cpp Normal file
View File

@ -0,0 +1,41 @@
//
// Created by ricardo on 02/06/25.
//
#include <mlir/IR/PatternMatch.h>
#include "Dialect.h"
#include "HelloCombine.inc"
struct SimplifyRedundantTranspose final : mlir::OpRewritePattern<mlir::hello::TransposeOp>
{
explicit SimplifyRedundantTranspose(mlir::MLIRContext* context) : OpRewritePattern(
context)
{
}
/// Transpose(Transpose(x)) = x
mlir::LogicalResult matchAndRewrite(mlir::hello::TransposeOp op, mlir::PatternRewriter& rewriter) const override
{
mlir::Value transposeInput = op.getOperand();
auto transposeInputOp = transposeInput.getDefiningOp<mlir::hello::TransposeOp>();
if (!transposeInputOp)
{
return mlir::failure();
}
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return mlir::success();
}
};
void mlir::hello::TransposeOp::getCanonicalizationPatterns(RewritePatternSet& set, MLIRContext* context)
{
set.add<SimplifyRedundantTranspose>(context);
}
void mlir::hello::ReshapeOp::getCanonicalizationPatterns(RewritePatternSet& set, MLIRContext* context)
{
set.add<ReshapeReshapeOptPattern, RedundantShapeOptPattern, FoldConstantReshapeOptPattern>(context);
}

23
lib/HelloCombine.td Normal file
View File

@ -0,0 +1,23 @@
#ifndef HELLO_COMBINE
#define HELLO_COMBINE
include "mlir/IR/PatternBase.td"
include "hello/Ops.td"
// Reshape(Reshape(x)) = Reshape(x)
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), (ReshapeOp $arg)>;
// Reshape(Consant(x)) = x'
def ReshapeConstant : NativeCodeCall<"$0.reshape(::llvm::cast<::mlir::ShapedType>($1.getType()))">;
def FoldConstantReshapeOptPattern : Pat<(ReshapeOp:$res (ConstantOp $arg)), (ConstantOp (ReshapeConstant $arg, $res))>;
// Reshape(x) =x , where input and output shapes are the same.
def TypesAreSame : Constraint<CPred<"$0.getType() == $1.getType()">>;
def RedundantShapeOptPattern : Pat<
(ReshapeOp: $res $arg), (replaceWithValue $arg),
[(TypesAreSame $res, $arg)]>;
#endif

102
main.cpp
View File

@ -5,10 +5,10 @@
#include <llvm/Support/ErrorOr.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/BuiltinOps.h.inc>
#include <mlir/IR/BuiltinOps.h.inc>
#include <mlir/IR/OwningOpRef.h>
#include <mlir/Parser/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Transforms/Passes.h>
#include "Dialect.h"
#include "MLIRGen.h"
@ -30,6 +30,7 @@ namespace
enum InputType { Hello, MLIR };
}
/// The input file type
static llvm::cl::opt<InputType> inputType("x", llvm::cl::init(Hello),
llvm::cl::desc("Decided the kind of input desired."),
llvm::cl::values(
@ -37,10 +38,13 @@ static llvm::cl::opt<InputType> inputType("x", llvm::cl::init(Hello),
llvm::cl::values(
clEnumValN(MLIR, "mlir", "load the input file as a mlir source.")));
/// What is the action the compiler will do
static llvm::cl::opt<Action> emitAction("emit", llvm::cl::desc("Select the kind of output desired"),
llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")),
llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")));
static llvm::cl::opt<bool> enableOpt("opt", llvm::cl::desc("Enable optimizations"));
std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
{
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
@ -56,27 +60,18 @@ std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
return parser.parseModule();
}
int dumpMLIR()
int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::OwningOpRef<mlir::ModuleOp>& module)
{
mlir::MLIRContext context;
context.getOrLoadDialect<mlir::hello::HelloDialect>();
if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir"))
{
auto module = parseInputFile(inputFilename);
if (module == nullptr)
auto syntaxNode = parseInputFile(inputFilename);
if (syntaxNode == nullptr)
{
return 1;
}
mlir::OwningOpRef<mlir::ModuleOp> mlirModule = hello::mlirGen(context, *module);
module = hello::mlirGen(context, *syntaxNode);
if (!mlirModule)
{
return 1;
}
mlirModule->dump();
return 0;
return module ? 0 : 1;
}
// Then the input file is mlir
@ -91,7 +86,7 @@ int dumpMLIR()
// Parse the input mlir.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
mlir::OwningOpRef<mlir::ModuleOp> module =
module =
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &context);
if (!module)
{
@ -99,29 +94,78 @@ int dumpMLIR()
return 1;
}
module->dump();
return 0;
}
int dumpMLIR()
{
mlir::MLIRContext context;
context.getOrLoadDialect<mlir::hello::HelloDialect>();
mlir::OwningOpRef<mlir::ModuleOp> module;
llvm::SourceMgr sourceManager;
if (int error = loadMLIR(sourceManager, context, module))
{
return error;
}
if (enableOpt)
{
mlir::PassManager manager(module.get()->getName());
if (mlir::failed(mlir::applyPassManagerCLOptions(manager)))
{
return 1;
}
manager.addNestedPass<mlir::hello::FuncOp>(mlir::createCanonicalizerPass());
if (mlir::failed(manager.run(*module)))
{
return 1;
}
module->print(llvm::outs());
return 0;
}
module->print(llvm::outs());
return 0;
}
int dumpSyntaxNode()
{
if (inputType == MLIR)
{
llvm::errs() << "Failed to dump hello syntax node when input type is MLIR.";
return 1;
}
auto syntaxNode = parseInputFile(inputFilename);
if (syntaxNode == nullptr)
{
return 1;
}
syntaxNode->dump();
return 0;
}
int main(int argc, char** argv)
{
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
mlir::registerPassManagerCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n");
auto module = parseInputFile(inputFilename);
if (!module)
{
return 1;
}
switch (emitAction)
{
case DumpSyntaxNode:
module->dump();
return 0;
return dumpSyntaxNode();
case DumpMLIR:
dumpMLIR();
return 0;
return dumpMLIR();
default:
llvm::errs() << "Unrecognized action\n";
return 1;