96 lines
2.7 KiB
C++
96 lines
2.7 KiB
C++
//
|
|
// Created by ricardo on 02/06/25.
|
|
//
|
|
|
|
#include <llvm/Support/Debug.h>
|
|
#include <mlir/Pass/Pass.h>
|
|
|
|
#include "Dialect.h"
|
|
#include "Passes.h"
|
|
|
|
|
|
namespace mlir::hello
|
|
{
|
|
#include "hello/ShapeInferenceInterface.cpp.inc"
|
|
}
|
|
|
|
|
|
#define DEBUG_TYPE "ShapeInference"
|
|
|
|
namespace
|
|
{
|
|
struct ShapeInferencePass : mlir::PassWrapper<ShapeInferencePass, mlir::OperationPass<mlir::hello::FuncOp>>
|
|
{
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass)
|
|
|
|
|
|
void runOnOperation() override
|
|
{
|
|
mlir::hello::FuncOp operation = getOperation();
|
|
llvm::SmallPtrSet<mlir::Operation*, 16> opWorkList;
|
|
|
|
operation.walk([&](mlir::Operation* op)
|
|
{
|
|
if (isDynamicShapes(op))
|
|
{
|
|
opWorkList.insert(op);
|
|
}
|
|
});
|
|
|
|
while (!opWorkList.empty())
|
|
{
|
|
auto nextOperation = llvm::find_if(opWorkList, isOperationInferred);
|
|
if (nextOperation == opWorkList.end())
|
|
{
|
|
break;
|
|
}
|
|
|
|
mlir::Operation* op = *nextOperation;
|
|
opWorkList.erase(op);
|
|
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
|
|
|
|
if (auto shapeInference = mlir::dyn_cast<mlir::hello::ShapeInference>(op))
|
|
{
|
|
shapeInference.inferShapes();
|
|
}
|
|
else
|
|
{
|
|
op->emitError(
|
|
std::string("Failed to inference shape for operation '") + op->getName().getIdentifier().str() +
|
|
"' without shape inference interface.");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
if (!opWorkList.empty())
|
|
{
|
|
operation.emitError("Failed to inference shape, ") << opWorkList.size() <<
|
|
" operations failed to inference.\n";
|
|
signalPassFailure();
|
|
}
|
|
}
|
|
|
|
static bool isOperationInferred(mlir::Operation* op)
|
|
{
|
|
return llvm::all_of(op->getOperandTypes(), [](mlir::Type operandType)
|
|
{
|
|
return llvm::isa<mlir::RankedTensorType>(operandType);
|
|
});
|
|
}
|
|
|
|
static bool isDynamicShapes(mlir::Operation* op)
|
|
{
|
|
return llvm::any_of(op->getResultTypes(), [](mlir::Type operandType)
|
|
{
|
|
return !llvm::isa<mlir::RankedTensorType>(operandType);
|
|
});
|
|
}
|
|
};
|
|
}
|
|
|
|
std::unique_ptr<mlir::Pass> mlir::hello::createShapeInferencePass()
|
|
{
|
|
return std::make_unique<ShapeInferencePass>();
|
|
}
|