347 lines
14 KiB
C++
347 lines
14 KiB
C++
//
|
|
// Created by ricardo on 05/06/25.
|
|
//
|
|
|
|
#include <mlir/IR/BuiltinOps.h>
|
|
#include <mlir/Pass/Pass.h>
|
|
#include <mlir/Dialect/Affine/IR/AffineOps.h>
|
|
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
|
#include <mlir/Dialect/Arith/IR/Arith.h>
|
|
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
|
#include <mlir/Transforms/DialectConversion.h>
|
|
|
|
#include "Passes.h"
|
|
#include "Dialect.h"
|
|
|
|
static mlir::MemRefType convertTensorTypeTOMemRefType(const mlir::RankedTensorType& tensorType)
|
|
{
|
|
return mlir::MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
|
}
|
|
|
|
static mlir::Value insertAllocAndDealloc(const mlir::MemRefType& type, const mlir::Location location,
|
|
mlir::PatternRewriter& rewriter)
|
|
{
|
|
auto allocateOperation = rewriter.create<mlir::memref::AllocOp>(location, type);
|
|
auto* parentBlock = allocateOperation->getBlock();
|
|
|
|
// Move the allocate operation before the first operation in this block.
|
|
allocateOperation->moveBefore(&parentBlock->front());
|
|
|
|
auto deallocateOperation = rewriter.create<mlir::memref::DeallocOp>(location, allocateOperation);
|
|
// Move the release operation before the last operation in this block, as the last operation
|
|
// usually is the return operation.
|
|
deallocateOperation->moveBefore(&parentBlock->back());
|
|
|
|
return allocateOperation;
|
|
}
|
|
|
|
/// The function type used for process an iteration of a lowered loop.
|
|
/// The `memRefOperands` is the operands of the input operations.
|
|
/// The `loopIterators` is the induction variables for the iteration.
|
|
/// The return value is the value to store and the current index of the iteration.
|
|
using LoopIterationFunction = mlir::function_ref<mlir::Value(mlir::OpBuilder& rewriter, mlir::ValueRange memRefOperands,
|
|
mlir::ValueRange loopIterators)>;
|
|
|
|
/// Helper function to lower operation to loops.
|
|
///
|
|
static void lowerOperationToLoops(mlir::Operation* op, mlir::ValueRange operands, mlir::PatternRewriter& rewriter,
|
|
LoopIterationFunction function)
|
|
{
|
|
auto tensorType = llvm::cast<mlir::RankedTensorType>(op->getResultTypes().front());
|
|
auto location = op->getLoc();
|
|
|
|
auto memRefType = convertTensorTypeTOMemRefType(tensorType);
|
|
auto allocation = insertAllocAndDealloc(memRefType, location, rewriter);
|
|
|
|
// Create a loop for every dimension of the shape.
|
|
// This vector stores the beginning value of each loop, as there all zeros.
|
|
llvm::SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), 0);
|
|
// This vector stores the step of each loop, as there all ones.
|
|
llvm::SmallVector<int64_t, 4> steps(tensorType.getRank(), 1);
|
|
|
|
|
|
mlir::affine::buildAffineLoopNest(rewriter, location, lowerBounds, tensorType.getShape(), steps, [&](
|
|
mlir::OpBuilder& nestedBuilder, const mlir::Location nestedLocation,
|
|
mlir::ValueRange iterators)
|
|
{
|
|
mlir::Value storedValue = function(nestedBuilder, operands, iterators);
|
|
nestedBuilder.create<mlir::affine::AffineStoreOp>(
|
|
nestedLocation, storedValue, allocation, iterators);
|
|
});
|
|
|
|
// Use the created buffer to replace the operation.
|
|
rewriter.replaceOp(op, allocation);
|
|
}
|
|
|
|
namespace
|
|
{
|
|
/// A generic pattern to convert a hello binary operation into an arith operation.
|
|
template <typename BinaryOperation, typename LoweredBinaryOperation>
|
|
struct BinaryOperationLoweringPattern : mlir::ConversionPattern
|
|
{
|
|
explicit BinaryOperationLoweringPattern(mlir::MLIRContext* context) : ConversionPattern(
|
|
BinaryOperation::getOperationName(), 1, context)
|
|
{
|
|
}
|
|
|
|
mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef<mlir::Value> operands,
|
|
mlir::ConversionPatternRewriter& rewriter) const final
|
|
{
|
|
auto location = op->getLoc();
|
|
lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder,
|
|
mlir::ValueRange memRefOperands,
|
|
mlir::ValueRange loopIterators)
|
|
{
|
|
// The adaptor is generated automatically by the framework.
|
|
// This adaptor will convert
|
|
typename BinaryOperation::Adaptor binaryOperationAdaptor(memRefOperands);
|
|
|
|
auto loadLeftOperand = builder.create<mlir::affine::AffineLoadOp>(
|
|
location, binaryOperationAdaptor.getLhs(), loopIterators);
|
|
auto loadRightOperand = builder.create<mlir::affine::AffineLoadOp>(
|
|
location, binaryOperationAdaptor.getRhs(), loopIterators);
|
|
|
|
return builder.create<LoweredBinaryOperation>(location, loadLeftOperand, loadRightOperand);
|
|
});
|
|
|
|
return mlir::success();
|
|
}
|
|
};
|
|
}
|
|
|
|
using AddOperationLoweringPattern = BinaryOperationLoweringPattern<mlir::hello::AddOp, mlir::arith::AddFOp>;
|
|
using MulOperationLoweringPattern = BinaryOperationLoweringPattern<mlir::hello::MulOp, mlir::arith::MulFOp>;
|
|
|
|
namespace
|
|
{
|
|
/// Lower the constant values into the buffer.
|
|
struct ConstantOperationLoweringPattern : mlir::OpRewritePattern<mlir::hello::ConstantOp>
|
|
{
|
|
/// Constructor.
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult matchAndRewrite(mlir::hello::ConstantOp op, mlir::PatternRewriter& rewriter) const final
|
|
{
|
|
mlir::DenseElementsAttr constantValues = op.getValue();
|
|
mlir::Location location = op->getLoc();
|
|
|
|
// Allocate buffer for these constant values.
|
|
mlir::RankedTensorType tensorType = llvm::cast<mlir::RankedTensorType>(op.getType());
|
|
mlir::MemRefType memRefTypes = convertTensorTypeTOMemRefType(tensorType);
|
|
mlir::Value allocationBuffer = insertAllocAndDealloc(memRefTypes, location, rewriter);
|
|
|
|
// Generate the constant indices up to the largest dimension,
|
|
// so that we can avoid a amount of redundant operations.
|
|
auto valueShape = memRefTypes.getShape();
|
|
llvm::SmallVector<mlir::Value, 8> constantIndices;
|
|
|
|
if (!valueShape.empty())
|
|
{
|
|
for (const auto i : llvm::seq<int64_t>(0, *llvm::max_element(valueShape)))
|
|
{
|
|
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(location, i));
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// If the value shape is empty, as the tensor of rank 0, so this is just a number.
|
|
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(location, 0));
|
|
}
|
|
|
|
// The store the constant values into the buffer.
|
|
// Define the recursive function to store.
|
|
|
|
// Vector to store the indices.
|
|
llvm::SmallVector<mlir::Value, 2> indices;
|
|
// Iterator to iterate the constant values.
|
|
auto valueIterator = constantValues.value_begin<mlir::FloatAttr>();
|
|
|
|
std::function<void(uint64_t)> storeElementFunc = [&](const uint64_t dimension)
|
|
{
|
|
// When the dimension reach the size of shape, we reach the end of recursion,
|
|
// where we store the value.
|
|
if (dimension == valueShape.size())
|
|
{
|
|
rewriter.create<mlir::affine::AffineStoreOp>(
|
|
location, rewriter.create<mlir::arith::ConstantOp>(location, *valueIterator), allocationBuffer,
|
|
llvm::ArrayRef(indices));
|
|
++valueIterator;
|
|
return;
|
|
}
|
|
|
|
// There build the indices.
|
|
for (const auto i : llvm::seq<int64_t>(0, valueShape[dimension]))
|
|
{
|
|
indices.push_back(constantIndices[i]);
|
|
storeElementFunc(dimension + 1);
|
|
indices.pop_back();
|
|
}
|
|
};
|
|
|
|
// Start the recursion from dimension 0;
|
|
storeElementFunc(0);
|
|
|
|
rewriter.replaceOp(op, allocationBuffer);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Lower the hello::func to func:func.
|
|
struct FuncOpLoweringPattern : mlir::OpConversionPattern<mlir::hello::FuncOp>
|
|
{
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
mlir::LogicalResult matchAndRewrite(mlir::hello::FuncOp op, OpAdaptor adaptor,
|
|
mlir::ConversionPatternRewriter& rewriter) const final
|
|
{
|
|
// Only lowering the main function as assuming that other functions have been inlined.
|
|
if (op.getName() != "main")
|
|
{
|
|
return mlir::failure();
|
|
}
|
|
|
|
// Validate the main function no arguments and no return values.
|
|
if (op.getNumArguments() != 0 || op.getNumResults() != 0)
|
|
{
|
|
return rewriter.notifyMatchFailure(op, [](mlir::Diagnostic& diagnostic)
|
|
{
|
|
diagnostic << "Expect the 'main' function to have no arguments and no result.";
|
|
});
|
|
}
|
|
|
|
// Create a new function with the same region.
|
|
auto function = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(), op.getFunctionType());
|
|
rewriter.inlineRegionBefore(op.getRegion(), function.getBody(), function.end());
|
|
rewriter.eraseOp(op);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Lower the hello::print, just replace the operand.
|
|
struct PrintOpLoweringPattern : mlir::OpConversionPattern<mlir::hello::PrintOp>
|
|
{
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
mlir::LogicalResult matchAndRewrite(mlir::hello::PrintOp op, OpAdaptor adaptor,
|
|
mlir::ConversionPatternRewriter& rewriter) const final
|
|
{
|
|
rewriter.modifyOpInPlace(op, [&]
|
|
{
|
|
op->setOperands(adaptor.getOperands());
|
|
});
|
|
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Lower hello::return to func::return.
|
|
struct ReturnOpLoweringPattern : mlir::OpRewritePattern<mlir::hello::ReturnOp>
|
|
{
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult matchAndRewrite(mlir::hello::ReturnOp op, mlir::PatternRewriter& rewriter) const final
|
|
{
|
|
// As all function calls have been inlined,
|
|
// So there is only one return in main function and return no value.
|
|
if (op.hasOperand())
|
|
{
|
|
return mlir::failure();
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Lower the transpose operation to affine loops.
|
|
struct TransposeOpLoweringPattern : mlir::ConversionPattern
|
|
{
|
|
explicit TransposeOpLoweringPattern(mlir::MLIRContext* context) : ConversionPattern(
|
|
mlir::hello::TransposeOp::getOperationName(),
|
|
1, context)
|
|
{
|
|
}
|
|
|
|
mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef<mlir::Value> operands,
|
|
mlir::ConversionPatternRewriter& rewriter) const final
|
|
{
|
|
auto location = op->getLoc();
|
|
|
|
lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder,
|
|
mlir::ValueRange memRefOperands,
|
|
mlir::ValueRange loopIterators)
|
|
{
|
|
mlir::hello::TransposeOpAdaptor adaptor(memRefOperands);
|
|
mlir::Value input = adaptor.getInput();
|
|
|
|
// Transpose operation just reverse the indices array as (x, y) -> (y, x).
|
|
llvm::SmallVector<mlir::Value, 2> reversedIterators(llvm::reverse(loopIterators));
|
|
return builder.create<mlir::affine::AffineLoadOp>(location, input, reversedIterators);
|
|
});
|
|
|
|
return mlir::success();
|
|
}
|
|
};
|
|
}
|
|
|
|
namespace
|
|
{
|
|
struct HelloToAffineLoopsLoweringPass : mlir::PassWrapper<
|
|
HelloToAffineLoopsLoweringPass, mlir::OperationPass<mlir::ModuleOp>>
|
|
{
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToAffineLoopsLoweringPass);
|
|
|
|
/// Argument used by command line options.
|
|
llvm::StringRef getArgument() const override
|
|
{
|
|
return "hello-to-affine-loops";
|
|
}
|
|
|
|
void getDependentDialects(mlir::DialectRegistry& registry) const override
|
|
{
|
|
registry.insert<mlir::affine::AffineDialect, mlir::func::FuncDialect, mlir::arith::ArithDialect,
|
|
mlir::memref::MemRefDialect>();
|
|
}
|
|
|
|
void runOnOperation() final;
|
|
};
|
|
}
|
|
|
|
void HelloToAffineLoopsLoweringPass::runOnOperation()
|
|
{
|
|
mlir::ConversionTarget target(getContext());
|
|
|
|
// Only allow the affine, func, arith and memRef dialect.
|
|
target.addLegalDialect<mlir::affine::AffineDialect, mlir::func::FuncDialect, mlir::arith::ArithDialect,
|
|
mlir::memref::MemRefDialect>();
|
|
// Disable hello dialect.
|
|
target.addIllegalDialect<mlir::hello::HelloDialect>();
|
|
// Manual allow the print op which will be allowed to llvm dialect.
|
|
target.addDynamicallyLegalOp<mlir::hello::PrintOp>([](const mlir::hello::PrintOp op)
|
|
{
|
|
// Make sure the operand is memRef op, as non of the operand types is tensor.
|
|
return llvm::none_of(op->getOperandTypes(), llvm::IsaPred<mlir::TensorType>);
|
|
});
|
|
|
|
// Construct the patterns.
|
|
mlir::RewritePatternSet set(&getContext());
|
|
set.add<AddOperationLoweringPattern, MulOperationLoweringPattern, ConstantOperationLoweringPattern,
|
|
ConstantOperationLoweringPattern, PrintOpLoweringPattern, FuncOpLoweringPattern, ReturnOpLoweringPattern,
|
|
TransposeOpLoweringPattern>(&getContext());
|
|
|
|
|
|
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(set))))
|
|
{
|
|
signalPassFailure();
|
|
}
|
|
}
|
|
|
|
|
|
namespace mlir::hello
|
|
{
|
|
std::unique_ptr<Pass> createLowerToAffineLoopsPass()
|
|
{
|
|
return std::make_unique<HelloToAffineLoopsLoweringPass>();
|
|
}
|
|
}
|