// // Created by ricardo on 05/06/25. // #include #include #include #include #include #include #include #include "Passes.h" #include "Dialect.h" static mlir::MemRefType convertTensorTypeTOMemRefType(const mlir::RankedTensorType& tensorType) { return mlir::MemRefType::get(tensorType.getShape(), tensorType.getElementType()); } static mlir::Value insertAllocAndDealloc(const mlir::MemRefType& type, const mlir::Location location, mlir::PatternRewriter& rewriter) { auto allocateOperation = rewriter.create(location, type); auto* parentBlock = allocateOperation->getBlock(); // Move the allocate operation before the first operation in this block. allocateOperation->moveBefore(&parentBlock->front()); auto deallocateOperation = rewriter.create(location, allocateOperation); // Move the release operation before the last operation in this block, as the last operation // usually is the return operation. deallocateOperation->moveBefore(&parentBlock->back()); return allocateOperation; } /// The function type used for process an iteration of a lowered loop. /// The `memRefOperands` is the operands of the input operations. /// The `loopIterators` is the induction variables for the iteration. /// The return value is the value to store and the current index of the iteration. using LoopIterationFunction = mlir::function_ref; /// Helper function to lower operation to loops. /// static void lowerOperationToLoops(mlir::Operation* op, mlir::ValueRange operands, mlir::PatternRewriter& rewriter, LoopIterationFunction function) { auto tensorType = llvm::cast(op->getResultTypes().front()); auto location = op->getLoc(); auto memRefType = convertTensorTypeTOMemRefType(tensorType); auto allocation = insertAllocAndDealloc(memRefType, location, rewriter); // Create a loop for every dimension of the shape. // This vector stores the beginning value of each loop, as there all zeros. llvm::SmallVector lowerBounds(tensorType.getRank(), 0); // This vector stores the step of each loop, as there all ones. llvm::SmallVector steps(tensorType.getRank(), 1); mlir::affine::buildAffineLoopNest(rewriter, location, lowerBounds, tensorType.getShape(), steps, [&]( mlir::OpBuilder& nestedBuilder, const mlir::Location nestedLocation, mlir::ValueRange iterators) { mlir::Value storedValue = function(nestedBuilder, operands, iterators); nestedBuilder.create( nestedLocation, storedValue, allocation, iterators); }); // Use the created buffer to replace the operation. rewriter.replaceOp(op, allocation); } namespace { /// A generic pattern to convert a hello binary operation into an arith operation. template struct BinaryOperationLoweringPattern : mlir::ConversionPattern { explicit BinaryOperationLoweringPattern(mlir::MLIRContext* context) : ConversionPattern( BinaryOperation::getOperationName(), 1, context) { } mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef operands, mlir::ConversionPatternRewriter& rewriter) const final { auto location = op->getLoc(); lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder, mlir::ValueRange memRefOperands, mlir::ValueRange loopIterators) { // The adaptor is generated automatically by the framework. // This adaptor will convert typename BinaryOperation::Adaptor binaryOperationAdaptor(memRefOperands); auto loadLeftOperand = builder.create( location, binaryOperationAdaptor.getLhs(), loopIterators); auto loadRightOperand = builder.create( location, binaryOperationAdaptor.getRhs(), loopIterators); return builder.create(location, loadLeftOperand, loadRightOperand); }); return mlir::success(); } }; } using AddOperationLoweringPattern = BinaryOperationLoweringPattern; using MulOperationLoweringPattern = BinaryOperationLoweringPattern; namespace { /// Lower the constant values into the buffer. struct ConstantOperationLoweringPattern : mlir::OpRewritePattern { /// Constructor. using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(mlir::hello::ConstantOp op, mlir::PatternRewriter& rewriter) const final { mlir::DenseElementsAttr constantValues = op.getValue(); mlir::Location location = op->getLoc(); // Allocate buffer for these constant values. mlir::RankedTensorType tensorType = llvm::cast(op.getType()); mlir::MemRefType memRefTypes = convertTensorTypeTOMemRefType(tensorType); mlir::Value allocationBuffer = insertAllocAndDealloc(memRefTypes, location, rewriter); // Generate the constant indices up to the largest dimension, // so that we can avoid a amount of redundant operations. auto valueShape = memRefTypes.getShape(); llvm::SmallVector constantIndices; if (!valueShape.empty()) { for (const auto i : llvm::seq(0, *llvm::max_element(valueShape))) { constantIndices.push_back(rewriter.create(location, i)); } } else { // If the value shape is empty, as the tensor of rank 0, so this is just a number. constantIndices.push_back(rewriter.create(location, 0)); } // The store the constant values into the buffer. // Define the recursive function to store. // Vector to store the indices. llvm::SmallVector indices; // Iterator to iterate the constant values. auto valueIterator = constantValues.value_begin(); std::function storeElementFunc = [&](const uint64_t dimension) { // When the dimension reach the size of shape, we reach the end of recursion, // where we store the value. if (dimension == valueShape.size()) { rewriter.create( location, rewriter.create(location, *valueIterator), allocationBuffer, llvm::ArrayRef(indices)); ++valueIterator; return; } // There build the indices. for (const auto i : llvm::seq(0, valueShape[dimension])) { indices.push_back(constantIndices[i]); storeElementFunc(dimension + 1); indices.pop_back(); } }; // Start the recursion from dimension 0; storeElementFunc(0); rewriter.replaceOp(op, allocationBuffer); return mlir::success(); } }; /// Lower the hello::func to func:func. struct FuncOpLoweringPattern : mlir::OpConversionPattern { using OpConversionPattern::OpConversionPattern; mlir::LogicalResult matchAndRewrite(mlir::hello::FuncOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const final { // Only lowering the main function as assuming that other functions have been inlined. if (op.getName() != "main") { return mlir::failure(); } // Validate the main function no arguments and no return values. if (op.getNumArguments() != 0 || op.getNumResults() != 0) { return rewriter.notifyMatchFailure(op, [](mlir::Diagnostic& diagnostic) { diagnostic << "Expect the 'main' function to have no arguments and no result."; }); } // Create a new function with the same region. auto function = rewriter.create(op.getLoc(), op.getName(), op.getFunctionType()); rewriter.inlineRegionBefore(op.getRegion(), function.getBody(), function.end()); rewriter.eraseOp(op); return mlir::success(); } }; /// Lower the hello::print, just replace the operand. struct PrintOpLoweringPattern : mlir::OpConversionPattern { using OpConversionPattern::OpConversionPattern; mlir::LogicalResult matchAndRewrite(mlir::hello::PrintOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const final { rewriter.modifyOpInPlace(op, [&] { op->setOperands(adaptor.getOperands()); }); return mlir::success(); } }; /// Lower hello::return to func::return. struct ReturnOpLoweringPattern : mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(mlir::hello::ReturnOp op, mlir::PatternRewriter& rewriter) const final { // As all function calls have been inlined, // So there is only one return in main function and return no value. if (op.hasOperand()) { return mlir::failure(); } rewriter.replaceOpWithNewOp(op); return mlir::success(); } }; /// Lower the transpose operation to affine loops. struct TransposeOpLoweringPattern : mlir::ConversionPattern { explicit TransposeOpLoweringPattern(mlir::MLIRContext* context) : ConversionPattern( mlir::hello::TransposeOp::getOperationName(), 1, context) { } mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef operands, mlir::ConversionPatternRewriter& rewriter) const final { auto location = op->getLoc(); lowerOperationToLoops(op, operands, rewriter, [location](mlir::OpBuilder& builder, mlir::ValueRange memRefOperands, mlir::ValueRange loopIterators) { mlir::hello::TransposeOpAdaptor adaptor(memRefOperands); mlir::Value input = adaptor.getInput(); // Transpose operation just reverse the indices array as (x, y) -> (y, x). llvm::SmallVector reversedIterators(llvm::reverse(loopIterators)); return builder.create(location, input, reversedIterators); }); return mlir::success(); } }; } namespace { struct HelloToAffineLoopsLoweringPass : mlir::PassWrapper< HelloToAffineLoopsLoweringPass, mlir::OperationPass> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToAffineLoopsLoweringPass); /// Argument used by command line options. llvm::StringRef getArgument() const override { return "hello-to-affine-loops"; } void getDependentDialects(mlir::DialectRegistry& registry) const override { registry.insert(); } void runOnOperation() final; }; } void HelloToAffineLoopsLoweringPass::runOnOperation() { mlir::ConversionTarget target(getContext()); // Only allow the affine, func, arith and memRef dialect. target.addLegalDialect(); // Disable hello dialect. target.addIllegalDialect(); // Manual allow the print op which will be allowed to llvm dialect. target.addDynamicallyLegalOp([](const mlir::hello::PrintOp op) { // Make sure the operand is memRef op, as non of the operand types is tensor. return llvm::none_of(op->getOperandTypes(), llvm::IsaPred); }); // Construct the patterns. mlir::RewritePatternSet set(&getContext()); set.add(&getContext()); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(set)))) { signalPassFailure(); } } namespace mlir::hello { std::unique_ptr createLowerToAffineLoopsPass() { return std::make_unique(); } }