From 8d2f844e2bd097e9b08cabe1436479dfbbe1bc13 Mon Sep 17 00:00:00 2001 From: jackfiled Date: Mon, 2 Jun 2025 16:17:45 +0800 Subject: [PATCH] feat: toy tutorial chapter 2. --- CMakeLists.txt | 23 +- examples/multiply_transpose.hello | 26 ++ include/CMakeLists.txt | 1 + include/Dialect.h | 24 ++ include/MLIRGen.h | 24 ++ include/hello/CMakeLists.txt | 6 + include/hello/Ops.td | 316 ++++++++++++++++++++ lib/Dialect.cpp | 313 ++++++++++++++++++++ lib/MLIRGen.cpp | 464 ++++++++++++++++++++++++++++++ main.cpp | 77 ++++- 10 files changed, 1269 insertions(+), 5 deletions(-) create mode 100644 examples/multiply_transpose.hello create mode 100644 include/CMakeLists.txt create mode 100644 include/Dialect.h create mode 100644 include/MLIRGen.h create mode 100644 include/hello/CMakeLists.txt create mode 100644 include/hello/Ops.td create mode 100644 lib/Dialect.cpp create mode 100644 lib/MLIRGen.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 529cabd..18c7ae6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,19 +20,36 @@ include(AddLLVM) include(AddMLIR) include(HandleLLVMOptions) -message(${MLIR_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${MLIR_INCLUDE_DIRS}) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) include_directories(include) +# Add include directory in cmake output directory for lint. +include_directories(${CMAKE_BINARY_DIR}/include) +add_subdirectory(include) -add_library(SyntaxNode SHARED lib/SyntaxNode.cpp include/SyntaxNode.h include/Parser.h include/Lexer.h) +add_library(SyntaxNode STATIC + lib/SyntaxNode.cpp + lib/Dialect.cpp + lib/MLIRGen.cpp + include/SyntaxNode.h + include/Parser.h + include/Lexer.h +) + +add_dependencies(SyntaxNode HelloOpsIncGen) target_link_libraries(SyntaxNode PRIVATE - MLIRSupport) + MLIRSupport + MLIRAnalysis + MLIRFunctionInterfaces + MLIRIR + MLIRParser + MLIRSideEffectInterfaces + MLIRTransforms) add_executable(hello-mlir main.cpp) diff --git a/examples/multiply_transpose.hello b/examples/multiply_transpose.hello new file mode 100644 index 0000000..e174055 --- /dev/null +++ b/examples/multiply_transpose.hello @@ -0,0 +1,26 @@ +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + var a = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <3, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return <3, 2>. + var d = multiply_transpose(b, a); + + # A new call with <3, 2> (instead of <2, 3>) for both dimensions will + # trigger another specialization of `multiply_transpose`. + var e = multiply_transpose(c, d); + + # Finally, calling into `multiply_transpose` with incompatible shapes + # (<2, 3> and <3, 2>) will trigger a shape inference error. + var f = multiply_transpose(a, c); +} \ No newline at end of file diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt new file mode 100644 index 0000000..cc7bded --- /dev/null +++ b/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(hello) \ No newline at end of file diff --git a/include/Dialect.h b/include/Dialect.h new file mode 100644 index 0000000..7a8924b --- /dev/null +++ b/include/Dialect.h @@ -0,0 +1,24 @@ +// +// Created by ricardo on 29/05/25. +// + +#ifndef DIALECT_H +#define DIALECT_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +/// Include the auto-generated header file containing the declaration of the toy +/// dialect. +#include "hello/Dialect.h.inc" + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "hello/Ops.h.inc" + +#endif //DIALECT_H diff --git a/include/MLIRGen.h b/include/MLIRGen.h new file mode 100644 index 0000000..4636057 --- /dev/null +++ b/include/MLIRGen.h @@ -0,0 +1,24 @@ +// +// Created by ricardo on 29/05/25. +// + +#ifndef MLIRGEN_H +#define MLIRGEN_H + +#include + +#include "SyntaxNode.h" + +namespace mlir +{ + class MLIRContext; + template + class OwningOpRef; +} + +namespace hello +{ + mlir::OwningOpRef mlirGen(mlir::MLIRContext& context, Module& helloModule); +} + +#endif //MLIRGEN_H diff --git a/include/hello/CMakeLists.txt b/include/hello/CMakeLists.txt new file mode 100644 index 0000000..f53458d --- /dev/null +++ b/include/hello/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +add_public_tablegen_target(HelloOpsIncGen) diff --git a/include/hello/Ops.td b/include/hello/Ops.td new file mode 100644 index 0000000..87078d3 --- /dev/null +++ b/include/hello/Ops.td @@ -0,0 +1,316 @@ +#ifndef HELLO_OPS +#define HELLO_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def Hello_Dialect : Dialect { + let name = "hello"; + let cppNamespace = "::mlir::hello"; +} + +class Hello_Op traits = []> : Op; + + +//===----------------------------------------------------------------------===// +// Hello Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +// We define a hello operation by inheriting from our base 'Hello_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'Pure' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Hello_Op<"constant", [Pure]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = hello.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> + : tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<(ins "DenseElementsAttr":$value), [{ + build($_builder, $_state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<(ins "double":$value)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +def AddOp : Hello_Op<"add"> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +def FuncOp : Hello_Op<"func", [ + FunctionOpInterface, IsolatedFromAbove + ]> { + let summary = "user defined function operation"; + let description = [{ + The "hello.func" operation represents a user defined function. These are + callable SSA-region operations that contain hello computations. + + Example: + + ```mlir + hello.func @main() { + %0 = hello.constant dense<5.500000e+00> : tensor + %1 = hello.reshape(%0 : tensor) to tensor<2x2xf64> + hello.print %1 : tensor<2x2xf64> + hello.return + } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs) + >]; + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } + }]; + + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +def GenericCallOp : Hello_Op<"generic_call"> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = hello.generic_call @my_func(%1, %3) + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Specialize assembly printing and parsing using a declarative format. + let assemblyFormat = [{ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> + ]; +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +def MulOp : Hello_Op<"mul"> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// PrintOp +//===----------------------------------------------------------------------===// + +def PrintOp : Hello_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + let arguments = (ins F64Tensor:$input); + + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +def ReshapeOp : Hello_Op<"reshape"> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = hello.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +def ReturnOp : Hello_Op<"return", [Pure, HasParent<"FuncOp">, + Terminator]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + hello.func @foo() -> tensor<2xf64> { + ... + hello.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // The return operation only emits the input in the format if it is present. + let assemblyFormat = "($input^ `:` type($input))? attr-dict "; + + // Allow building a ReturnOp with no return operand. + let builders = [ + OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + ]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +def TransposeOp : Hello_Op<"transpose"> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<(ins "Value":$input)> + ]; + + // Invoke a static verify method to verify this transpose operation. + let hasVerifier = 1; +} + +#endif diff --git a/lib/Dialect.cpp b/lib/Dialect.cpp new file mode 100644 index 0000000..ce7c2c7 --- /dev/null +++ b/lib/Dialect.cpp @@ -0,0 +1,313 @@ +// +// Created by ricardo on 29/05/25. +// + +#include "Dialect.h" +#include "hello/Dialect.cpp.inc" + +#include + +void mlir::hello::HelloDialect::initialize() +{ + addOperations< +#define GET_OP_LIST +#include "hello/Ops.cpp.inc" + >(); +} + +using namespace mlir; +using namespace mlir::hello; + + +/// A generalized parser for binary operations. This parses the different forms +/// of 'printBinaryOp' below. +static ParseResult parseBinaryOp(OpAsmParser& parser, + OperationState& result) +{ + SmallVector operands; + llvm::SMLoc operandsLoc = parser.getCurrentLocation(); + mlir::Type type; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return mlir::failure(); + + // If the type is a function type, it contains the input and result types of + // this operation. + if (mlir::FunctionType funcType = llvm::dyn_cast(type)) + { + if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, + result.operands)) + return mlir::failure(); + result.addTypes(funcType.getResults()); + return mlir::success(); + } + + // Otherwise, the parsed type is the type of both operands and results. + if (parser.resolveOperands(operands, type, result.operands)) + return mlir::failure(); + result.addTypes(type); + return mlir::success(); +} + +/// A generalized printer for binary operations. It prints in two different +/// forms depending on if all of the types match. +static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op) +{ + printer << " " << op->getOperands(); + printer.printOptionalAttrDict(op->getAttrs()); + printer << " : "; + + // If all of the types are the same, print the type directly. + mlir::Type resultType = *op->result_type_begin(); + if (llvm::all_of(op->getOperandTypes(), + [=](mlir::Type type) { return type == resultType; })) + { + printer << resultType; + return; + } + + // Otherwise, print a functional type. + printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state, + double value) +{ + auto dataType = RankedTensorType::get({}, builder.getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + build(builder, state, dataType, dataAttribute); +} + +/// The 'OpAsmParser' class provides a collection of methods for parsing +/// various punctuation, as well as attributes, operands, types, etc. Each of +/// these methods returns a `ParseResult`. This class is a wrapper around +/// `LogicalResult` that can be converted to a boolean `true` value on failure, +/// or `false` on success. This allows for easily chaining together a set of +/// parser rules. These rules are used to populate an `mlir::OperationState` +/// similarly to the `build` methods described above. +mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser, + mlir::OperationState& result) +{ + mlir::DenseElementsAttr value; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(value, "value", result.attributes)) + return failure(); + + result.addTypes(value.getType()); + return success(); +} + +/// The 'OpAsmPrinter' class is a stream that allows for formatting +/// strings, attributes, operands, types, etc. +void ConstantOp::print(mlir::OpAsmPrinter& printer) +{ + printer << " "; + printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); + printer << getValue(); +} + +/// Verifier for the constant operation. This corresponds to the +/// `let hasVerifier = 1` in the op definition. +mlir::LogicalResult ConstantOp::verify() +{ + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = llvm::dyn_cast(getResult().getType()); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = llvm::cast(getValue().getType()); + if (attrType.getRank() != resultType.getRank()) + { + return emitOpError("return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) + { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) + { + return emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void AddOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, + mlir::Value lhs, mlir::Value rhs) +{ + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult AddOp::parse(mlir::OpAsmParser& parser, + mlir::OperationState& result) +{ + return parseBinaryOp(parser, result); +} + +void AddOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +void GenericCallOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, + StringRef callee, ArrayRef arguments) +{ + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, + llvm::StringRef name, mlir::FunctionType type, + llvm::ArrayRef attrs) +{ + // FunctionOpInterface provides a convenient `build` method that will populate + // the state of our FuncOp, and create an entry block. + buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); +} + +mlir::ParseResult FuncOp::parse(OpAsmParser& parser, + OperationState& result) +{ + // Dispatch to the FunctionOpInterface provided utility method that parses the + // function operation. + auto buildFuncType = + [](Builder& builder, ArrayRef argTypes, + ArrayRef results, + function_interface_impl::VariadicFlag, + std::string&) + { + return builder.getFunctionType(argTypes, results); + }; + + return mlir::function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(mlir::OpAsmPrinter& p) +{ + // Dispatch to the FunctionOpInterface provided utility method that prints the + // function operation. + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +void MulOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, + mlir::Value lhs, mlir::Value rhs) +{ + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult MulOp::parse(mlir::OpAsmParser& parser, + mlir::OperationState& result) +{ + return parseBinaryOp(parser, result); +} + +void MulOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult ReturnOp::verify() +{ + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast((*this)->getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (getNumOperands() > 1) + return emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto& results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError() << "does not return the same number of values (" + << getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!hasOperand()) + return mlir::success(); + + auto inputType = *operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || llvm::isa(inputType) || + llvm::isa(resultType)) + return mlir::success(); + + return emitError() << "type of return operand (" << inputType + << ") doesn't match function result type (" << resultType + << ")"; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +void TransposeOp::build(mlir::OpBuilder& builder, mlir::OperationState& state, + mlir::Value value) +{ + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(value); +} + +mlir::LogicalResult TransposeOp::verify() +{ + auto inputType = llvm::dyn_cast(getOperand().getType()); + auto resultType = llvm::dyn_cast(getType()); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) + { + return emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + + +#define GET_OP_CLASSES +#include "hello/Ops.cpp.inc" diff --git a/lib/MLIRGen.cpp b/lib/MLIRGen.cpp new file mode 100644 index 0000000..19d70fc --- /dev/null +++ b/lib/MLIRGen.cpp @@ -0,0 +1,464 @@ +// +// Created by ricardo on 29/05/25. +// +#include "MLIRGen.h" + +#include + +#include "Dialect.h" + +#include +#include +#include +#include + +#include + +using namespace mlir::hello; +using namespace hello; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::ScopedFatalErrorHandler; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace +{ + class MLIRGenImpl + { + public: + MLIRGenImpl(mlir::MLIRContext& context) : builder(&context) + { + } + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(Module& moduleAST) + { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (Function& f : moduleAST) + mlirGen(f); + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (mlir::failed(mlir::verify(theModule))) + { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + + private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(const Location& loc) + { + return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) + { + if (symbolTable.count(var)) + return mlir::failure(); + symbolTable.insert(var, value); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + FuncOp mlirGen(FunctionPrototype& proto) + { + auto location = loc(proto.getLocation()); + + // This is a generic function, the return type will be inferred later. + // Arguments type are uniformly unranked tensors. + llvm::SmallVector argTypes(proto.getParameters().size(), + getType(ValueType{})); + auto funcType = builder.getFunctionType(argTypes, std::nullopt); + return builder.create(location, proto.getName(), + funcType); + } + + /// Emit a new function and add it to the MLIR module. + FuncOp mlirGen(Function& funcAST) + { + // Create a scope in the symbol table to hold variable declarations. + llvm::ScopedHashTableScope varScope(symbolTable); + + // Create an MLIR function for the given prototype. + builder.setInsertionPointToEnd(theModule.getBody()); + FuncOp function = mlirGen(*funcAST.getPrototype()); + if (!function) + return nullptr; + + // Let's start the body of the function now! + mlir::Block& entryBlock = function.front(); + auto protoArgs = funcAST.getPrototype()->getParameters(); + + // Declare all the function arguments in the symbol table. + for (const auto nameValue : + llvm::zip(protoArgs, entryBlock.getArguments())) + { + if (failed(declare(std::get<0>(nameValue)->getName(), + std::get<1>(nameValue)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) + { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) + { + builder.create(loc(funcAST.getPrototype()->getLocation())); + } + else if (returnOp.hasOperand()) + { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType( + function.getFunctionType().getInputs(), getType(ValueType{}))); + } + + return function; + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExpression& binop) + { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLeft()); + if (!lhs) + return nullptr; + mlir::Value rhs = mlirGen(*binop.getRight()); + if (!rhs) + return nullptr; + auto location = loc(binop.getLocation()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOperator()) + { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + default: + emitError(location, "invalid binary operator '") << binop.getOperator() << "'"; + return nullptr; + } + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExpression& expr) + { + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.getLocation()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExpression& ret) + { + auto location = loc(ret.getLocation()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getReturnExpression().has_value()) + { + expr = mlirGen(**ret.getReturnExpression()); + if (!expr) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, + expr ? ArrayRef(expr) : ArrayRef()); + return mlir::success(); + } + + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::Value mlirGen(LiteralExpression& lit) + { + auto type = getType(lit.getDimensions()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDimensions().begin(), lit.getDimensions().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDimensions(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data)); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.getLocation()), type, dataAttribute); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExpressionNodeBase& expr, std::vector& data) + { + if (auto* lit = dyn_cast(&expr)) + { + for (auto& value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExpression& call) + { + llvm::StringRef callee = call.getName(); + auto location = loc(call.getLocation()); + + // Codegen the operands first. + SmallVector operands; + for (auto& expr : call.getArguments()) + { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builtin calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") + { + if (call.getArguments().size() != 1) + { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to + // user-defined functions are mapped to a custom call that takes the callee + // name as an attribute. + return builder.create(location, callee, operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExpression& call) + { + auto arg = mlirGen(*call.getArgument()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.getLocation()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExpression& num) + { + return builder.create(loc(num.getLocation()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExpressionNodeBase& expr) + { + switch (expr.getKind()) + { + case ExpressionNodeBase::BinaryOperation: + return mlirGen(cast(expr)); + case ExpressionNodeBase::Variable: + return mlirGen(cast(expr)); + case ExpressionNodeBase::Literal: + return mlirGen(cast(expr)); + case ExpressionNodeBase::Call: + return mlirGen(cast(expr)); + case ExpressionNodeBase::Number: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.getLocation())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VariableDeclarationExpression& vardecl) + { + auto* init = vardecl.getInitialValue(); + if (!init) + { + emitError(loc(vardecl.getLocation()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) + { + value = builder.create(loc(vardecl.getLocation()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExpressionList& blockAST) + { + llvm::ScopedHashTableScope varScope(symbolTable); + for (auto& expr : blockAST) + { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto* vardecl = dyn_cast(expr.get())) + { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto* ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto* print = dyn_cast(expr.get())) + { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) + { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). + mlir::Type getType(const ValueType& type) { return getType(type.shape); } + }; +} + +namespace hello +{ + mlir::OwningOpRef mlirGen(mlir::MLIRContext& context, Module& helloModule) + { + return MLIRGenImpl(context).mlirGen(helloModule); + } +} diff --git a/main.cpp b/main.cpp index 0d99340..5bd6edc 100644 --- a/main.cpp +++ b/main.cpp @@ -4,6 +4,19 @@ #include #include #include +#include +#include +#include +#include +#include + +#include "Dialect.h" +#include "MLIRGen.h" + +namespace mlir +{ + class ModuleOp; +} static llvm::cl::opt inputFilename(llvm::cl::Positional, llvm::cl::desc(""), @@ -12,11 +25,21 @@ static llvm::cl::opt inputFilename(llvm::cl::Positional, namespace { - enum Action { None, DumpSyntaxNode }; + enum Action { None, DumpSyntaxNode, DumpMLIR }; + + enum InputType { Hello, MLIR }; } +static llvm::cl::opt inputType("x", llvm::cl::init(Hello), + llvm::cl::desc("Decided the kind of input desired."), + llvm::cl::values( + clEnumValN(Hello, "hello", "load the input file as a hello source.")), + llvm::cl::values( + clEnumValN(MLIR, "mlir", "load the input file as a mlir source."))); + static llvm::cl::opt emitAction("emit", llvm::cl::desc("Select the kind of output desired"), - llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node"))); + llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")), + llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code"))); std::unique_ptr parseInputFile(llvm::StringRef filename) { @@ -33,6 +56,53 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) return parser.parseModule(); } +int dumpMLIR() +{ + mlir::MLIRContext context; + context.getOrLoadDialect(); + + if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir")) + { + auto module = parseInputFile(inputFilename); + if (module == nullptr) + { + return 1; + } + mlir::OwningOpRef mlirModule = hello::mlirGen(context, *module); + + if (!mlirModule) + { + return 1; + } + + mlirModule->dump(); + return 0; + } + + // Then the input file is mlir + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code ec = fileOrErr.getError()) + { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return 1; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + mlir::OwningOpRef module = + mlir::parseSourceFile(sourceMgr, &context); + if (!module) + { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 1; + } + + module->dump(); + return 0; +} + int main(int argc, char** argv) { llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n"); @@ -49,6 +119,9 @@ int main(int argc, char** argv) case DumpSyntaxNode: module->dump(); return 0; + case DumpMLIR: + dumpMLIR(); + return 0; default: llvm::errs() << "Unrecognized action\n"; return 1;