hello-mlir/lib/ShapeInferencePass.cpp
jackfiled 902915a57b
feat: toy tutorial chapter 4.
Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
2025-06-03 16:03:17 +08:00

96 lines
2.7 KiB
C++

//
// Created by ricardo on 02/06/25.
//
#include <llvm/Support/Debug.h>
#include <mlir/Pass/Pass.h>
#include "Dialect.h"
#include "Passes.h"
namespace mlir::hello
{
#include "hello/ShapeInferenceInterface.cpp.inc"
}
#define DEBUG_TYPE "ShapeInference"
namespace
{
struct ShapeInferencePass : mlir::PassWrapper<ShapeInferencePass, mlir::OperationPass<mlir::hello::FuncOp>>
{
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass)
void runOnOperation() override
{
mlir::hello::FuncOp operation = getOperation();
llvm::SmallPtrSet<mlir::Operation*, 16> opWorkList;
operation.walk([&](mlir::Operation* op)
{
if (isDynamicShapes(op))
{
opWorkList.insert(op);
}
});
while (!opWorkList.empty())
{
auto nextOperation = llvm::find_if(opWorkList, isOperationInferred);
if (nextOperation == opWorkList.end())
{
break;
}
mlir::Operation* op = *nextOperation;
opWorkList.erase(op);
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
if (auto shapeInference = mlir::dyn_cast<mlir::hello::ShapeInference>(op))
{
shapeInference.inferShapes();
}
else
{
op->emitError(
std::string("Failed to inference shape for operation '") + op->getName().getIdentifier().str() +
"' without shape inference interface.");
signalPassFailure();
return;
}
}
if (!opWorkList.empty())
{
operation.emitError("Failed to inference shape, ") << opWorkList.size() <<
" operations failed to inference.\n";
signalPassFailure();
}
}
static bool isOperationInferred(mlir::Operation* op)
{
return llvm::all_of(op->getOperandTypes(), [](mlir::Type operandType)
{
return llvm::isa<mlir::RankedTensorType>(operandType);
});
}
static bool isDynamicShapes(mlir::Operation* op)
{
return llvm::any_of(op->getResultTypes(), [](mlir::Type operandType)
{
return !llvm::isa<mlir::RankedTensorType>(operandType);
});
}
};
}
std::unique_ptr<mlir::Pass> mlir::hello::createShapeInferencePass()
{
return std::make_unique<ShapeInferencePass>();
}