feat: toy tutorial chapter 4.

Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
2025-06-03 16:03:17 +08:00
parent eacf20fe3c
commit 902915a57b
12 changed files with 380 additions and 68 deletions

View File

@@ -9,8 +9,10 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "hello/ShapeInferenceInterface.h"
/// Include the auto-generated header file containing the declaration of the toy
/// dialect.

20
include/Passes.h Normal file
View File

@@ -0,0 +1,20 @@
//
// Created by ricardo on 02/06/25.
//
#ifndef PASSES_H
#define PASSES_H
#include <memory>
namespace mlir
{
class Pass;
namespace hello
{
std::unique_ptr<Pass> createShapeInferencePass();
}
}
#endif //PASSES_H

View File

@@ -4,3 +4,8 @@ mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
add_public_tablegen_target(HelloOpsIncGen)
set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
mlir_tablegen(ShapeInferenceInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(ShapeInferenceInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(HelloInterfaceIncGen)

View File

@@ -5,12 +5,17 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "hello/ShapeInferenceInterface.td"
def Hello_Dialect : Dialect {
let name = "hello";
let cppNamespace = "::mlir::hello";
}
class Hello_Op<string mnemonic, list<Trait> traits = []> : Op<Hello_Dialect, mnemonic, traits>;
@@ -70,7 +75,7 @@ def ConstantOp : Hello_Op<"constant", [Pure]> {
// AddOp
//===----------------------------------------------------------------------===//
def AddOp : Hello_Op<"add"> {
def AddOp : Hello_Op<"add", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise addition operation";
let description = [{
The "add" operation performs element-wise addition between two tensors.
@@ -148,7 +153,8 @@ def FuncOp : Hello_Op<"func", [
// GenericCallOp
//===----------------------------------------------------------------------===//
def GenericCallOp : Hello_Op<"generic_call"> {
def GenericCallOp : Hello_Op<"generic_call",
[DeclareOpInterfaceMethods<CallOpInterface>]> {
let summary = "generic call operation";
let description = [{
Generic calls represent calls to a user defined function that needs to
@@ -187,7 +193,7 @@ def GenericCallOp : Hello_Op<"generic_call"> {
// MulOp
//===----------------------------------------------------------------------===//
def MulOp : Hello_Op<"mul"> {
def MulOp : Hello_Op<"mul", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise multiplication operation";
let description = [{
The "mul" operation performs element-wise multiplication between two
@@ -296,7 +302,7 @@ def ReturnOp : Hello_Op<"return", [Pure, HasParent<"FuncOp">,
// TransposeOp
//===----------------------------------------------------------------------===//
def TransposeOp : Hello_Op<"transpose", [Pure]> {
def TransposeOp : Hello_Op<"transpose", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "transpose operation";
let arguments = (ins F64Tensor:$input);
@@ -316,4 +322,25 @@ def TransposeOp : Hello_Op<"transpose", [Pure]> {
let hasCanonicalizer = 1;
}
def CastOp : Hello_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
Pure,
SameOperandsAndResultShape
]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
without changing any data elements. The source and destination types
must both be tensor types with the same element type. If both are ranked,
then shape is required to match. The operation is invalid if converting
to a mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}
#endif

View File

@@ -0,0 +1,18 @@
//
// Created by ricardo on 02/06/25.
//
#ifndef SHAPEINFERENCEINTERFACE_H
#define SHAPEINFERENCEINTERFACE_H
#include "mlir/IR/OpDefinition.h"
namespace mlir
{
namespace hello
{
#include "hello/ShapeInferenceInterface.h.inc"
}
}
#endif //SHAPEINFERENCEINTERFACE_H

View File

@@ -0,0 +1,18 @@
#ifndef SHAPE_INFERENCE_INTERFACE
#define SHAPE_INFERENCE_INTERFACE
include "mlir/IR/OpBase.td"
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
}];
let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.",
"void", "inferShapes">
];
}
#endif