24 lines
738 B
TableGen
24 lines
738 B
TableGen
#ifndef HELLO_COMBINE
|
|
#define HELLO_COMBINE
|
|
|
|
include "mlir/IR/PatternBase.td"
|
|
include "hello/Ops.td"
|
|
|
|
// Reshape(Reshape(x)) = Reshape(x)
|
|
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), (ReshapeOp $arg)>;
|
|
|
|
// Reshape(Consant(x)) = x'
|
|
|
|
def ReshapeConstant : NativeCodeCall<"$0.reshape(::llvm::cast<::mlir::ShapedType>($1.getType()))">;
|
|
|
|
def FoldConstantReshapeOptPattern : Pat<(ReshapeOp:$res (ConstantOp $arg)), (ConstantOp (ReshapeConstant $arg, $res))>;
|
|
|
|
// Reshape(x) =x , where input and output shapes are the same.
|
|
def TypesAreSame : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
|
|
|
def RedundantShapeOptPattern : Pat<
|
|
(ReshapeOp: $res $arg), (replaceWithValue $arg),
|
|
[(TypesAreSame $res, $arg)]>;
|
|
|
|
#endif
|