feat: toy tutorial chapter 3.
Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
41
lib/HelloCombine.cpp
Normal file
41
lib/HelloCombine.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
//
|
||||
// Created by ricardo on 02/06/25.
|
||||
//
|
||||
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include "Dialect.h"
|
||||
#include "HelloCombine.inc"
|
||||
|
||||
|
||||
struct SimplifyRedundantTranspose final : mlir::OpRewritePattern<mlir::hello::TransposeOp>
|
||||
{
|
||||
explicit SimplifyRedundantTranspose(mlir::MLIRContext* context) : OpRewritePattern(
|
||||
context)
|
||||
{
|
||||
}
|
||||
|
||||
/// Transpose(Transpose(x)) = x
|
||||
mlir::LogicalResult matchAndRewrite(mlir::hello::TransposeOp op, mlir::PatternRewriter& rewriter) const override
|
||||
{
|
||||
mlir::Value transposeInput = op.getOperand();
|
||||
auto transposeInputOp = transposeInput.getDefiningOp<mlir::hello::TransposeOp>();
|
||||
|
||||
if (!transposeInputOp)
|
||||
{
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void mlir::hello::TransposeOp::getCanonicalizationPatterns(RewritePatternSet& set, MLIRContext* context)
|
||||
{
|
||||
set.add<SimplifyRedundantTranspose>(context);
|
||||
}
|
||||
|
||||
void mlir::hello::ReshapeOp::getCanonicalizationPatterns(RewritePatternSet& set, MLIRContext* context)
|
||||
{
|
||||
set.add<ReshapeReshapeOptPattern, RedundantShapeOptPattern, FoldConstantReshapeOptPattern>(context);
|
||||
}
|
23
lib/HelloCombine.td
Normal file
23
lib/HelloCombine.td
Normal file
@@ -0,0 +1,23 @@
|
||||
#ifndef HELLO_COMBINE
|
||||
#define HELLO_COMBINE
|
||||
|
||||
include "mlir/IR/PatternBase.td"
|
||||
include "hello/Ops.td"
|
||||
|
||||
// Reshape(Reshape(x)) = Reshape(x)
|
||||
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), (ReshapeOp $arg)>;
|
||||
|
||||
// Reshape(Consant(x)) = x'
|
||||
|
||||
def ReshapeConstant : NativeCodeCall<"$0.reshape(::llvm::cast<::mlir::ShapedType>($1.getType()))">;
|
||||
|
||||
def FoldConstantReshapeOptPattern : Pat<(ReshapeOp:$res (ConstantOp $arg)), (ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
|
||||
// Reshape(x) =x , where input and output shapes are the same.
|
||||
def TypesAreSame : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
|
||||
def RedundantShapeOptPattern : Pat<
|
||||
(ReshapeOp: $res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreSame $res, $arg)]>;
|
||||
|
||||
#endif
|
Reference in New Issue
Block a user