// // Created by ricardo on 29/05/25. // #include "Dialect.h" #include "hello/Dialect.cpp.inc" #include #include #include struct HelloDialectInlinerInterface : mlir::DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable, bool wouldBeCloned) const override { return true; } bool isLegalToInline(mlir::Region* dest, mlir::Region* src, bool wouldBeCloned, mlir::IRMapping& valueMapping) const override { return true; } bool isLegalToInline(mlir::Operation* op, mlir::Region* dest, bool wouldBeCloned, mlir::IRMapping& valueMapping) const override { return true; } void handleTerminator(mlir::Operation* op, mlir::ValueRange returnValues) const override { // Only the `hello.returnOp` is the function terminator auto returnOp = llvm::cast(op); assert(returnOp.getNumOperands() == returnValues.size()); for (const auto& it : llvm::enumerate(returnOp.getOperands())) { returnValues[it.index()].replaceAllUsesWith(it.value()); } } mlir::Operation* materializeCallConversion(mlir::OpBuilder& builder, mlir::Value input, mlir::Type resultType, mlir::Location conversionLoc) const override { return builder.create(conversionLoc, resultType, input); } }; void mlir::hello::HelloDialect::initialize() { addOperations< #define GET_OP_LIST #include "hello/Ops.cpp.inc" >(); addInterfaces(); } using namespace mlir; using namespace mlir::hello; /// A generalized parser for binary operations. This parses the different forms /// of 'printBinaryOp' below. static ParseResult parseBinaryOp(OpAsmParser& parser, OperationState& result) { SmallVector operands; SMLoc operandsLoc = parser.getCurrentLocation(); Type type; if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type)) return failure(); // If the type is a function type, it contains the input and result types of // this operation. if (FunctionType funcType = llvm::dyn_cast(type)) { if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, result.operands)) return failure(); result.addTypes(funcType.getResults()); return success(); } // Otherwise, the parsed type is the type of both operands and results. if (parser.resolveOperands(operands, type, result.operands)) return failure(); result.addTypes(type); return success(); } /// A generalized printer for binary operations. It prints in two different /// forms depending on if all of the types match. static void printBinaryOp(OpAsmPrinter& printer, Operation* op) { printer << " " << op->getOperands(); printer.printOptionalAttrDict(op->getAttrs()); printer << " : "; // If all of the types are the same, print the type directly. Type resultType = *op->result_type_begin(); if (llvm::all_of(op->getOperandTypes(), [=](Type type) { return type == resultType; })) { printer << resultType; return; } // Otherwise, print a functional type. printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); } //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// /// Build a constant operation. /// The builder is passed as an argument, so is the state that this method is /// expected to fill in order to build the operation. void ConstantOp::build(OpBuilder& builder, OperationState& state, double value) { auto dataType = RankedTensorType::get({}, builder.getF64Type()); auto dataAttribute = DenseElementsAttr::get(dataType, value); build(builder, state, dataType, dataAttribute); } /// The 'OpAsmParser' class provides a collection of methods for parsing /// various punctuation, as well as attributes, operands, types, etc. Each of /// these methods returns a `ParseResult`. This class is a wrapper around /// `LogicalResult` that can be converted to a boolean `true` value on failure, /// or `false` on success. This allows for easily chaining together a set of /// parser rules. These rules are used to populate an `mlir::OperationState` /// similarly to the `build` methods described above. ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) { DenseElementsAttr value; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseAttribute(value, "value", result.attributes)) return failure(); result.addTypes(value.getType()); return success(); } /// The 'OpAsmPrinter' class is a stream that allows for formatting /// strings, attributes, operands, types, etc. void ConstantOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); printer << getValue(); } /// Verifier for the constant operation. This corresponds to the /// `let hasVerifier = 1` in the op definition. LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. auto resultType = llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); // Check that the rank of the attribute type matches the rank of the constant // result type. auto attrType = llvm::cast(getValue().getType()); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") << attrType.getRank() << " != " << resultType.getRank(); } // Check that each of the dimensions match between the two types. for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { if (attrType.getShape()[dim] != resultType.getShape()[dim]) { return emitOpError( "return type shape mismatches its attribute at dimension ") << dim << ": " << attrType.getShape()[dim] << " != " << resultType.getShape()[dim]; } } return success(); } //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// void AddOp::build(OpBuilder& builder, OperationState& state, Value lhs, Value rhs) { state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands({lhs, rhs}); } ParseResult AddOp::parse(OpAsmParser& parser, OperationState& result) { return parseBinaryOp(parser, result); } void AddOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); } void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// void GenericCallOp::build(OpBuilder& builder, OperationState& state, StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands(arguments); state.addAttribute("callee", SymbolRefAttr::get(builder.getContext(), callee)); } CallInterfaceCallable GenericCallOp::getCallableForCallee() { return (*this)->getAttrOfType("callee"); } void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { (*this)->setAttr("callee", mlir::cast(callee)); } Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } MutableOperandRange GenericCallOp::getArgOperandsMutable() { return getInputsMutable(); } //===----------------------------------------------------------------------===// // FuncOp //===----------------------------------------------------------------------===// void FuncOp::build(OpBuilder& builder, OperationState& state, StringRef name, FunctionType type, ArrayRef attrs) { // FunctionOpInterface provides a convenient `build` method that will populate // the state of our FuncOp, and create an entry block. buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); } ParseResult FuncOp::parse(OpAsmParser& parser, OperationState& result) { // Dispatch to the FunctionOpInterface provided utility method that parses the // function operation. auto buildFuncType = [](Builder& builder, ArrayRef argTypes, ArrayRef results, function_interface_impl::VariadicFlag, std::string&) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter& p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. function_interface_impl::printFunctionOp( p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// void MulOp::build(OpBuilder& builder, OperationState& state, Value lhs, Value rhs) { state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands({lhs, rhs}); } ParseResult MulOp::parse(OpAsmParser& parser, OperationState& result) { return parseBinaryOp(parser, result); } void MulOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); } void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// LogicalResult ReturnOp::verify() { // We know that the parent operation is a function, because of the 'HasParent' // trait attached to the operation definition. auto function = cast((*this)->getParentOp()); /// ReturnOps can only have a single optional operand. if (getNumOperands() > 1) return emitOpError() << "expects at most 1 return operand"; // The operand number and types must match the function signature. const auto& results = function.getFunctionType().getResults(); if (getNumOperands() != results.size()) return emitOpError() << "does not return the same number of values (" << getNumOperands() << ") as the enclosing function (" << results.size() << ")"; // If the operation does not have an input, we are done. if (!hasOperand()) return success(); auto inputType = *operand_type_begin(); auto resultType = results.front(); // Check that the result type of the function matches the operand type. if (inputType == resultType || llvm::isa(inputType) || llvm::isa(resultType)) return success(); return emitError() << "type of return operand (" << inputType << ") doesn't match function result type (" << resultType << ")"; } //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// void TransposeOp::build(OpBuilder& builder, OperationState& state, Value value) { state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands(value); } LogicalResult TransposeOp::verify() { auto inputType = llvm::dyn_cast(getOperand().getType()); auto resultType = llvm::dyn_cast(getType()); if (!inputType || !resultType) return success(); auto inputShape = inputType.getShape(); if (!std::equal(inputShape.begin(), inputShape.end(), resultType.getShape().rbegin())) { return emitError() << "expected result shape to be a transpose of the input"; } return success(); } void TransposeOp::inferShapes() { // Transpose will reverse the shape of tensor. auto tensorType = llvm::cast(getOperand().getType()); // And assume that transpose only applies for matrix. SmallVector dimensions(llvm::reverse(tensorType.getShape())); getResult().setType(RankedTensorType::get(dimensions, tensorType.getElementType())); } // CastOp bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (inputs.size() != 1 || outputs.size() != 1) { return false; } const auto inputTensorType = mlir::dyn_cast(inputs.front()); const auto outputTensorType = mlir::dyn_cast(outputs.front()); if (!inputTensorType || !outputTensorType || inputTensorType.getElementType() != outputTensorType.getElementType()) { return false; } // If both have rank, they must be to the size. // And the known size can be cast into unknown size. return !inputTensorType.hasRank() || !outputTensorType.hasRank() || inputTensorType == outputTensorType; } void CastOp::inferShapes() { getResult().setType(getInput().getType()); } #define GET_OP_CLASSES #include "hello/Ops.cpp.inc"