#ifndef HELLO_COMBINE #define HELLO_COMBINE include "mlir/IR/PatternBase.td" include "hello/Ops.td" // Reshape(Reshape(x)) = Reshape(x) def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), (ReshapeOp $arg)>; // Reshape(Consant(x)) = x' def ReshapeConstant : NativeCodeCall<"$0.reshape(::llvm::cast<::mlir::ShapedType>($1.getType()))">; def FoldConstantReshapeOptPattern : Pat<(ReshapeOp:$res (ConstantOp $arg)), (ConstantOp (ReshapeConstant $arg, $res))>; // Reshape(x) =x , where input and output shapes are the same. def TypesAreSame : Constraint>; def RedundantShapeOptPattern : Pat< (ReshapeOp: $res $arg), (replaceWithValue $arg), [(TypesAreSame $res, $arg)]>; #endif