feat: toy tutorial chapter 5.
Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
346
lib/LowerToAffineLoopsPass.cpp
Normal file
346
lib/LowerToAffineLoopsPass.cpp
Normal file
@@ -0,0 +1,346 @@
|
||||
//
|
||||
// Created by ricardo on 05/06/25.
|
||||
//
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <mlir/Dialect/Affine/IR/AffineOps.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
#include "Passes.h"
|
||||
#include "Dialect.h"
|
||||
|
||||
static mlir::MemRefType convertTensorTypeTOMemRefType(const mlir::RankedTensorType& tensorType)
|
||||
{
|
||||
return mlir::MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
}
|
||||
|
||||
static mlir::Value insertAllocAndDealloc(const mlir::MemRefType& type, const mlir::Location location,
|
||||
mlir::PatternRewriter& rewriter)
|
||||
{
|
||||
auto allocateOperation = rewriter.create<mlir::memref::AllocOp>(location, type);
|
||||
auto* parentBlock = allocateOperation->getBlock();
|
||||
|
||||
// Move the allocate operation before the first operation in this block.
|
||||
allocateOperation->moveBefore(&parentBlock->front());
|
||||
|
||||
auto deallocateOperation = rewriter.create<mlir::memref::DeallocOp>(location, allocateOperation);
|
||||
// Move the release operation before the last operation in this block, as the last operation
|
||||
// usually is the return operation.
|
||||
deallocateOperation->moveBefore(&parentBlock->back());
|
||||
|
||||
return allocateOperation;
|
||||
}
|
||||
|
||||
/// The function type used for process an iteration of a lowered loop.
|
||||
/// The `memRefOperands` is the operands of the input operations.
|
||||
/// The `loopIterators` is the induction variables for the iteration.
|
||||
/// The return value is the value to store and the current index of the iteration.
|
||||
using LoopIterationFunction = mlir::function_ref<mlir::Value(mlir::OpBuilder& rewriter, mlir::ValueRange memRefOperands,
|
||||
mlir::ValueRange loopIterators)>;
|
||||
|
||||
/// Helper function to lower operation to loops.
|
||||
///
|
||||
static void lowerOperationToLoops(mlir::Operation* op, mlir::ValueRange operands, mlir::PatternRewriter& rewriter,
|
||||
LoopIterationFunction function)
|
||||
{
|
||||
auto tensorType = llvm::cast<mlir::RankedTensorType>(op->getResultTypes().front());
|
||||
auto location = op->getLoc();
|
||||
|
||||
auto memRefType = convertTensorTypeTOMemRefType(tensorType);
|
||||
auto allocation = insertAllocAndDealloc(memRefType, location, rewriter);
|
||||
|
||||
// Create a loop for every dimension of the shape.
|
||||
// This vector stores the beginning value of each loop, as there all zeros.
|
||||
llvm::SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), 0);
|
||||
// This vector stores the step of each loop, as there all ones.
|
||||
llvm::SmallVector<int64_t, 4> steps(tensorType.getRank(), 1);
|
||||
|
||||
|
||||
mlir::affine::buildAffineLoopNest(rewriter, location, lowerBounds, tensorType.getShape(), steps, [&](
|
||||
mlir::OpBuilder& nestedBuilder, const mlir::Location nestedLocation,
|
||||
mlir::ValueRange iterators)
|
||||
{
|
||||
mlir::Value storedValue = function(nestedBuilder, operands, iterators);
|
||||
nestedBuilder.create<mlir::affine::AffineStoreOp>(
|
||||
nestedLocation, storedValue, allocation, iterators);
|
||||
});
|
||||
|
||||
// Use the created buffer to replace the operation.
|
||||
rewriter.replaceOp(op, allocation);
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
/// A generic pattern to convert a hello binary operation into an arith operation.
|
||||
template <typename BinaryOperation, typename LoweredBinaryOperation>
|
||||
struct BinaryOperationLoweringPattern : mlir::ConversionPattern
|
||||
{
|
||||
explicit BinaryOperationLoweringPattern(mlir::MLIRContext* context) : ConversionPattern(
|
||||
BinaryOperation::getOperationName(), 1, context)
|
||||
{
|
||||
}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef<mlir::Value> operands,
|
||||
mlir::ConversionPatternRewriter& rewriter) const final
|
||||
{
|
||||
auto location = op->getLoc();
|
||||
lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder,
|
||||
mlir::ValueRange memRefOperands,
|
||||
mlir::ValueRange loopIterators)
|
||||
{
|
||||
// The adaptor is generated automatically by the framework.
|
||||
// This adaptor will convert
|
||||
typename BinaryOperation::Adaptor binaryOperationAdaptor(memRefOperands);
|
||||
|
||||
auto loadLeftOperand = builder.create<mlir::affine::AffineLoadOp>(
|
||||
location, binaryOperationAdaptor.getLhs(), loopIterators);
|
||||
auto loadRightOperand = builder.create<mlir::affine::AffineLoadOp>(
|
||||
location, binaryOperationAdaptor.getRhs(), loopIterators);
|
||||
|
||||
return builder.create<LoweredBinaryOperation>(location, loadLeftOperand, loadRightOperand);
|
||||
});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
using AddOperationLoweringPattern = BinaryOperationLoweringPattern<mlir::hello::AddOp, mlir::arith::AddFOp>;
|
||||
using MulOperationLoweringPattern = BinaryOperationLoweringPattern<mlir::hello::MulOp, mlir::arith::MulFOp>;
|
||||
|
||||
namespace
|
||||
{
|
||||
/// Lower the constant values into the buffer.
|
||||
struct ConstantOperationLoweringPattern : mlir::OpRewritePattern<mlir::hello::ConstantOp>
|
||||
{
|
||||
/// Constructor.
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::hello::ConstantOp op, mlir::PatternRewriter& rewriter) const final
|
||||
{
|
||||
mlir::DenseElementsAttr constantValues = op.getValue();
|
||||
mlir::Location location = op->getLoc();
|
||||
|
||||
// Allocate buffer for these constant values.
|
||||
mlir::RankedTensorType tensorType = llvm::cast<mlir::RankedTensorType>(op.getType());
|
||||
mlir::MemRefType memRefTypes = convertTensorTypeTOMemRefType(tensorType);
|
||||
mlir::Value allocationBuffer = insertAllocAndDealloc(memRefTypes, location, rewriter);
|
||||
|
||||
// Generate the constant indices up to the largest dimension,
|
||||
// so that we can avoid a amount of redundant operations.
|
||||
auto valueShape = memRefTypes.getShape();
|
||||
llvm::SmallVector<mlir::Value, 8> constantIndices;
|
||||
|
||||
if (!valueShape.empty())
|
||||
{
|
||||
for (const auto i : llvm::seq<int64_t>(0, *llvm::max_element(valueShape)))
|
||||
{
|
||||
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(location, i));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// If the value shape is empty, as the tensor of rank 0, so this is just a number.
|
||||
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(location, 0));
|
||||
}
|
||||
|
||||
// The store the constant values into the buffer.
|
||||
// Define the recursive function to store.
|
||||
|
||||
// Vector to store the indices.
|
||||
llvm::SmallVector<mlir::Value, 2> indices;
|
||||
// Iterator to iterate the constant values.
|
||||
auto valueIterator = constantValues.value_begin<mlir::FloatAttr>();
|
||||
|
||||
std::function<void(uint64_t)> storeElementFunc = [&](const uint64_t dimension)
|
||||
{
|
||||
// When the dimension reach the size of shape, we reach the end of recursion,
|
||||
// where we store the value.
|
||||
if (dimension == valueShape.size())
|
||||
{
|
||||
rewriter.create<mlir::affine::AffineStoreOp>(
|
||||
location, rewriter.create<mlir::arith::ConstantOp>(location, *valueIterator), allocationBuffer,
|
||||
llvm::ArrayRef(indices));
|
||||
++valueIterator;
|
||||
return;
|
||||
}
|
||||
|
||||
// There build the indices.
|
||||
for (const auto i : llvm::seq<int64_t>(0, valueShape[dimension]))
|
||||
{
|
||||
indices.push_back(constantIndices[i]);
|
||||
storeElementFunc(dimension + 1);
|
||||
indices.pop_back();
|
||||
}
|
||||
};
|
||||
|
||||
// Start the recursion from dimension 0;
|
||||
storeElementFunc(0);
|
||||
|
||||
rewriter.replaceOp(op, allocationBuffer);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Lower the hello::func to func:func.
|
||||
struct FuncOpLoweringPattern : mlir::OpConversionPattern<mlir::hello::FuncOp>
|
||||
{
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::hello::FuncOp op, OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter& rewriter) const final
|
||||
{
|
||||
// Only lowering the main function as assuming that other functions have been inlined.
|
||||
if (op.getName() != "main")
|
||||
{
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// Validate the main function no arguments and no return values.
|
||||
if (op.getNumArguments() != 0 || op.getNumResults() != 0)
|
||||
{
|
||||
return rewriter.notifyMatchFailure(op, [](mlir::Diagnostic& diagnostic)
|
||||
{
|
||||
diagnostic << "Expect the 'main' function to have no arguments and no result.";
|
||||
});
|
||||
}
|
||||
|
||||
// Create a new function with the same region.
|
||||
auto function = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(), op.getFunctionType());
|
||||
rewriter.inlineRegionBefore(op.getRegion(), function.getBody(), function.end());
|
||||
rewriter.eraseOp(op);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Lower the hello::print, just replace the operand.
|
||||
struct PrintOpLoweringPattern : mlir::OpConversionPattern<mlir::hello::PrintOp>
|
||||
{
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::hello::PrintOp op, OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter& rewriter) const final
|
||||
{
|
||||
rewriter.modifyOpInPlace(op, [&]
|
||||
{
|
||||
op->setOperands(adaptor.getOperands());
|
||||
});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Lower hello::return to func::return.
|
||||
struct ReturnOpLoweringPattern : mlir::OpRewritePattern<mlir::hello::ReturnOp>
|
||||
{
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::hello::ReturnOp op, mlir::PatternRewriter& rewriter) const final
|
||||
{
|
||||
// As all function calls have been inlined,
|
||||
// So there is only one return in main function and return no value.
|
||||
if (op.hasOperand())
|
||||
{
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Lower the transpose operation to affine loops.
|
||||
struct TransposeOpLoweringPattern : mlir::ConversionPattern
|
||||
{
|
||||
explicit TransposeOpLoweringPattern(mlir::MLIRContext* context) : ConversionPattern(
|
||||
mlir::hello::TransposeOp::getOperationName(),
|
||||
1, context)
|
||||
{
|
||||
}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef<mlir::Value> operands,
|
||||
mlir::ConversionPatternRewriter& rewriter) const final
|
||||
{
|
||||
auto location = op->getLoc();
|
||||
|
||||
lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder,
|
||||
mlir::ValueRange memRefOperands,
|
||||
mlir::ValueRange loopIterators)
|
||||
{
|
||||
mlir::hello::TransposeOpAdaptor adaptor(memRefOperands);
|
||||
mlir::Value input = adaptor.getInput();
|
||||
|
||||
// Transpose operation just reverse the indices array as (x, y) -> (y, x).
|
||||
llvm::SmallVector<mlir::Value, 2> reversedIterators(llvm::reverse(loopIterators));
|
||||
return builder.create<mlir::affine::AffineLoadOp>(location, input, reversedIterators);
|
||||
});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
struct HelloToAffineLoopsLoweringPass : mlir::PassWrapper<
|
||||
HelloToAffineLoopsLoweringPass, mlir::OperationPass<mlir::ModuleOp>>
|
||||
{
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToAffineLoopsLoweringPass);
|
||||
|
||||
/// Argument used by command line options.
|
||||
llvm::StringRef getArgument() const override
|
||||
{
|
||||
return "hello-to-affine-loops";
|
||||
}
|
||||
|
||||
void getDependentDialects(mlir::DialectRegistry& registry) const override
|
||||
{
|
||||
registry.insert<mlir::affine::AffineDialect, mlir::func::FuncDialect, mlir::arith::ArithDialect,
|
||||
mlir::memref::MemRefDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() final;
|
||||
};
|
||||
}
|
||||
|
||||
void HelloToAffineLoopsLoweringPass::runOnOperation()
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
|
||||
// Only allow the affine, func, arith and memRef dialect.
|
||||
target.addLegalDialect<mlir::affine::AffineDialect, mlir::func::FuncDialect, mlir::arith::ArithDialect,
|
||||
mlir::memref::MemRefDialect>();
|
||||
// Disable hello dialect.
|
||||
target.addIllegalDialect<mlir::hello::HelloDialect>();
|
||||
// Manual allow the print op which will be allowed to llvm dialect.
|
||||
target.addDynamicallyLegalOp<mlir::hello::PrintOp>([](const mlir::hello::PrintOp op)
|
||||
{
|
||||
// Make sure the operand is memRef op, as non of the operand types is tensor.
|
||||
return llvm::none_of(op->getOperandTypes(), llvm::IsaPred<mlir::TensorType>);
|
||||
});
|
||||
|
||||
// Construct the patterns.
|
||||
mlir::RewritePatternSet set(&getContext());
|
||||
set.add<AddOperationLoweringPattern, MulOperationLoweringPattern, ConstantOperationLoweringPattern,
|
||||
ConstantOperationLoweringPattern, PrintOpLoweringPattern, FuncOpLoweringPattern, ReturnOpLoweringPattern,
|
||||
TransposeOpLoweringPattern>(&getContext());
|
||||
|
||||
|
||||
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(set))))
|
||||
{
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
namespace mlir::hello
|
||||
{
|
||||
std::unique_ptr<Pass> createLowerToAffineLoopsPass()
|
||||
{
|
||||
return std::make_unique<HelloToAffineLoopsLoweringPass>();
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user