42 lines
1.2 KiB
C++
42 lines
1.2 KiB
C++
//
|
|
// Created by ricardo on 02/06/25.
|
|
//
|
|
|
|
#include <mlir/IR/PatternMatch.h>
|
|
#include "Dialect.h"
|
|
#include "HelloCombine.inc"
|
|
|
|
|
|
struct SimplifyRedundantTranspose final : mlir::OpRewritePattern<mlir::hello::TransposeOp>
|
|
{
|
|
explicit SimplifyRedundantTranspose(mlir::MLIRContext* context) : OpRewritePattern(
|
|
context)
|
|
{
|
|
}
|
|
|
|
/// Transpose(Transpose(x)) = x
|
|
mlir::LogicalResult matchAndRewrite(mlir::hello::TransposeOp op, mlir::PatternRewriter& rewriter) const override
|
|
{
|
|
mlir::Value transposeInput = op.getOperand();
|
|
auto transposeInputOp = transposeInput.getDefiningOp<mlir::hello::TransposeOp>();
|
|
|
|
if (!transposeInputOp)
|
|
{
|
|
return mlir::failure();
|
|
}
|
|
|
|
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
void mlir::hello::TransposeOp::getCanonicalizationPatterns(RewritePatternSet& set, MLIRContext* context)
|
|
{
|
|
set.add<SimplifyRedundantTranspose>(context);
|
|
}
|
|
|
|
void mlir::hello::ReshapeOp::getCanonicalizationPatterns(RewritePatternSet& set, MLIRContext* context)
|
|
{
|
|
set.add<ReshapeReshapeOptPattern, RedundantShapeOptPattern, FoldConstantReshapeOptPattern>(context);
|
|
}
|