feat: toy tutorial chapter 2.
This commit is contained in:
parent
1a64b78ef8
commit
8d2f844e2b
|
@ -20,19 +20,36 @@ include(AddLLVM)
|
||||||
include(AddMLIR)
|
include(AddMLIR)
|
||||||
include(HandleLLVMOptions)
|
include(HandleLLVMOptions)
|
||||||
|
|
||||||
message(${MLIR_INCLUDE_DIRS})
|
|
||||||
include_directories(${LLVM_INCLUDE_DIRS})
|
include_directories(${LLVM_INCLUDE_DIRS})
|
||||||
include_directories(${MLIR_INCLUDE_DIRS})
|
include_directories(${MLIR_INCLUDE_DIRS})
|
||||||
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
||||||
add_definitions(${LLVM_DEFINITIONS})
|
add_definitions(${LLVM_DEFINITIONS})
|
||||||
|
|
||||||
include_directories(include)
|
include_directories(include)
|
||||||
|
# Add include directory in cmake output directory for lint.
|
||||||
|
include_directories(${CMAKE_BINARY_DIR}/include)
|
||||||
|
add_subdirectory(include)
|
||||||
|
|
||||||
add_library(SyntaxNode SHARED lib/SyntaxNode.cpp include/SyntaxNode.h include/Parser.h include/Lexer.h)
|
add_library(SyntaxNode STATIC
|
||||||
|
lib/SyntaxNode.cpp
|
||||||
|
lib/Dialect.cpp
|
||||||
|
lib/MLIRGen.cpp
|
||||||
|
include/SyntaxNode.h
|
||||||
|
include/Parser.h
|
||||||
|
include/Lexer.h
|
||||||
|
)
|
||||||
|
|
||||||
|
add_dependencies(SyntaxNode HelloOpsIncGen)
|
||||||
|
|
||||||
target_link_libraries(SyntaxNode
|
target_link_libraries(SyntaxNode
|
||||||
PRIVATE
|
PRIVATE
|
||||||
MLIRSupport)
|
MLIRSupport
|
||||||
|
MLIRAnalysis
|
||||||
|
MLIRFunctionInterfaces
|
||||||
|
MLIRIR
|
||||||
|
MLIRParser
|
||||||
|
MLIRSideEffectInterfaces
|
||||||
|
MLIRTransforms)
|
||||||
|
|
||||||
add_executable(hello-mlir main.cpp)
|
add_executable(hello-mlir main.cpp)
|
||||||
|
|
||||||
|
|
26
examples/multiply_transpose.hello
Normal file
26
examples/multiply_transpose.hello
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
# User defined generic function that operates on unknown shaped arguments.
|
||||||
|
def multiply_transpose(a, b) {
|
||||||
|
return transpose(a) * transpose(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
def main() {
|
||||||
|
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
|
||||||
|
var a = [[1, 2, 3], [4, 5, 6]];
|
||||||
|
var b<2, 3> = [1, 2, 3, 4, 5, 6];
|
||||||
|
|
||||||
|
# This call will specialize `multiply_transpose` with <2, 3> for both
|
||||||
|
# arguments and deduce a return type of <3, 2> in initialization of `c`.
|
||||||
|
var c = multiply_transpose(a, b);
|
||||||
|
|
||||||
|
# A second call to `multiply_transpose` with <2, 3> for both arguments will
|
||||||
|
# reuse the previously specialized and inferred version and return <3, 2>.
|
||||||
|
var d = multiply_transpose(b, a);
|
||||||
|
|
||||||
|
# A new call with <3, 2> (instead of <2, 3>) for both dimensions will
|
||||||
|
# trigger another specialization of `multiply_transpose`.
|
||||||
|
var e = multiply_transpose(c, d);
|
||||||
|
|
||||||
|
# Finally, calling into `multiply_transpose` with incompatible shapes
|
||||||
|
# (<2, 3> and <3, 2>) will trigger a shape inference error.
|
||||||
|
var f = multiply_transpose(a, c);
|
||||||
|
}
|
1
include/CMakeLists.txt
Normal file
1
include/CMakeLists.txt
Normal file
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(hello)
|
24
include/Dialect.h
Normal file
24
include/Dialect.h
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
//
|
||||||
|
// Created by ricardo on 29/05/25.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef DIALECT_H
|
||||||
|
#define DIALECT_H
|
||||||
|
|
||||||
|
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/SymbolTable.h"
|
||||||
|
#include "mlir/Interfaces/CallInterfaces.h"
|
||||||
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||||
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
|
||||||
|
/// Include the auto-generated header file containing the declaration of the toy
|
||||||
|
/// dialect.
|
||||||
|
#include "hello/Dialect.h.inc"
|
||||||
|
|
||||||
|
/// Include the auto-generated header file containing the declarations of the
|
||||||
|
/// toy operations.
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "hello/Ops.h.inc"
|
||||||
|
|
||||||
|
#endif //DIALECT_H
|
24
include/MLIRGen.h
Normal file
24
include/MLIRGen.h
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
//
|
||||||
|
// Created by ricardo on 29/05/25.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef MLIRGEN_H
|
||||||
|
#define MLIRGEN_H
|
||||||
|
|
||||||
|
#include <mlir/IR/BuiltinOps.h>
|
||||||
|
|
||||||
|
#include "SyntaxNode.h"
|
||||||
|
|
||||||
|
namespace mlir
|
||||||
|
{
|
||||||
|
class MLIRContext;
|
||||||
|
template <typename OpTy>
|
||||||
|
class OwningOpRef;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace hello
|
||||||
|
{
|
||||||
|
mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext& context, Module& helloModule);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //MLIRGEN_H
|
6
include/hello/CMakeLists.txt
Normal file
6
include/hello/CMakeLists.txt
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS Ops.td)
|
||||||
|
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||||
|
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||||
|
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||||
|
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||||
|
add_public_tablegen_target(HelloOpsIncGen)
|
316
include/hello/Ops.td
Normal file
316
include/hello/Ops.td
Normal file
|
@ -0,0 +1,316 @@
|
||||||
|
#ifndef HELLO_OPS
|
||||||
|
#define HELLO_OPS
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir/Interfaces/FunctionInterfaces.td"
|
||||||
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
|
||||||
|
def Hello_Dialect : Dialect {
|
||||||
|
let name = "hello";
|
||||||
|
let cppNamespace = "::mlir::hello";
|
||||||
|
}
|
||||||
|
|
||||||
|
class Hello_Op<string mnemonic, list<Trait> traits = []> : Op<Hello_Dialect, mnemonic, traits>;
|
||||||
|
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Hello Operations
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstantOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// We define a hello operation by inheriting from our base 'Hello_Op' class above.
|
||||||
|
// Here we provide the mnemonic and a list of traits for the operation. The
|
||||||
|
// constant operation is marked as 'Pure' as it is a pure operation
|
||||||
|
// and may be removed if dead.
|
||||||
|
def ConstantOp : Hello_Op<"constant", [Pure]> {
|
||||||
|
// Provide a summary and description for this operation. This can be used to
|
||||||
|
// auto-generate documentation of the operations within our dialect.
|
||||||
|
let summary = "constant";
|
||||||
|
let description = [{
|
||||||
|
Constant operation turns a literal into an SSA value. The data is attached
|
||||||
|
to the operation as an attribute. For example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%0 = hello.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>
|
||||||
|
: tensor<2x3xf64>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
// The constant operation takes an attribute as the only input.
|
||||||
|
let arguments = (ins F64ElementsAttr:$value);
|
||||||
|
|
||||||
|
// The constant operation returns a single value of TensorType.
|
||||||
|
let results = (outs F64Tensor);
|
||||||
|
|
||||||
|
// Indicate that the operation has a custom parser and printer method.
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
|
||||||
|
// Add custom build methods for the constant operation. These method populates
|
||||||
|
// the `state` that MLIR uses to create operations, i.e. these are used when
|
||||||
|
// using `builder.create<ConstantOp>(...)`.
|
||||||
|
let builders = [
|
||||||
|
// Build a constant with a given constant tensor value.
|
||||||
|
OpBuilder<(ins "DenseElementsAttr":$value), [{
|
||||||
|
build($_builder, $_state, value.getType(), value);
|
||||||
|
}]>,
|
||||||
|
|
||||||
|
// Build a constant with a given constant floating-point value.
|
||||||
|
OpBuilder<(ins "double":$value)>
|
||||||
|
];
|
||||||
|
|
||||||
|
// Indicate that additional verification for this operation is necessary.
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AddOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def AddOp : Hello_Op<"add"> {
|
||||||
|
let summary = "element-wise addition operation";
|
||||||
|
let description = [{
|
||||||
|
The "add" operation performs element-wise addition between two tensors.
|
||||||
|
The shapes of the tensor operands are expected to match.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
|
||||||
|
let results = (outs F64Tensor);
|
||||||
|
|
||||||
|
// Indicate that the operation has a custom parser and printer method.
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
|
||||||
|
// Allow building an AddOp with from the two input operands.
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// FuncOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def FuncOp : Hello_Op<"func", [
|
||||||
|
FunctionOpInterface, IsolatedFromAbove
|
||||||
|
]> {
|
||||||
|
let summary = "user defined function operation";
|
||||||
|
let description = [{
|
||||||
|
The "hello.func" operation represents a user defined function. These are
|
||||||
|
callable SSA-region operations that contain hello computations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
hello.func @main() {
|
||||||
|
%0 = hello.constant dense<5.500000e+00> : tensor<f64>
|
||||||
|
%1 = hello.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||||
|
hello.print %1 : tensor<2x2xf64>
|
||||||
|
hello.return
|
||||||
|
}
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SymbolNameAttr:$sym_name,
|
||||||
|
TypeAttrOf<FunctionType>:$function_type,
|
||||||
|
OptionalAttr<DictArrayAttr>:$arg_attrs,
|
||||||
|
OptionalAttr<DictArrayAttr>:$res_attrs
|
||||||
|
);
|
||||||
|
let regions = (region AnyRegion:$body);
|
||||||
|
|
||||||
|
let builders = [OpBuilder<(ins
|
||||||
|
"StringRef":$name, "FunctionType":$type,
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
|
||||||
|
>];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
// FunctionOpInterface Methods
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Returns the argument types of this function.
|
||||||
|
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
|
||||||
|
|
||||||
|
/// Returns the result types of this function.
|
||||||
|
ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
|
||||||
|
|
||||||
|
Region *getCallableRegion() { return &getBody(); }
|
||||||
|
}];
|
||||||
|
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let skipDefaultBuilders = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GenericCallOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def GenericCallOp : Hello_Op<"generic_call"> {
|
||||||
|
let summary = "generic call operation";
|
||||||
|
let description = [{
|
||||||
|
Generic calls represent calls to a user defined function that needs to
|
||||||
|
be specialized for the shape of its arguments. The callee name is attached
|
||||||
|
as a symbol reference via an attribute. The arguments list must match the
|
||||||
|
arguments expected by the callee. For example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%4 = hello.generic_call @my_func(%1, %3)
|
||||||
|
: (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||||
|
```
|
||||||
|
|
||||||
|
This is only valid if a function named "my_func" exists and takes two
|
||||||
|
arguments.
|
||||||
|
}];
|
||||||
|
|
||||||
|
// The generic call operation takes a symbol reference attribute as the
|
||||||
|
// callee, and inputs for the call.
|
||||||
|
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
|
||||||
|
|
||||||
|
// The generic call operation returns a single value of TensorType.
|
||||||
|
let results = (outs F64Tensor);
|
||||||
|
|
||||||
|
// Specialize assembly printing and parsing using a declarative format.
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
|
||||||
|
}];
|
||||||
|
|
||||||
|
// Add custom build methods for the generic call operation.
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "StringRef":$callee, "ArrayRef<Value>":$arguments)>
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// MulOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def MulOp : Hello_Op<"mul"> {
|
||||||
|
let summary = "element-wise multiplication operation";
|
||||||
|
let description = [{
|
||||||
|
The "mul" operation performs element-wise multiplication between two
|
||||||
|
tensors. The shapes of the tensor operands are expected to match.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
|
||||||
|
let results = (outs F64Tensor);
|
||||||
|
|
||||||
|
// Indicate that the operation has a custom parser and printer method.
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
|
||||||
|
// Allow building a MulOp with from the two input operands.
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PrintOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def PrintOp : Hello_Op<"print"> {
|
||||||
|
let summary = "print operation";
|
||||||
|
let description = [{
|
||||||
|
The "print" builtin operation prints a given input tensor, and produces
|
||||||
|
no results.
|
||||||
|
}];
|
||||||
|
|
||||||
|
// The print operation takes an input tensor to print.
|
||||||
|
let arguments = (ins F64Tensor:$input);
|
||||||
|
|
||||||
|
let assemblyFormat = "$input attr-dict `:` type($input)";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReshapeOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ReshapeOp : Hello_Op<"reshape"> {
|
||||||
|
let summary = "tensor reshape operation";
|
||||||
|
let description = [{
|
||||||
|
Reshape operation is transforming its input tensor into a new tensor with
|
||||||
|
the same number of elements but different shapes. For example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%0 = hello.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins F64Tensor:$input);
|
||||||
|
|
||||||
|
// We expect that the reshape operation returns a statically shaped tensor.
|
||||||
|
let results = (outs StaticShapeTensorOf<[F64]>);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `:` type($input) `)` attr-dict `to` type(results)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReturnOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ReturnOp : Hello_Op<"return", [Pure, HasParent<"FuncOp">,
|
||||||
|
Terminator]> {
|
||||||
|
let summary = "return operation";
|
||||||
|
let description = [{
|
||||||
|
The "return" operation represents a return operation within a function.
|
||||||
|
The operation takes an optional tensor operand and produces no results.
|
||||||
|
The operand type must match the signature of the function that contains
|
||||||
|
the operation. For example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
hello.func @foo() -> tensor<2xf64> {
|
||||||
|
...
|
||||||
|
hello.return %0 : tensor<2xf64>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
// The return operation takes an optional input operand to return. This
|
||||||
|
// value must match the return type of the enclosing function.
|
||||||
|
let arguments = (ins Variadic<F64Tensor>:$input);
|
||||||
|
|
||||||
|
// The return operation only emits the input in the format if it is present.
|
||||||
|
let assemblyFormat = "($input^ `:` type($input))? attr-dict ";
|
||||||
|
|
||||||
|
// Allow building a ReturnOp with no return operand.
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]>
|
||||||
|
];
|
||||||
|
|
||||||
|
// Provide extra utility definitions on the c++ operation class definition.
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
bool hasOperand() { return getNumOperands() != 0; }
|
||||||
|
}];
|
||||||
|
|
||||||
|
// Invoke a static verify method to verify this return operation.
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TransposeOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def TransposeOp : Hello_Op<"transpose"> {
|
||||||
|
let summary = "transpose operation";
|
||||||
|
|
||||||
|
let arguments = (ins F64Tensor:$input);
|
||||||
|
let results = (outs F64Tensor);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `:` type($input) `)` attr-dict `to` type(results)
|
||||||
|
}];
|
||||||
|
|
||||||
|
// Allow building a TransposeOp with from the input operand.
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "Value":$input)>
|
||||||
|
];
|
||||||
|
|
||||||
|
// Invoke a static verify method to verify this transpose operation.
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
313
lib/Dialect.cpp
Normal file
313
lib/Dialect.cpp
Normal file
|
@ -0,0 +1,313 @@
|
||||||
|
//
|
||||||
|
// Created by ricardo on 29/05/25.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "Dialect.h"
|
||||||
|
#include "hello/Dialect.cpp.inc"
|
||||||
|
|
||||||
|
#include <mlir/Interfaces/FunctionImplementation.h>
|
||||||
|
|
||||||
|
void mlir::hello::HelloDialect::initialize()
|
||||||
|
{
|
||||||
|
addOperations<
|
||||||
|
#define GET_OP_LIST
|
||||||
|
#include "hello/Ops.cpp.inc"
|
||||||
|
>();
|
||||||
|
}
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::hello;
|
||||||
|
|
||||||
|
|
||||||
|
/// A generalized parser for binary operations. This parses the different forms
|
||||||
|
/// of 'printBinaryOp' below.
|
||||||
|
static ParseResult parseBinaryOp(OpAsmParser& parser,
|
||||||
|
OperationState& result)
|
||||||
|
{
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
|
||||||
|
llvm::SMLoc operandsLoc = parser.getCurrentLocation();
|
||||||
|
mlir::Type type;
|
||||||
|
if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
|
||||||
|
parser.parseOptionalAttrDict(result.attributes) ||
|
||||||
|
parser.parseColonType(type))
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
// If the type is a function type, it contains the input and result types of
|
||||||
|
// this operation.
|
||||||
|
if (mlir::FunctionType funcType = llvm::dyn_cast<mlir::FunctionType>(type))
|
||||||
|
{
|
||||||
|
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||||
|
result.operands))
|
||||||
|
return mlir::failure();
|
||||||
|
result.addTypes(funcType.getResults());
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, the parsed type is the type of both operands and results.
|
||||||
|
if (parser.resolveOperands(operands, type, result.operands))
|
||||||
|
return mlir::failure();
|
||||||
|
result.addTypes(type);
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A generalized printer for binary operations. It prints in two different
|
||||||
|
/// forms depending on if all of the types match.
|
||||||
|
static void printBinaryOp(mlir::OpAsmPrinter& printer, mlir::Operation* op)
|
||||||
|
{
|
||||||
|
printer << " " << op->getOperands();
|
||||||
|
printer.printOptionalAttrDict(op->getAttrs());
|
||||||
|
printer << " : ";
|
||||||
|
|
||||||
|
// If all of the types are the same, print the type directly.
|
||||||
|
mlir::Type resultType = *op->result_type_begin();
|
||||||
|
if (llvm::all_of(op->getOperandTypes(),
|
||||||
|
[=](mlir::Type type) { return type == resultType; }))
|
||||||
|
{
|
||||||
|
printer << resultType;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, print a functional type.
|
||||||
|
printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstantOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Build a constant operation.
|
||||||
|
/// The builder is passed as an argument, so is the state that this method is
|
||||||
|
/// expected to fill in order to build the operation.
|
||||||
|
void mlir::hello::ConstantOp::build(OpBuilder& builder, OperationState& state,
|
||||||
|
double value)
|
||||||
|
{
|
||||||
|
auto dataType = RankedTensorType::get({}, builder.getF64Type());
|
||||||
|
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
||||||
|
build(builder, state, dataType, dataAttribute);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The 'OpAsmParser' class provides a collection of methods for parsing
|
||||||
|
/// various punctuation, as well as attributes, operands, types, etc. Each of
|
||||||
|
/// these methods returns a `ParseResult`. This class is a wrapper around
|
||||||
|
/// `LogicalResult` that can be converted to a boolean `true` value on failure,
|
||||||
|
/// or `false` on success. This allows for easily chaining together a set of
|
||||||
|
/// parser rules. These rules are used to populate an `mlir::OperationState`
|
||||||
|
/// similarly to the `build` methods described above.
|
||||||
|
mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser& parser,
|
||||||
|
mlir::OperationState& result)
|
||||||
|
{
|
||||||
|
mlir::DenseElementsAttr value;
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) ||
|
||||||
|
parser.parseAttribute(value, "value", result.attributes))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
result.addTypes(value.getType());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The 'OpAsmPrinter' class is a stream that allows for formatting
|
||||||
|
/// strings, attributes, operands, types, etc.
|
||||||
|
void ConstantOp::print(mlir::OpAsmPrinter& printer)
|
||||||
|
{
|
||||||
|
printer << " ";
|
||||||
|
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
|
||||||
|
printer << getValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Verifier for the constant operation. This corresponds to the
|
||||||
|
/// `let hasVerifier = 1` in the op definition.
|
||||||
|
mlir::LogicalResult ConstantOp::verify()
|
||||||
|
{
|
||||||
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
|
// must match the shape of the attribute holding the data.
|
||||||
|
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
||||||
|
if (!resultType)
|
||||||
|
return success();
|
||||||
|
|
||||||
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
|
// result type.
|
||||||
|
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
||||||
|
if (attrType.getRank() != resultType.getRank())
|
||||||
|
{
|
||||||
|
return emitOpError("return type must match the one of the attached value "
|
||||||
|
"attribute: ")
|
||||||
|
<< attrType.getRank() << " != " << resultType.getRank();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that each of the dimensions match between the two types.
|
||||||
|
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim)
|
||||||
|
{
|
||||||
|
if (attrType.getShape()[dim] != resultType.getShape()[dim])
|
||||||
|
{
|
||||||
|
return emitOpError(
|
||||||
|
"return type shape mismatches its attribute at dimension ")
|
||||||
|
<< dim << ": " << attrType.getShape()[dim]
|
||||||
|
<< " != " << resultType.getShape()[dim];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AddOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void AddOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
|
||||||
|
mlir::Value lhs, mlir::Value rhs)
|
||||||
|
{
|
||||||
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||||
|
state.addOperands({lhs, rhs});
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::ParseResult AddOp::parse(mlir::OpAsmParser& parser,
|
||||||
|
mlir::OperationState& result)
|
||||||
|
{
|
||||||
|
return parseBinaryOp(parser, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GenericCallOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void GenericCallOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
|
||||||
|
StringRef callee, ArrayRef<mlir::Value> arguments)
|
||||||
|
{
|
||||||
|
// Generic call always returns an unranked Tensor initially.
|
||||||
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||||
|
state.addOperands(arguments);
|
||||||
|
state.addAttribute("callee",
|
||||||
|
mlir::SymbolRefAttr::get(builder.getContext(), callee));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// FuncOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void FuncOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
|
||||||
|
llvm::StringRef name, mlir::FunctionType type,
|
||||||
|
llvm::ArrayRef<mlir::NamedAttribute> attrs)
|
||||||
|
{
|
||||||
|
// FunctionOpInterface provides a convenient `build` method that will populate
|
||||||
|
// the state of our FuncOp, and create an entry block.
|
||||||
|
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::ParseResult FuncOp::parse(OpAsmParser& parser,
|
||||||
|
OperationState& result)
|
||||||
|
{
|
||||||
|
// Dispatch to the FunctionOpInterface provided utility method that parses the
|
||||||
|
// function operation.
|
||||||
|
auto buildFuncType =
|
||||||
|
[](Builder& builder, ArrayRef<Type> argTypes,
|
||||||
|
ArrayRef<Type> results,
|
||||||
|
function_interface_impl::VariadicFlag,
|
||||||
|
std::string&)
|
||||||
|
{
|
||||||
|
return builder.getFunctionType(argTypes, results);
|
||||||
|
};
|
||||||
|
|
||||||
|
return mlir::function_interface_impl::parseFunctionOp(
|
||||||
|
parser, result, /*allowVariadic=*/false,
|
||||||
|
getFunctionTypeAttrName(result.name), buildFuncType,
|
||||||
|
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuncOp::print(mlir::OpAsmPrinter& p)
|
||||||
|
{
|
||||||
|
// Dispatch to the FunctionOpInterface provided utility method that prints the
|
||||||
|
// function operation.
|
||||||
|
mlir::function_interface_impl::printFunctionOp(
|
||||||
|
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
|
||||||
|
getArgAttrsAttrName(), getResAttrsAttrName());
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// MulOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void MulOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
|
||||||
|
mlir::Value lhs, mlir::Value rhs)
|
||||||
|
{
|
||||||
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||||
|
state.addOperands({lhs, rhs});
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::ParseResult MulOp::parse(mlir::OpAsmParser& parser,
|
||||||
|
mlir::OperationState& result)
|
||||||
|
{
|
||||||
|
return parseBinaryOp(parser, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MulOp::print(mlir::OpAsmPrinter& p) { printBinaryOp(p, *this); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReturnOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
mlir::LogicalResult ReturnOp::verify()
|
||||||
|
{
|
||||||
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
|
// trait attached to the operation definition.
|
||||||
|
auto function = cast<FuncOp>((*this)->getParentOp());
|
||||||
|
|
||||||
|
/// ReturnOps can only have a single optional operand.
|
||||||
|
if (getNumOperands() > 1)
|
||||||
|
return emitOpError() << "expects at most 1 return operand";
|
||||||
|
|
||||||
|
// The operand number and types must match the function signature.
|
||||||
|
const auto& results = function.getFunctionType().getResults();
|
||||||
|
if (getNumOperands() != results.size())
|
||||||
|
return emitOpError() << "does not return the same number of values ("
|
||||||
|
<< getNumOperands() << ") as the enclosing function ("
|
||||||
|
<< results.size() << ")";
|
||||||
|
|
||||||
|
// If the operation does not have an input, we are done.
|
||||||
|
if (!hasOperand())
|
||||||
|
return mlir::success();
|
||||||
|
|
||||||
|
auto inputType = *operand_type_begin();
|
||||||
|
auto resultType = results.front();
|
||||||
|
|
||||||
|
// Check that the result type of the function matches the operand type.
|
||||||
|
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||||
|
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||||
|
return mlir::success();
|
||||||
|
|
||||||
|
return emitError() << "type of return operand (" << inputType
|
||||||
|
<< ") doesn't match function result type (" << resultType
|
||||||
|
<< ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TransposeOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void TransposeOp::build(mlir::OpBuilder& builder, mlir::OperationState& state,
|
||||||
|
mlir::Value value)
|
||||||
|
{
|
||||||
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||||
|
state.addOperands(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::LogicalResult TransposeOp::verify()
|
||||||
|
{
|
||||||
|
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||||
|
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||||
|
if (!inputType || !resultType)
|
||||||
|
return mlir::success();
|
||||||
|
|
||||||
|
auto inputShape = inputType.getShape();
|
||||||
|
if (!std::equal(inputShape.begin(), inputShape.end(),
|
||||||
|
resultType.getShape().rbegin()))
|
||||||
|
{
|
||||||
|
return emitError()
|
||||||
|
<< "expected result shape to be a transpose of the input";
|
||||||
|
}
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "hello/Ops.cpp.inc"
|
464
lib/MLIRGen.cpp
Normal file
464
lib/MLIRGen.cpp
Normal file
|
@ -0,0 +1,464 @@
|
||||||
|
//
|
||||||
|
// Created by ricardo on 29/05/25.
|
||||||
|
//
|
||||||
|
#include "MLIRGen.h"
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "Dialect.h"
|
||||||
|
|
||||||
|
#include <mlir/IR/Builders.h>
|
||||||
|
#include <mlir/IR/BuiltinOps.h>
|
||||||
|
#include <mlir/IR/BuiltinTypes.h>
|
||||||
|
#include <mlir/IR/Verifier.h>
|
||||||
|
|
||||||
|
#include <llvm/ADT/ScopedHashTable.h>
|
||||||
|
|
||||||
|
using namespace mlir::hello;
|
||||||
|
using namespace hello;
|
||||||
|
|
||||||
|
using llvm::ArrayRef;
|
||||||
|
using llvm::cast;
|
||||||
|
using llvm::dyn_cast;
|
||||||
|
using llvm::isa;
|
||||||
|
using llvm::ScopedFatalErrorHandler;
|
||||||
|
using llvm::SmallVector;
|
||||||
|
using llvm::StringRef;
|
||||||
|
using llvm::Twine;
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
class MLIRGenImpl
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
MLIRGenImpl(mlir::MLIRContext& context) : builder(&context)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Public API: convert the AST for a Toy module (source file) to an MLIR
|
||||||
|
/// Module operation.
|
||||||
|
mlir::ModuleOp mlirGen(Module& moduleAST)
|
||||||
|
{
|
||||||
|
// We create an empty MLIR module and codegen functions one at a time and
|
||||||
|
// add them to the module.
|
||||||
|
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||||
|
|
||||||
|
for (Function& f : moduleAST)
|
||||||
|
mlirGen(f);
|
||||||
|
|
||||||
|
// Verify the module after we have finished constructing it, this will check
|
||||||
|
// the structural properties of the IR and invoke any specific verifiers we
|
||||||
|
// have on the Toy operations.
|
||||||
|
if (mlir::failed(mlir::verify(theModule)))
|
||||||
|
{
|
||||||
|
theModule.emitError("module verification error");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return theModule;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// A "module" matches a Toy source file: containing a list of functions.
|
||||||
|
mlir::ModuleOp theModule;
|
||||||
|
|
||||||
|
/// The builder is a helper class to create IR inside a function. The builder
|
||||||
|
/// is stateful, in particular it keeps an "insertion point": this is where
|
||||||
|
/// the next operations will be introduced.
|
||||||
|
mlir::OpBuilder builder;
|
||||||
|
|
||||||
|
/// The symbol table maps a variable name to a value in the current scope.
|
||||||
|
/// Entering a function creates a new scope, and the function arguments are
|
||||||
|
/// added to the mapping. When the processing of a function is terminated, the
|
||||||
|
/// scope is destroyed and the mappings created in this scope are dropped.
|
||||||
|
llvm::ScopedHashTable<StringRef, mlir::Value> symbolTable;
|
||||||
|
|
||||||
|
/// Helper conversion for a Toy AST location to an MLIR location.
|
||||||
|
mlir::Location loc(const Location& loc)
|
||||||
|
{
|
||||||
|
return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line,
|
||||||
|
loc.col);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Declare a variable in the current scope, return success if the variable
|
||||||
|
/// wasn't declared yet.
|
||||||
|
mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value)
|
||||||
|
{
|
||||||
|
if (symbolTable.count(var))
|
||||||
|
return mlir::failure();
|
||||||
|
symbolTable.insert(var, value);
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create the prototype for an MLIR function with as many arguments as the
|
||||||
|
/// provided Toy AST prototype.
|
||||||
|
FuncOp mlirGen(FunctionPrototype& proto)
|
||||||
|
{
|
||||||
|
auto location = loc(proto.getLocation());
|
||||||
|
|
||||||
|
// This is a generic function, the return type will be inferred later.
|
||||||
|
// Arguments type are uniformly unranked tensors.
|
||||||
|
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getParameters().size(),
|
||||||
|
getType(ValueType{}));
|
||||||
|
auto funcType = builder.getFunctionType(argTypes, std::nullopt);
|
||||||
|
return builder.create<FuncOp>(location, proto.getName(),
|
||||||
|
funcType);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emit a new function and add it to the MLIR module.
|
||||||
|
FuncOp mlirGen(Function& funcAST)
|
||||||
|
{
|
||||||
|
// Create a scope in the symbol table to hold variable declarations.
|
||||||
|
llvm::ScopedHashTableScope varScope(symbolTable);
|
||||||
|
|
||||||
|
// Create an MLIR function for the given prototype.
|
||||||
|
builder.setInsertionPointToEnd(theModule.getBody());
|
||||||
|
FuncOp function = mlirGen(*funcAST.getPrototype());
|
||||||
|
if (!function)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
// Let's start the body of the function now!
|
||||||
|
mlir::Block& entryBlock = function.front();
|
||||||
|
auto protoArgs = funcAST.getPrototype()->getParameters();
|
||||||
|
|
||||||
|
// Declare all the function arguments in the symbol table.
|
||||||
|
for (const auto nameValue :
|
||||||
|
llvm::zip(protoArgs, entryBlock.getArguments()))
|
||||||
|
{
|
||||||
|
if (failed(declare(std::get<0>(nameValue)->getName(),
|
||||||
|
std::get<1>(nameValue))))
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the insertion point in the builder to the beginning of the function
|
||||||
|
// body, it will be used throughout the codegen to create operations in this
|
||||||
|
// function.
|
||||||
|
builder.setInsertionPointToStart(&entryBlock);
|
||||||
|
|
||||||
|
// Emit the body of the function.
|
||||||
|
if (mlir::failed(mlirGen(*funcAST.getBody())))
|
||||||
|
{
|
||||||
|
function.erase();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implicitly return void if no return statement was emitted.
|
||||||
|
// FIXME: we may fix the parser instead to always return the last expression
|
||||||
|
// (this would possibly help the REPL case later)
|
||||||
|
ReturnOp returnOp;
|
||||||
|
if (!entryBlock.empty())
|
||||||
|
returnOp = dyn_cast<ReturnOp>(entryBlock.back());
|
||||||
|
if (!returnOp)
|
||||||
|
{
|
||||||
|
builder.create<ReturnOp>(loc(funcAST.getPrototype()->getLocation()));
|
||||||
|
}
|
||||||
|
else if (returnOp.hasOperand())
|
||||||
|
{
|
||||||
|
// Otherwise, if this return operation has an operand then add a result to
|
||||||
|
// the function.
|
||||||
|
function.setType(builder.getFunctionType(
|
||||||
|
function.getFunctionType().getInputs(), getType(ValueType{})));
|
||||||
|
}
|
||||||
|
|
||||||
|
return function;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emit a binary operation
|
||||||
|
mlir::Value mlirGen(BinaryExpression& binop)
|
||||||
|
{
|
||||||
|
// First emit the operations for each side of the operation before emitting
|
||||||
|
// the operation itself. For example if the expression is `a + foo(a)`
|
||||||
|
// 1) First it will visiting the LHS, which will return a reference to the
|
||||||
|
// value holding `a`. This value should have been emitted at declaration
|
||||||
|
// time and registered in the symbol table, so nothing would be
|
||||||
|
// codegen'd. If the value is not in the symbol table, an error has been
|
||||||
|
// emitted and nullptr is returned.
|
||||||
|
// 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
|
||||||
|
// and the result value is returned. If an error occurs we get a nullptr
|
||||||
|
// and propagate.
|
||||||
|
//
|
||||||
|
mlir::Value lhs = mlirGen(*binop.getLeft());
|
||||||
|
if (!lhs)
|
||||||
|
return nullptr;
|
||||||
|
mlir::Value rhs = mlirGen(*binop.getRight());
|
||||||
|
if (!rhs)
|
||||||
|
return nullptr;
|
||||||
|
auto location = loc(binop.getLocation());
|
||||||
|
|
||||||
|
// Derive the operation name from the binary operator. At the moment we only
|
||||||
|
// support '+' and '*'.
|
||||||
|
switch (binop.getOperator())
|
||||||
|
{
|
||||||
|
case '+':
|
||||||
|
return builder.create<AddOp>(location, lhs, rhs);
|
||||||
|
case '*':
|
||||||
|
return builder.create<MulOp>(location, lhs, rhs);
|
||||||
|
default:
|
||||||
|
emitError(location, "invalid binary operator '") << binop.getOperator() << "'";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This is a reference to a variable in an expression. The variable is
|
||||||
|
/// expected to have been declared and so should have a value in the symbol
|
||||||
|
/// table, otherwise emit an error and return nullptr.
|
||||||
|
mlir::Value mlirGen(VariableExpression& expr)
|
||||||
|
{
|
||||||
|
if (auto variable = symbolTable.lookup(expr.getName()))
|
||||||
|
return variable;
|
||||||
|
|
||||||
|
emitError(loc(expr.getLocation()), "error: unknown variable '")
|
||||||
|
<< expr.getName() << "'";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emit a return operation. This will return failure if any generation fails.
|
||||||
|
mlir::LogicalResult mlirGen(ReturnExpression& ret)
|
||||||
|
{
|
||||||
|
auto location = loc(ret.getLocation());
|
||||||
|
|
||||||
|
// 'return' takes an optional expression, handle that case here.
|
||||||
|
mlir::Value expr = nullptr;
|
||||||
|
if (ret.getReturnExpression().has_value())
|
||||||
|
{
|
||||||
|
expr = mlirGen(**ret.getReturnExpression());
|
||||||
|
if (!expr)
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, this return operation has zero operands.
|
||||||
|
builder.create<ReturnOp>(location,
|
||||||
|
expr ? ArrayRef(expr) : ArrayRef<mlir::Value>());
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emit a literal/constant array. It will be emitted as a flattened array of
|
||||||
|
/// data in an Attribute attached to a `toy.constant` operation.
|
||||||
|
/// See documentation on [Attributes](LangRef.md#attributes) for more details.
|
||||||
|
/// Here is an excerpt:
|
||||||
|
///
|
||||||
|
/// Attributes are the mechanism for specifying constant data in MLIR in
|
||||||
|
/// places where a variable is never allowed [...]. They consist of a name
|
||||||
|
/// and a concrete attribute value. The set of expected attributes, their
|
||||||
|
/// structure, and their interpretation are all contextually dependent on
|
||||||
|
/// what they are attached to.
|
||||||
|
///
|
||||||
|
/// Example, the source level statement:
|
||||||
|
/// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
|
||||||
|
/// will be converted to:
|
||||||
|
/// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
|
||||||
|
/// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
|
||||||
|
/// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
|
||||||
|
///
|
||||||
|
mlir::Value mlirGen(LiteralExpression& lit)
|
||||||
|
{
|
||||||
|
auto type = getType(lit.getDimensions());
|
||||||
|
|
||||||
|
// The attribute is a vector with a floating point value per element
|
||||||
|
// (number) in the array, see `collectData()` below for more details.
|
||||||
|
std::vector<double> data;
|
||||||
|
data.reserve(std::accumulate(lit.getDimensions().begin(), lit.getDimensions().end(), 1,
|
||||||
|
std::multiplies<int>()));
|
||||||
|
collectData(lit, data);
|
||||||
|
|
||||||
|
// The type of this attribute is tensor of 64-bit floating-point with the
|
||||||
|
// shape of the literal.
|
||||||
|
mlir::Type elementType = builder.getF64Type();
|
||||||
|
auto dataType = mlir::RankedTensorType::get(lit.getDimensions(), elementType);
|
||||||
|
|
||||||
|
// This is the actual attribute that holds the list of values for this
|
||||||
|
// tensor literal.
|
||||||
|
auto dataAttribute =
|
||||||
|
mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data));
|
||||||
|
|
||||||
|
// Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
|
||||||
|
// method.
|
||||||
|
return builder.create<ConstantOp>(loc(lit.getLocation()), type, dataAttribute);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Recursive helper function to accumulate the data that compose an array
|
||||||
|
/// literal. It flattens the nested structure in the supplied vector. For
|
||||||
|
/// example with this array:
|
||||||
|
/// [[1, 2], [3, 4]]
|
||||||
|
/// we will generate:
|
||||||
|
/// [ 1, 2, 3, 4 ]
|
||||||
|
/// Individual numbers are represented as doubles.
|
||||||
|
/// Attributes are the way MLIR attaches constant to operations.
|
||||||
|
void collectData(ExpressionNodeBase& expr, std::vector<double>& data)
|
||||||
|
{
|
||||||
|
if (auto* lit = dyn_cast<LiteralExpression>(&expr))
|
||||||
|
{
|
||||||
|
for (auto& value : lit->getValues())
|
||||||
|
collectData(*value, data);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(isa<NumberExpression>(expr) && "expected literal or number expr");
|
||||||
|
data.push_back(cast<NumberExpression>(expr).getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emit a call expression. It emits specific operations for the `transpose`
|
||||||
|
/// builtin. Other identifiers are assumed to be user-defined functions.
|
||||||
|
mlir::Value mlirGen(CallExpression& call)
|
||||||
|
{
|
||||||
|
llvm::StringRef callee = call.getName();
|
||||||
|
auto location = loc(call.getLocation());
|
||||||
|
|
||||||
|
// Codegen the operands first.
|
||||||
|
SmallVector<mlir::Value, 4> operands;
|
||||||
|
for (auto& expr : call.getArguments())
|
||||||
|
{
|
||||||
|
auto arg = mlirGen(*expr);
|
||||||
|
if (!arg)
|
||||||
|
return nullptr;
|
||||||
|
operands.push_back(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builtin calls have their custom operation, meaning this is a
|
||||||
|
// straightforward emission.
|
||||||
|
if (callee == "transpose")
|
||||||
|
{
|
||||||
|
if (call.getArguments().size() != 1)
|
||||||
|
{
|
||||||
|
emitError(location, "MLIR codegen encountered an error: toy.transpose "
|
||||||
|
"does not accept multiple arguments");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return builder.create<TransposeOp>(location, operands[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise this is a call to a user-defined function. Calls to
|
||||||
|
// user-defined functions are mapped to a custom call that takes the callee
|
||||||
|
// name as an attribute.
|
||||||
|
return builder.create<GenericCallOp>(location, callee, operands);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emit a print expression. It emits specific operations for two builtins:
|
||||||
|
/// transpose(x) and print(x).
|
||||||
|
mlir::LogicalResult mlirGen(PrintExpression& call)
|
||||||
|
{
|
||||||
|
auto arg = mlirGen(*call.getArgument());
|
||||||
|
if (!arg)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
builder.create<PrintOp>(loc(call.getLocation()), arg);
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emit a constant for a single number (FIXME: semantic? broadcast?)
|
||||||
|
mlir::Value mlirGen(NumberExpression& num)
|
||||||
|
{
|
||||||
|
return builder.create<ConstantOp>(loc(num.getLocation()), num.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dispatch codegen for the right expression subclass using RTTI.
|
||||||
|
mlir::Value mlirGen(ExpressionNodeBase& expr)
|
||||||
|
{
|
||||||
|
switch (expr.getKind())
|
||||||
|
{
|
||||||
|
case ExpressionNodeBase::BinaryOperation:
|
||||||
|
return mlirGen(cast<BinaryExpression>(expr));
|
||||||
|
case ExpressionNodeBase::Variable:
|
||||||
|
return mlirGen(cast<VariableExpression>(expr));
|
||||||
|
case ExpressionNodeBase::Literal:
|
||||||
|
return mlirGen(cast<LiteralExpression>(expr));
|
||||||
|
case ExpressionNodeBase::Call:
|
||||||
|
return mlirGen(cast<CallExpression>(expr));
|
||||||
|
case ExpressionNodeBase::Number:
|
||||||
|
return mlirGen(cast<NumberExpression>(expr));
|
||||||
|
default:
|
||||||
|
emitError(loc(expr.getLocation()))
|
||||||
|
<< "MLIR codegen encountered an unhandled expr kind '"
|
||||||
|
<< Twine(expr.getKind()) << "'";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a variable declaration, we'll codegen the expression that forms the
|
||||||
|
/// initializer and record the value in the symbol table before returning it.
|
||||||
|
/// Future expressions will be able to reference this variable through symbol
|
||||||
|
/// table lookup.
|
||||||
|
mlir::Value mlirGen(VariableDeclarationExpression& vardecl)
|
||||||
|
{
|
||||||
|
auto* init = vardecl.getInitialValue();
|
||||||
|
if (!init)
|
||||||
|
{
|
||||||
|
emitError(loc(vardecl.getLocation()),
|
||||||
|
"missing initializer in variable declaration");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value value = mlirGen(*init);
|
||||||
|
if (!value)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
// We have the initializer value, but in case the variable was declared
|
||||||
|
// with specific shape, we emit a "reshape" operation. It will get
|
||||||
|
// optimized out later as needed.
|
||||||
|
if (!vardecl.getType().shape.empty())
|
||||||
|
{
|
||||||
|
value = builder.create<ReshapeOp>(loc(vardecl.getLocation()),
|
||||||
|
getType(vardecl.getType()), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the value in the symbol table.
|
||||||
|
if (failed(declare(vardecl.getName(), value)))
|
||||||
|
return nullptr;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Codegen a list of expression, return failure if one of them hit an error.
|
||||||
|
mlir::LogicalResult mlirGen(ExpressionList& blockAST)
|
||||||
|
{
|
||||||
|
llvm::ScopedHashTableScope varScope(symbolTable);
|
||||||
|
for (auto& expr : blockAST)
|
||||||
|
{
|
||||||
|
// Specific handling for variable declarations, return statement, and
|
||||||
|
// print. These can only appear in block list and not in nested
|
||||||
|
// expressions.
|
||||||
|
if (auto* vardecl = dyn_cast<VariableDeclarationExpression>(expr.get()))
|
||||||
|
{
|
||||||
|
if (!mlirGen(*vardecl))
|
||||||
|
return mlir::failure();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto* ret = dyn_cast<ReturnExpression>(expr.get()))
|
||||||
|
return mlirGen(*ret);
|
||||||
|
if (auto* print = dyn_cast<PrintExpression>(expr.get()))
|
||||||
|
{
|
||||||
|
if (mlir::failed(mlirGen(*print)))
|
||||||
|
return mlir::success();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic expression dispatch codegen.
|
||||||
|
if (!mlirGen(*expr))
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a tensor type from a list of shape dimensions.
|
||||||
|
mlir::Type getType(ArrayRef<int64_t> shape)
|
||||||
|
{
|
||||||
|
// If the shape is empty, then this type is unranked.
|
||||||
|
if (shape.empty())
|
||||||
|
return mlir::UnrankedTensorType::get(builder.getF64Type());
|
||||||
|
|
||||||
|
// Otherwise, we use the given shape.
|
||||||
|
return mlir::RankedTensorType::get(shape, builder.getF64Type());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build an MLIR type from a Toy AST variable type (forward to the generic
|
||||||
|
/// getType above).
|
||||||
|
mlir::Type getType(const ValueType& type) { return getType(type.shape); }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace hello
|
||||||
|
{
|
||||||
|
mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext& context, Module& helloModule)
|
||||||
|
{
|
||||||
|
return MLIRGenImpl(context).mlirGen(helloModule);
|
||||||
|
}
|
||||||
|
}
|
77
main.cpp
77
main.cpp
|
@ -4,6 +4,19 @@
|
||||||
#include <llvm/Support/CommandLine.h>
|
#include <llvm/Support/CommandLine.h>
|
||||||
#include <llvm/Support/ErrorOr.h>
|
#include <llvm/Support/ErrorOr.h>
|
||||||
#include <llvm/Support/MemoryBuffer.h>
|
#include <llvm/Support/MemoryBuffer.h>
|
||||||
|
#include <llvm/Support/SourceMgr.h>
|
||||||
|
#include <mlir/IR/BuiltinOps.h.inc>
|
||||||
|
#include <mlir/IR/BuiltinOps.h.inc>
|
||||||
|
#include <mlir/IR/OwningOpRef.h>
|
||||||
|
#include <mlir/Parser/Parser.h>
|
||||||
|
|
||||||
|
#include "Dialect.h"
|
||||||
|
#include "MLIRGen.h"
|
||||||
|
|
||||||
|
namespace mlir
|
||||||
|
{
|
||||||
|
class ModuleOp;
|
||||||
|
}
|
||||||
|
|
||||||
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
|
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
|
||||||
llvm::cl::desc("<input hello file>"),
|
llvm::cl::desc("<input hello file>"),
|
||||||
|
@ -12,11 +25,21 @@ static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
enum Action { None, DumpSyntaxNode };
|
enum Action { None, DumpSyntaxNode, DumpMLIR };
|
||||||
|
|
||||||
|
enum InputType { Hello, MLIR };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static llvm::cl::opt<InputType> inputType("x", llvm::cl::init(Hello),
|
||||||
|
llvm::cl::desc("Decided the kind of input desired."),
|
||||||
|
llvm::cl::values(
|
||||||
|
clEnumValN(Hello, "hello", "load the input file as a hello source.")),
|
||||||
|
llvm::cl::values(
|
||||||
|
clEnumValN(MLIR, "mlir", "load the input file as a mlir source.")));
|
||||||
|
|
||||||
static llvm::cl::opt<Action> emitAction("emit", llvm::cl::desc("Select the kind of output desired"),
|
static llvm::cl::opt<Action> emitAction("emit", llvm::cl::desc("Select the kind of output desired"),
|
||||||
llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")));
|
llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")),
|
||||||
|
llvm::cl::values(clEnumValN(DumpMLIR, "mlir", "Dump mlir code")));
|
||||||
|
|
||||||
std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
|
std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
|
||||||
{
|
{
|
||||||
|
@ -33,6 +56,53 @@ std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
|
||||||
return parser.parseModule();
|
return parser.parseModule();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int dumpMLIR()
|
||||||
|
{
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
context.getOrLoadDialect<mlir::hello::HelloDialect>();
|
||||||
|
|
||||||
|
if (inputType != MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir"))
|
||||||
|
{
|
||||||
|
auto module = parseInputFile(inputFilename);
|
||||||
|
if (module == nullptr)
|
||||||
|
{
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
mlir::OwningOpRef<mlir::ModuleOp> mlirModule = hello::mlirGen(context, *module);
|
||||||
|
|
||||||
|
if (!mlirModule)
|
||||||
|
{
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlirModule->dump();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then the input file is mlir
|
||||||
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||||
|
if (std::error_code ec = fileOrErr.getError())
|
||||||
|
{
|
||||||
|
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the input mlir.
|
||||||
|
llvm::SourceMgr sourceMgr;
|
||||||
|
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||||
|
mlir::OwningOpRef<mlir::ModuleOp> module =
|
||||||
|
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &context);
|
||||||
|
if (!module)
|
||||||
|
{
|
||||||
|
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
module->dump();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n");
|
llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n");
|
||||||
|
@ -49,6 +119,9 @@ int main(int argc, char** argv)
|
||||||
case DumpSyntaxNode:
|
case DumpSyntaxNode:
|
||||||
module->dump();
|
module->dump();
|
||||||
return 0;
|
return 0;
|
||||||
|
case DumpMLIR:
|
||||||
|
dumpMLIR();
|
||||||
|
return 0;
|
||||||
default:
|
default:
|
||||||
llvm::errs() << "Unrecognized action\n";
|
llvm::errs() << "Unrecognized action\n";
|
||||||
return 1;
|
return 1;
|
||||||
|
|
Loading…
Reference in New Issue
Block a user