// // Created by ricardo on 02/06/25. // #include #include #include "Dialect.h" #include "Passes.h" namespace mlir::hello { #include "hello/ShapeInferenceInterface.cpp.inc" } #define DEBUG_TYPE "ShapeInference" namespace { struct ShapeInferencePass : mlir::PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) void runOnOperation() override { mlir::hello::FuncOp operation = getOperation(); llvm::SmallPtrSet opWorkList; operation.walk([&](mlir::Operation* op) { if (isDynamicShapes(op)) { opWorkList.insert(op); } }); while (!opWorkList.empty()) { auto nextOperation = llvm::find_if(opWorkList, isOperationInferred); if (nextOperation == opWorkList.end()) { break; } mlir::Operation* op = *nextOperation; opWorkList.erase(op); LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); if (auto shapeInference = mlir::dyn_cast(op)) { shapeInference.inferShapes(); } else { op->emitError( std::string("Failed to inference shape for operation '") + op->getName().getIdentifier().str() + "' without shape inference interface."); signalPassFailure(); return; } } if (!opWorkList.empty()) { operation.emitError("Failed to inference shape, ") << opWorkList.size() << " operations failed to inference.\n"; signalPassFailure(); } } static bool isOperationInferred(mlir::Operation* op) { return llvm::all_of(op->getOperandTypes(), [](mlir::Type operandType) { return llvm::isa(operandType); }); } static bool isDynamicShapes(mlir::Operation* op) { return llvm::any_of(op->getResultTypes(), [](mlir::Type operandType) { return !llvm::isa(operandType); }); } }; } std::unique_ptr mlir::hello::createShapeInferencePass() { return std::make_unique(); }