// // Created by ricardo on 07/06/25. // #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "Dialect.h" #include "Passes.h" namespace { class PrintOpLoweringPattern : public mlir::ConversionPattern { public: explicit PrintOpLoweringPattern(mlir::MLIRContext* context) : ConversionPattern( mlir::hello::PrintOp::getOperationName(), 1, context) { } mlir::LogicalResult matchAndRewrite(mlir::Operation* op, llvm::ArrayRef operands, mlir::ConversionPatternRewriter& rewriter) const final { auto* context = rewriter.getContext(); auto memRefType = llvm::cast(op->getOperandTypes().front()); auto memRefShape = memRefType.getShape(); auto location = op->getLoc(); auto parentModule = op->getParentOfType(); // Get the `printf` function declaration. auto printfRef = getOrInsertPrintf(rewriter, parentModule); // Create the format string in C format. mlir::Value formatSpecifierString = getOrCreateGlobalString(location, rewriter, "format_specifier", "%f \0", parentModule); // Create the LF format string in C format. mlir::Value newLineString = getOrCreateGlobalString(location, rewriter, "new_line", "\n\0", parentModule); // Create a loop to print the value in the tensor. llvm::SmallVector loopIterators; for (const auto i : llvm::seq(0, memRefShape.size())) { auto lowerBound = rewriter.create(location, 0); auto upperBound = rewriter.create(location, memRefShape[i]); auto step = rewriter.create(location, 1); auto loop = rewriter.create(location, lowerBound, upperBound, step); // FIXME: Remove the nested operation in loop, Why? for (mlir::Operation& nestedOperation : *loop.getBody()) { rewriter.eraseOp(&nestedOperation); } loopIterators.push_back(loop.getInductionVar()); // Place the new line output and terminator in the end of loop. rewriter.setInsertionPointToEnd(loop.getBody()); if (i != memRefShape.size() - 1) { // Add change line when in a row. rewriter.create(location, getPrintfFunctionType(context), printfRef, newLineString); } rewriter.create(location); // Then place the rewriter at the start of the loop, so when finishing once, // ths rewriter will at the start of newly created loop. rewriter.setInsertionPointToStart(loop.getBody()); } auto printOperation = llvm::cast(op); auto loadedElement = rewriter.create(location, printOperation.getInput(), loopIterators); rewriter.create(location, getPrintfFunctionType(context), printfRef, llvm::ArrayRef({formatSpecifierString, loadedElement})); // At last remove the print operation. rewriter.eraseOp(printOperation); return mlir::success(); } private: static mlir::LLVM::LLVMFunctionType getPrintfFunctionType(mlir::MLIRContext* context) { const auto llvmInteger32Type = mlir::IntegerType::get(context, 32); const auto llvmPointerType = mlir::LLVM::LLVMPointerType::get(context); // The `printf` is `int printf(char *, ...)`. const auto llvmFunctionType = mlir::LLVM::LLVMFunctionType::get(llvmInteger32Type, llvmPointerType, true); return llvmFunctionType; } static mlir::FlatSymbolRefAttr getOrInsertPrintf(mlir::PatternRewriter& rewriter, mlir::ModuleOp& module) { auto* context = module.getContext(); if (module.lookupSymbol("printf")) { return mlir::SymbolRefAttr::get(context, "printf"); } // Insert the `printf` declarations in the body of the parent module. mlir::PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); rewriter.create(module.getLoc(), "printf", getPrintfFunctionType(context)); return mlir::SymbolRefAttr::get(context, "printf"); } /// Create or get a global string used for print. static mlir::Value getOrCreateGlobalString(mlir::Location& loc, mlir::OpBuilder& builder, llvm::StringRef name, llvm::StringRef value, mlir::ModuleOp& module) { auto global = module.lookupSymbol(name); if (!global) { // Failed to find the global value, create it. mlir::OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(module.getBody()); auto stringType = mlir::LLVM::LLVMArrayType::get( mlir::IntegerType::get(builder.getContext(), 8), value.size()); global = builder.create(loc, stringType, true, mlir::LLVM::Linkage::Internal, name, builder.getStringAttr(value), 0); } // Get the pointer to the first character in the global string. mlir::Value globalPointer = builder.create(loc, global); mlir::Value constantZero = builder.create( loc, builder.getI64Type(), builder.getIndexAttr(0)); return builder.create(loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), globalPointer, llvm::ArrayRef({ constantZero, constantZero })); } }; struct HelloToLLVMLoweringPass : mlir::PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToLLVMLoweringPass) llvm::StringRef getArgument() const final { return "hello-to-llvm"; } void getDependentDialects(mlir::DialectRegistry& registry) const final { registry.insert(); } void runOnOperation() final; }; } void HelloToLLVMLoweringPass::runOnOperation() { mlir::LLVMConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); mlir::LLVMTypeConverter typeConverter(&getContext()); mlir::RewritePatternSet patterns(&getContext()); // The lower path is a little bit complicated. // Convert affine dialect to standard dialect. mlir::populateAffineToStdConversionPatterns(patterns); // Convert scf dialect to cf dialect. mlir::populateSCFToControlFlowConversionPatterns(patterns); // Convert arith to llvm dialect. mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); mlir::populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); mlir::populateFuncToLLVMConversionPatterns(typeConverter, patterns); patterns.add(&getContext()); // As we need to convert all operations to LLVM dialect, so // we perform a full convertion, which will permit the legal opertions in result. if (const auto module = getOperation(); mlir::failed(mlir::applyFullConversion(module, target, std::move(patterns)))) { signalPassFailure(); } } std::unique_ptr mlir::hello::createLowerToLLVMPass() { return std::make_unique(); }