feat: toy tutorial chapter 4.
Signed-off-by: jackfiled <xcrenchangjun@outlook.com>
This commit is contained in:
@@ -9,8 +9,10 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "hello/ShapeInferenceInterface.h"
|
||||
|
||||
/// Include the auto-generated header file containing the declaration of the toy
|
||||
/// dialect.
|
||||
|
20
include/Passes.h
Normal file
20
include/Passes.h
Normal file
@@ -0,0 +1,20 @@
|
||||
//
|
||||
// Created by ricardo on 02/06/25.
|
||||
//
|
||||
|
||||
#ifndef PASSES_H
|
||||
#define PASSES_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir
|
||||
{
|
||||
class Pass;
|
||||
|
||||
namespace hello
|
||||
{
|
||||
std::unique_ptr<Pass> createShapeInferencePass();
|
||||
}
|
||||
}
|
||||
|
||||
#endif //PASSES_H
|
@@ -4,3 +4,8 @@ mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||
add_public_tablegen_target(HelloOpsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
|
||||
mlir_tablegen(ShapeInferenceInterface.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(ShapeInferenceInterface.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(HelloInterfaceIncGen)
|
||||
|
@@ -5,12 +5,17 @@ include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/FunctionInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "hello/ShapeInferenceInterface.td"
|
||||
|
||||
def Hello_Dialect : Dialect {
|
||||
let name = "hello";
|
||||
let cppNamespace = "::mlir::hello";
|
||||
}
|
||||
|
||||
|
||||
|
||||
class Hello_Op<string mnemonic, list<Trait> traits = []> : Op<Hello_Dialect, mnemonic, traits>;
|
||||
|
||||
|
||||
@@ -70,7 +75,7 @@ def ConstantOp : Hello_Op<"constant", [Pure]> {
|
||||
// AddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AddOp : Hello_Op<"add"> {
|
||||
def AddOp : Hello_Op<"add", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "element-wise addition operation";
|
||||
let description = [{
|
||||
The "add" operation performs element-wise addition between two tensors.
|
||||
@@ -148,7 +153,8 @@ def FuncOp : Hello_Op<"func", [
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def GenericCallOp : Hello_Op<"generic_call"> {
|
||||
def GenericCallOp : Hello_Op<"generic_call",
|
||||
[DeclareOpInterfaceMethods<CallOpInterface>]> {
|
||||
let summary = "generic call operation";
|
||||
let description = [{
|
||||
Generic calls represent calls to a user defined function that needs to
|
||||
@@ -187,7 +193,7 @@ def GenericCallOp : Hello_Op<"generic_call"> {
|
||||
// MulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MulOp : Hello_Op<"mul"> {
|
||||
def MulOp : Hello_Op<"mul", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "element-wise multiplication operation";
|
||||
let description = [{
|
||||
The "mul" operation performs element-wise multiplication between two
|
||||
@@ -296,7 +302,7 @@ def ReturnOp : Hello_Op<"return", [Pure, HasParent<"FuncOp">,
|
||||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TransposeOp : Hello_Op<"transpose", [Pure]> {
|
||||
def TransposeOp : Hello_Op<"transpose", [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "transpose operation";
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
@@ -316,4 +322,25 @@ def TransposeOp : Hello_Op<"transpose", [Pure]> {
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def CastOp : Hello_Op<"cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
|
||||
Pure,
|
||||
SameOperandsAndResultShape
|
||||
]> {
|
||||
let summary = "shape cast operation";
|
||||
|
||||
let description = [{
|
||||
The "cast" operation converts a tensor from one type to an equivalent type
|
||||
without changing any data elements. The source and destination types
|
||||
must both be tensor types with the same element type. If both are ranked,
|
||||
then shape is required to match. The operation is invalid if converting
|
||||
to a mismatching constant dimension.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor:$output);
|
||||
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
|
||||
}
|
||||
|
||||
#endif
|
||||
|
18
include/hello/ShapeInferenceInterface.h
Normal file
18
include/hello/ShapeInferenceInterface.h
Normal file
@@ -0,0 +1,18 @@
|
||||
//
|
||||
// Created by ricardo on 02/06/25.
|
||||
//
|
||||
|
||||
#ifndef SHAPEINFERENCEINTERFACE_H
|
||||
#define SHAPEINFERENCEINTERFACE_H
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir
|
||||
{
|
||||
namespace hello
|
||||
{
|
||||
#include "hello/ShapeInferenceInterface.h.inc"
|
||||
}
|
||||
}
|
||||
|
||||
#endif //SHAPEINFERENCEINTERFACE_H
|
18
include/hello/ShapeInferenceInterface.td
Normal file
18
include/hello/ShapeInferenceInterface.td
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef SHAPE_INFERENCE_INTERFACE
|
||||
#define SHAPE_INFERENCE_INTERFACE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
|
||||
let description = [{
|
||||
Interface to access a registered method to infer the return types for an
|
||||
operation that can be used during type inference.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||
"void", "inferShapes">
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
Reference in New Issue
Block a user