diff --git a/CMakeLists.txt b/CMakeLists.txt index 18c7ae6..54e1361 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,16 +30,22 @@ include_directories(include) include_directories(${CMAKE_BINARY_DIR}/include) add_subdirectory(include) +set(LLVM_TARGET_DEFINITIONS lib/HelloCombine.td) +mlir_tablegen(HelloCombine.inc -gen-rewriters) +include_directories(${CMAKE_BINARY_DIR}) +add_public_tablegen_target(HelloCombineIncGen) + add_library(SyntaxNode STATIC lib/SyntaxNode.cpp lib/Dialect.cpp lib/MLIRGen.cpp + lib/HelloCombine.cpp include/SyntaxNode.h include/Parser.h include/Lexer.h ) -add_dependencies(SyntaxNode HelloOpsIncGen) +add_dependencies(SyntaxNode HelloOpsIncGen HelloCombineIncGen) target_link_libraries(SyntaxNode PRIVATE diff --git a/examples/reshape_reshape.hello b/examples/reshape_reshape.hello new file mode 100644 index 0000000..1d4ca8a --- /dev/null +++ b/examples/reshape_reshape.hello @@ -0,0 +1,6 @@ +def main() { + var a<2,1> = [1, 2]; + var b<2,1> = a; + var c<2,1> = b; + print(c); +} \ No newline at end of file diff --git a/examples/transpose_transpose.hello b/examples/transpose_transpose.hello new file mode 100644 index 0000000..3986630 --- /dev/null +++ b/examples/transpose_transpose.hello @@ -0,0 +1,3 @@ +def transpose_transpose(x) { + return transpose(transpose(x)); +} \ No newline at end of file diff --git a/include/hello/Ops.td b/include/hello/Ops.td index 87078d3..25b38cc 100644 --- a/include/hello/Ops.td +++ b/include/hello/Ops.td @@ -246,6 +246,8 @@ def ReshapeOp : Hello_Op<"reshape"> { let assemblyFormat = [{ `(` $input `:` type($input) `)` attr-dict `to` type(results) }]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -294,7 +296,7 @@ def ReturnOp : Hello_Op<"return", [Pure, HasParent<"FuncOp">, // TransposeOp //===----------------------------------------------------------------------===// -def TransposeOp : Hello_Op<"transpose"> { +def TransposeOp : Hello_Op<"transpose", [Pure]> { let summary = "transpose operation"; let arguments = (ins F64Tensor:$input); @@ -311,6 +313,7 @@ def TransposeOp : Hello_Op<"transpose"> { // Invoke a static verify method to verify this transpose operation. let hasVerifier = 1; + let hasCanonicalizer = 1; } #endif diff --git a/lib/HelloCombine.cpp b/lib/HelloCombine.cpp new file mode 100644 index 0000000..e90c517 --- /dev/null +++ b/lib/HelloCombine.cpp @@ -0,0 +1,41 @@ +// +// Created by ricardo on 02/06/25. +// + +#include +#include "Dialect.h" +#include "HelloCombine.inc" + + +struct SimplifyRedundantTranspose final : mlir::OpRewritePattern +{ + explicit SimplifyRedundantTranspose(mlir::MLIRContext* context) : OpRewritePattern( + context) + { + } + + /// Transpose(Transpose(x)) = x + mlir::LogicalResult matchAndRewrite(mlir::hello::TransposeOp op, mlir::PatternRewriter& rewriter) const override + { + mlir::Value transposeInput = op.getOperand(); + auto transposeInputOp = transposeInput.getDefiningOp(); + + if (!transposeInputOp) + { + return mlir::failure(); + } + + rewriter.replaceOp(op, {transposeInputOp.getOperand()}); + return mlir::success(); + } +}; + +void mlir::hello::TransposeOp::getCanonicalizationPatterns(RewritePatternSet& set, MLIRContext* context) +{ + set.add(context); +} + +void mlir::hello::ReshapeOp::getCanonicalizationPatterns(RewritePatternSet& set, MLIRContext* context) +{ + set.add(context); +} diff --git a/lib/HelloCombine.td b/lib/HelloCombine.td new file mode 100644 index 0000000..a03de9d --- /dev/null +++ b/lib/HelloCombine.td @@ -0,0 +1,23 @@ +#ifndef HELLO_COMBINE +#define HELLO_COMBINE + +include "mlir/IR/PatternBase.td" +include "hello/Ops.td" + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), (ReshapeOp $arg)>; + +// Reshape(Consant(x)) = x' + +def ReshapeConstant : NativeCodeCall<"$0.reshape(::llvm::cast<::mlir::ShapedType>($1.getType()))">; + +def FoldConstantReshapeOptPattern : Pat<(ReshapeOp:$res (ConstantOp $arg)), (ConstantOp (ReshapeConstant $arg, $res))>; + +// Reshape(x) =x , where input and output shapes are the same. +def TypesAreSame : Constraint>; + +def RedundantShapeOptPattern : Pat< + (ReshapeOp: $res $arg), (replaceWithValue $arg), + [(TypesAreSame $res, $arg)]>; + +#endif diff --git a/main.cpp b/main.cpp index 5bd6edc..8d51c76 100644 --- a/main.cpp +++ b/main.cpp @@ -5,10 +5,10 @@ #include #include #include -#include -#include #include #include +#include +#include #include "Dialect.h" #include "MLIRGen.h" @@ -30,6 +30,7 @@ namespace enum InputType { Hello, MLIR }; } +/// The input file type static llvm::cl::opt inputType("x", llvm::cl::init(Hello), llvm::cl::desc("Decided the kind of input desired."), llvm::cl::values( @@ -37,10 +38,13 @@ static llvm::cl::opt inputType("x", llvm::cl::init(Hello), llvm::cl::values( clEnumValN(MLIR, "mlir", "load the input file as a mlir source."))); +/// What is the action the compiler will do static llvm::cl::opt emitAction("emit", llvm::cl::desc("Select the kind of output desired"), llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")), llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code"))); +static llvm::cl::opt enableOpt("opt", llvm::cl::desc("Enable optimizations")); + std::unique_ptr parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> fileOrErr = @@ -56,27 +60,18 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) return parser.parseModule(); } -int dumpMLIR() +int loadMLIR(llvm::SourceMgr& sourceManager, mlir::MLIRContext& context, mlir::OwningOpRef& module) { - mlir::MLIRContext context; - context.getOrLoadDialect(); - if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir")) { - auto module = parseInputFile(inputFilename); - if (module == nullptr) + auto syntaxNode = parseInputFile(inputFilename); + if (syntaxNode == nullptr) { return 1; } - mlir::OwningOpRef mlirModule = hello::mlirGen(context, *module); + module = hello::mlirGen(context, *syntaxNode); - if (!mlirModule) - { - return 1; - } - - mlirModule->dump(); - return 0; + return module ? 0 : 1; } // Then the input file is mlir @@ -91,7 +86,7 @@ int dumpMLIR() // Parse the input mlir. llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - mlir::OwningOpRef module = + module = mlir::parseSourceFile(sourceMgr, &context); if (!module) { @@ -99,29 +94,78 @@ int dumpMLIR() return 1; } - module->dump(); + return 0; +} + +int dumpMLIR() +{ + mlir::MLIRContext context; + context.getOrLoadDialect(); + + mlir::OwningOpRef module; + llvm::SourceMgr sourceManager; + + if (int error = loadMLIR(sourceManager, context, module)) + { + return error; + } + + if (enableOpt) + { + mlir::PassManager manager(module.get()->getName()); + + if (mlir::failed(mlir::applyPassManagerCLOptions(manager))) + { + return 1; + } + + manager.addNestedPass(mlir::createCanonicalizerPass()); + + if (mlir::failed(manager.run(*module))) + { + return 1; + } + + module->print(llvm::outs()); + return 0; + } + + module->print(llvm::outs()); + return 0; +} + +int dumpSyntaxNode() +{ + if (inputType == MLIR) + { + llvm::errs() << "Failed to dump hello syntax node when input type is MLIR."; + return 1; + } + + auto syntaxNode = parseInputFile(inputFilename); + if (syntaxNode == nullptr) + { + return 1; + } + + syntaxNode->dump(); return 0; } int main(int argc, char** argv) { + mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); + mlir::registerPassManagerCLOptions(); + llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n"); - auto module = parseInputFile(inputFilename); - - if (!module) - { - return 1; - } - switch (emitAction) { case DumpSyntaxNode: - module->dump(); - return 0; + return dumpSyntaxNode(); case DumpMLIR: - dumpMLIR(); - return 0; + return dumpMLIR(); default: llvm::errs() << "Unrecognized action\n"; return 1;