hello-mlir/lib/LowerToAffineLoopsPass.cpp
jackfiled c5ab1a6bc0
feat: toy tutorial chapter 5.
Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
2025-06-05 23:46:52 +08:00

347 lines
14 KiB
C++

//
// Created by ricardo on 05/06/25.
//
#include <mlir/IR/BuiltinOps.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Transforms/DialectConversion.h>
#include "Passes.h"
#include "Dialect.h"
static mlir::MemRefType convertTensorTypeTOMemRefType(const mlir::RankedTensorType& tensorType)
{
return mlir::MemRefType::get(tensorType.getShape(), tensorType.getElementType());
}
static mlir::Value insertAllocAndDealloc(const mlir::MemRefType& type, const mlir::Location location,
mlir::PatternRewriter& rewriter)
{
auto allocateOperation = rewriter.create<mlir::memref::AllocOp>(location, type);
auto* parentBlock = allocateOperation->getBlock();
// Move the allocate operation before the first operation in this block.
allocateOperation->moveBefore(&parentBlock->front());
auto deallocateOperation = rewriter.create<mlir::memref::DeallocOp>(location, allocateOperation);
// Move the release operation before the last operation in this block, as the last operation
// usually is the return operation.
deallocateOperation->moveBefore(&parentBlock->back());
return allocateOperation;
}
/// The function type used for process an iteration of a lowered loop.
/// The `memRefOperands` is the operands of the input operations.
/// The `loopIterators` is the induction variables for the iteration.
/// The return value is the value to store and the current index of the iteration.
using LoopIterationFunction = mlir::function_ref<mlir::Value(mlir::OpBuilder& rewriter, mlir::ValueRange memRefOperands,
mlir::ValueRange loopIterators)>;
/// Helper function to lower operation to loops.
///
static void lowerOperationToLoops(mlir::Operation* op, mlir::ValueRange operands, mlir::PatternRewriter& rewriter,
LoopIterationFunction function)
{
auto tensorType = llvm::cast<mlir::RankedTensorType>(op->getResultTypes().front());
auto location = op->getLoc();
auto memRefType = convertTensorTypeTOMemRefType(tensorType);
auto allocation = insertAllocAndDealloc(memRefType, location, rewriter);
// Create a loop for every dimension of the shape.
// This vector stores the beginning value of each loop, as there all zeros.
llvm::SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), 0);
// This vector stores the step of each loop, as there all ones.
llvm::SmallVector<int64_t, 4> steps(tensorType.getRank(), 1);
mlir::affine::buildAffineLoopNest(rewriter, location, lowerBounds, tensorType.getShape(), steps, [&](
mlir::OpBuilder& nestedBuilder, const mlir::Location nestedLocation,
mlir::ValueRange iterators)
{
mlir::Value storedValue = function(nestedBuilder, operands, iterators);
nestedBuilder.create<mlir::affine::AffineStoreOp>(
nestedLocation, storedValue, allocation, iterators);
});
// Use the created buffer to replace the operation.
rewriter.replaceOp(op, allocation);
}
namespace
{
/// A generic pattern to convert a hello binary operation into an arith operation.
template <typename BinaryOperation, typename LoweredBinaryOperation>
struct BinaryOperationLoweringPattern : mlir::ConversionPattern
{
explicit BinaryOperationLoweringPattern(mlir::MLIRContext* context) : ConversionPattern(
BinaryOperation::getOperationName(), 1, context)
{
}
mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter& rewriter) const final
{
auto location = op->getLoc();
lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder,
mlir::ValueRange memRefOperands,
mlir::ValueRange loopIterators)
{
// The adaptor is generated automatically by the framework.
// This adaptor will convert
typename BinaryOperation::Adaptor binaryOperationAdaptor(memRefOperands);
auto loadLeftOperand = builder.create<mlir::affine::AffineLoadOp>(
location, binaryOperationAdaptor.getLhs(), loopIterators);
auto loadRightOperand = builder.create<mlir::affine::AffineLoadOp>(
location, binaryOperationAdaptor.getRhs(), loopIterators);
return builder.create<LoweredBinaryOperation>(location, loadLeftOperand, loadRightOperand);
});
return mlir::success();
}
};
}
using AddOperationLoweringPattern = BinaryOperationLoweringPattern<mlir::hello::AddOp, mlir::arith::AddFOp>;
using MulOperationLoweringPattern = BinaryOperationLoweringPattern<mlir::hello::MulOp, mlir::arith::MulFOp>;
namespace
{
/// Lower the constant values into the buffer.
struct ConstantOperationLoweringPattern : mlir::OpRewritePattern<mlir::hello::ConstantOp>
{
/// Constructor.
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(mlir::hello::ConstantOp op, mlir::PatternRewriter& rewriter) const final
{
mlir::DenseElementsAttr constantValues = op.getValue();
mlir::Location location = op->getLoc();
// Allocate buffer for these constant values.
mlir::RankedTensorType tensorType = llvm::cast<mlir::RankedTensorType>(op.getType());
mlir::MemRefType memRefTypes = convertTensorTypeTOMemRefType(tensorType);
mlir::Value allocationBuffer = insertAllocAndDealloc(memRefTypes, location, rewriter);
// Generate the constant indices up to the largest dimension,
// so that we can avoid a amount of redundant operations.
auto valueShape = memRefTypes.getShape();
llvm::SmallVector<mlir::Value, 8> constantIndices;
if (!valueShape.empty())
{
for (const auto i : llvm::seq<int64_t>(0, *llvm::max_element(valueShape)))
{
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(location, i));
}
}
else
{
// If the value shape is empty, as the tensor of rank 0, so this is just a number.
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(location, 0));
}
// The store the constant values into the buffer.
// Define the recursive function to store.
// Vector to store the indices.
llvm::SmallVector<mlir::Value, 2> indices;
// Iterator to iterate the constant values.
auto valueIterator = constantValues.value_begin<mlir::FloatAttr>();
std::function<void(uint64_t)> storeElementFunc = [&](const uint64_t dimension)
{
// When the dimension reach the size of shape, we reach the end of recursion,
// where we store the value.
if (dimension == valueShape.size())
{
rewriter.create<mlir::affine::AffineStoreOp>(
location, rewriter.create<mlir::arith::ConstantOp>(location, *valueIterator), allocationBuffer,
llvm::ArrayRef(indices));
++valueIterator;
return;
}
// There build the indices.
for (const auto i : llvm::seq<int64_t>(0, valueShape[dimension]))
{
indices.push_back(constantIndices[i]);
storeElementFunc(dimension + 1);
indices.pop_back();
}
};
// Start the recursion from dimension 0;
storeElementFunc(0);
rewriter.replaceOp(op, allocationBuffer);
return mlir::success();
}
};
/// Lower the hello::func to func:func.
struct FuncOpLoweringPattern : mlir::OpConversionPattern<mlir::hello::FuncOp>
{
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::hello::FuncOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter& rewriter) const final
{
// Only lowering the main function as assuming that other functions have been inlined.
if (op.getName() != "main")
{
return mlir::failure();
}
// Validate the main function no arguments and no return values.
if (op.getNumArguments() != 0 || op.getNumResults() != 0)
{
return rewriter.notifyMatchFailure(op, [](mlir::Diagnostic& diagnostic)
{
diagnostic << "Expect the 'main' function to have no arguments and no result.";
});
}
// Create a new function with the same region.
auto function = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(), op.getFunctionType());
rewriter.inlineRegionBefore(op.getRegion(), function.getBody(), function.end());
rewriter.eraseOp(op);
return mlir::success();
}
};
/// Lower the hello::print, just replace the operand.
struct PrintOpLoweringPattern : mlir::OpConversionPattern<mlir::hello::PrintOp>
{
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::hello::PrintOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter& rewriter) const final
{
rewriter.modifyOpInPlace(op, [&]
{
op->setOperands(adaptor.getOperands());
});
return mlir::success();
}
};
/// Lower hello::return to func::return.
struct ReturnOpLoweringPattern : mlir::OpRewritePattern<mlir::hello::ReturnOp>
{
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(mlir::hello::ReturnOp op, mlir::PatternRewriter& rewriter) const final
{
// As all function calls have been inlined,
// So there is only one return in main function and return no value.
if (op.hasOperand())
{
return mlir::failure();
}
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op);
return mlir::success();
}
};
/// Lower the transpose operation to affine loops.
struct TransposeOpLoweringPattern : mlir::ConversionPattern
{
explicit TransposeOpLoweringPattern(mlir::MLIRContext* context) : ConversionPattern(
mlir::hello::TransposeOp::getOperationName(),
1, context)
{
}
mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter& rewriter) const final
{
auto location = op->getLoc();
lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder,
mlir::ValueRange memRefOperands,
mlir::ValueRange loopIterators)
{
mlir::hello::TransposeOpAdaptor adaptor(memRefOperands);
mlir::Value input = adaptor.getInput();
// Transpose operation just reverse the indices array as (x, y) -> (y, x).
llvm::SmallVector<mlir::Value, 2> reversedIterators(llvm::reverse(loopIterators));
return builder.create<mlir::affine::AffineLoadOp>(location, input, reversedIterators);
});
return mlir::success();
}
};
}
namespace
{
struct HelloToAffineLoopsLoweringPass : mlir::PassWrapper<
HelloToAffineLoopsLoweringPass, mlir::OperationPass<mlir::ModuleOp>>
{
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToAffineLoopsLoweringPass);
/// Argument used by command line options.
llvm::StringRef getArgument() const override
{
return "hello-to-affine-loops";
}
void getDependentDialects(mlir::DialectRegistry& registry) const override
{
registry.insert<mlir::affine::AffineDialect, mlir::func::FuncDialect, mlir::arith::ArithDialect,
mlir::memref::MemRefDialect>();
}
void runOnOperation() final;
};
}
void HelloToAffineLoopsLoweringPass::runOnOperation()
{
mlir::ConversionTarget target(getContext());
// Only allow the affine, func, arith and memRef dialect.
target.addLegalDialect<mlir::affine::AffineDialect, mlir::func::FuncDialect, mlir::arith::ArithDialect,
mlir::memref::MemRefDialect>();
// Disable hello dialect.
target.addIllegalDialect<mlir::hello::HelloDialect>();
// Manual allow the print op which will be allowed to llvm dialect.
target.addDynamicallyLegalOp<mlir::hello::PrintOp>([](const mlir::hello::PrintOp op)
{
// Make sure the operand is memRef op, as non of the operand types is tensor.
return llvm::none_of(op->getOperandTypes(), llvm::IsaPred<mlir::TensorType>);
});
// Construct the patterns.
mlir::RewritePatternSet set(&getContext());
set.add<AddOperationLoweringPattern, MulOperationLoweringPattern, ConstantOperationLoweringPattern,
ConstantOperationLoweringPattern, PrintOpLoweringPattern, FuncOpLoweringPattern, ReturnOpLoweringPattern,
TransposeOpLoweringPattern>(&getContext());
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(set))))
{
signalPassFailure();
}
}
namespace mlir::hello
{
std::unique_ptr<Pass> createLowerToAffineLoopsPass()
{
return std::make_unique<HelloToAffineLoopsLoweringPass>();
}
}