424 lines
14 KiB
C++
424 lines
14 KiB
C++
//
|
|
// Created by ricardo on 29/05/25.
|
|
//
|
|
|
|
#include "Dialect.h"
|
|
#include "hello/Dialect.cpp.inc"
|
|
|
|
#include <mlir/Interfaces/FunctionImplementation.h>
|
|
#include <mlir/Transforms/InliningUtils.h>
|
|
#include <oneapi/tbb/detail/_template_helpers.h>
|
|
|
|
struct HelloDialectInlinerInterface : mlir::DialectInlinerInterface
|
|
{
|
|
using DialectInlinerInterface::DialectInlinerInterface;
|
|
|
|
bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable, bool wouldBeCloned) const override
|
|
{
|
|
return true;
|
|
}
|
|
|
|
bool isLegalToInline(mlir::Region* dest, mlir::Region* src, bool wouldBeCloned,
|
|
mlir::IRMapping& valueMapping) const override
|
|
{
|
|
return true;
|
|
}
|
|
|
|
bool isLegalToInline(mlir::Operation* op, mlir::Region* dest, bool wouldBeCloned,
|
|
mlir::IRMapping& valueMapping) const override
|
|
{
|
|
return true;
|
|
}
|
|
|
|
void handleTerminator(mlir::Operation* op, mlir::ValueRange returnValues) const override
|
|
{
|
|
// Only the `hello.returnOp` is the function terminator
|
|
auto returnOp = llvm::cast<mlir::hello::ReturnOp>(op);
|
|
|
|
assert(returnOp.getNumOperands() == returnValues.size());
|
|
|
|
for (const auto& it : llvm::enumerate(returnOp.getOperands()))
|
|
{
|
|
returnValues[it.index()].replaceAllUsesWith(it.value());
|
|
}
|
|
}
|
|
|
|
mlir::Operation* materializeCallConversion(mlir::OpBuilder& builder, mlir::Value input, mlir::Type resultType,
|
|
mlir::Location conversionLoc) const override
|
|
{
|
|
return builder.create<mlir::hello::CastOp>(conversionLoc, resultType, input);
|
|
}
|
|
};
|
|
|
|
void mlir::hello::HelloDialect::initialize()
|
|
{
|
|
addOperations<
|
|
#define GET_OP_LIST
|
|
#include "hello/Ops.cpp.inc"
|
|
>();
|
|
addInterfaces<HelloDialectInlinerInterface>();
|
|
}
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::hello;
|
|
|
|
|
|
/// A generalized parser for binary operations. This parses the different forms
|
|
/// of 'printBinaryOp' below.
|
|
static ParseResult parseBinaryOp(OpAsmParser& parser,
|
|
OperationState& result)
|
|
{
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
|
|
SMLoc operandsLoc = parser.getCurrentLocation();
|
|
Type type;
|
|
if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
parser.parseColonType(type))
|
|
return failure();
|
|
|
|
// If the type is a function type, it contains the input and result types of
|
|
// this operation.
|
|
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type))
|
|
{
|
|
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
|
result.operands))
|
|
return failure();
|
|
result.addTypes(funcType.getResults());
|
|
return success();
|
|
}
|
|
|
|
// Otherwise, the parsed type is the type of both operands and results.
|
|
if (parser.resolveOperands(operands, type, result.operands))
|
|
return failure();
|
|
result.addTypes(type);
|
|
return success();
|
|
}
|
|
|
|
/// A generalized printer for binary operations. It prints in two different
|
|
/// forms depending on if all of the types match.
|
|
static void printBinaryOp(OpAsmPrinter& printer, Operation* op)
|
|
{
|
|
printer << " " << op->getOperands();
|
|
printer.printOptionalAttrDict(op->getAttrs());
|
|
printer << " : ";
|
|
|
|
// If all of the types are the same, print the type directly.
|
|
Type resultType = *op->result_type_begin();
|
|
if (llvm::all_of(op->getOperandTypes(),
|
|
[=](Type type) { return type == resultType; }))
|
|
{
|
|
printer << resultType;
|
|
return;
|
|
}
|
|
|
|
// Otherwise, print a functional type.
|
|
printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConstantOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Build a constant operation.
|
|
/// The builder is passed as an argument, so is the state that this method is
|
|
/// expected to fill in order to build the operation.
|
|
void ConstantOp::build(OpBuilder& builder, OperationState& state,
|
|
double value)
|
|
{
|
|
auto dataType = RankedTensorType::get({}, builder.getF64Type());
|
|
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
|
build(builder, state, dataType, dataAttribute);
|
|
}
|
|
|
|
/// The 'OpAsmParser' class provides a collection of methods for parsing
|
|
/// various punctuation, as well as attributes, operands, types, etc. Each of
|
|
/// these methods returns a `ParseResult`. This class is a wrapper around
|
|
/// `LogicalResult` that can be converted to a boolean `true` value on failure,
|
|
/// or `false` on success. This allows for easily chaining together a set of
|
|
/// parser rules. These rules are used to populate an `mlir::OperationState`
|
|
/// similarly to the `build` methods described above.
|
|
ParseResult ConstantOp::parse(OpAsmParser& parser,
|
|
OperationState& result)
|
|
{
|
|
DenseElementsAttr value;
|
|
if (parser.parseOptionalAttrDict(result.attributes) ||
|
|
parser.parseAttribute(value, "value", result.attributes))
|
|
return failure();
|
|
|
|
result.addTypes(value.getType());
|
|
return success();
|
|
}
|
|
|
|
/// The 'OpAsmPrinter' class is a stream that allows for formatting
|
|
/// strings, attributes, operands, types, etc.
|
|
void ConstantOp::print(OpAsmPrinter& printer)
|
|
{
|
|
printer << " ";
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
|
|
printer << getValue();
|
|
}
|
|
|
|
/// Verifier for the constant operation. This corresponds to the
|
|
/// `let hasVerifier = 1` in the op definition.
|
|
LogicalResult ConstantOp::verify()
|
|
{
|
|
// If the return type of the constant is not an unranked tensor, the shape
|
|
// must match the shape of the attribute holding the data.
|
|
auto resultType = llvm::dyn_cast<RankedTensorType>(getResult().getType());
|
|
if (!resultType)
|
|
return success();
|
|
|
|
// Check that the rank of the attribute type matches the rank of the constant
|
|
// result type.
|
|
auto attrType = llvm::cast<RankedTensorType>(getValue().getType());
|
|
if (attrType.getRank() != resultType.getRank())
|
|
{
|
|
return emitOpError("return type must match the one of the attached value "
|
|
"attribute: ")
|
|
<< attrType.getRank() << " != " << resultType.getRank();
|
|
}
|
|
|
|
// Check that each of the dimensions match between the two types.
|
|
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim)
|
|
{
|
|
if (attrType.getShape()[dim] != resultType.getShape()[dim])
|
|
{
|
|
return emitOpError(
|
|
"return type shape mismatches its attribute at dimension ")
|
|
<< dim << ": " << attrType.getShape()[dim]
|
|
<< " != " << resultType.getShape()[dim];
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AddOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AddOp::build(OpBuilder& builder, OperationState& state,
|
|
Value lhs, Value rhs)
|
|
{
|
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
|
state.addOperands({lhs, rhs});
|
|
}
|
|
|
|
ParseResult AddOp::parse(OpAsmParser& parser,
|
|
OperationState& result)
|
|
{
|
|
return parseBinaryOp(parser, result);
|
|
}
|
|
|
|
void AddOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); }
|
|
|
|
void AddOp::inferShapes()
|
|
{
|
|
getResult().setType(getLhs().getType());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GenericCallOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void GenericCallOp::build(OpBuilder& builder, OperationState& state,
|
|
StringRef callee, ArrayRef<Value> arguments)
|
|
{
|
|
// Generic call always returns an unranked Tensor initially.
|
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
|
state.addOperands(arguments);
|
|
state.addAttribute("callee",
|
|
SymbolRefAttr::get(builder.getContext(), callee));
|
|
}
|
|
|
|
CallInterfaceCallable GenericCallOp::getCallableForCallee()
|
|
{
|
|
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
|
|
}
|
|
|
|
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee)
|
|
{
|
|
(*this)->setAttr("callee", mlir::cast<SymbolRefAttr>(callee));
|
|
}
|
|
|
|
Operation::operand_range GenericCallOp::getArgOperands()
|
|
{
|
|
return getInputs();
|
|
}
|
|
|
|
MutableOperandRange GenericCallOp::getArgOperandsMutable()
|
|
{
|
|
return getInputsMutable();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FuncOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void FuncOp::build(OpBuilder& builder, OperationState& state,
|
|
StringRef name, FunctionType type,
|
|
ArrayRef<NamedAttribute> attrs)
|
|
{
|
|
// FunctionOpInterface provides a convenient `build` method that will populate
|
|
// the state of our FuncOp, and create an entry block.
|
|
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
|
|
}
|
|
|
|
ParseResult FuncOp::parse(OpAsmParser& parser,
|
|
OperationState& result)
|
|
{
|
|
// Dispatch to the FunctionOpInterface provided utility method that parses the
|
|
// function operation.
|
|
auto buildFuncType =
|
|
[](Builder& builder, ArrayRef<Type> argTypes,
|
|
ArrayRef<Type> results,
|
|
function_interface_impl::VariadicFlag,
|
|
std::string&)
|
|
{
|
|
return builder.getFunctionType(argTypes, results);
|
|
};
|
|
|
|
return function_interface_impl::parseFunctionOp(
|
|
parser, result, /*allowVariadic=*/false,
|
|
getFunctionTypeAttrName(result.name), buildFuncType,
|
|
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
|
|
}
|
|
|
|
void FuncOp::print(OpAsmPrinter& p)
|
|
{
|
|
// Dispatch to the FunctionOpInterface provided utility method that prints the
|
|
// function operation.
|
|
function_interface_impl::printFunctionOp(
|
|
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
|
|
getArgAttrsAttrName(), getResAttrsAttrName());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MulOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void MulOp::build(OpBuilder& builder, OperationState& state,
|
|
Value lhs, Value rhs)
|
|
{
|
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
|
state.addOperands({lhs, rhs});
|
|
}
|
|
|
|
ParseResult MulOp::parse(OpAsmParser& parser,
|
|
OperationState& result)
|
|
{
|
|
return parseBinaryOp(parser, result);
|
|
}
|
|
|
|
void MulOp::print(OpAsmPrinter& p) { printBinaryOp(p, *this); }
|
|
|
|
void MulOp::inferShapes()
|
|
{
|
|
getResult().setType(getLhs().getType());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReturnOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult ReturnOp::verify()
|
|
{
|
|
// We know that the parent operation is a function, because of the 'HasParent'
|
|
// trait attached to the operation definition.
|
|
auto function = cast<FuncOp>((*this)->getParentOp());
|
|
|
|
/// ReturnOps can only have a single optional operand.
|
|
if (getNumOperands() > 1)
|
|
return emitOpError() << "expects at most 1 return operand";
|
|
|
|
// The operand number and types must match the function signature.
|
|
const auto& results = function.getFunctionType().getResults();
|
|
if (getNumOperands() != results.size())
|
|
return emitOpError() << "does not return the same number of values ("
|
|
<< getNumOperands() << ") as the enclosing function ("
|
|
<< results.size() << ")";
|
|
|
|
// If the operation does not have an input, we are done.
|
|
if (!hasOperand())
|
|
return success();
|
|
|
|
auto inputType = *operand_type_begin();
|
|
auto resultType = results.front();
|
|
|
|
// Check that the result type of the function matches the operand type.
|
|
if (inputType == resultType || llvm::isa<UnrankedTensorType>(inputType) ||
|
|
llvm::isa<UnrankedTensorType>(resultType))
|
|
return success();
|
|
|
|
return emitError() << "type of return operand (" << inputType
|
|
<< ") doesn't match function result type (" << resultType
|
|
<< ")";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransposeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TransposeOp::build(OpBuilder& builder, OperationState& state,
|
|
Value value)
|
|
{
|
|
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
|
state.addOperands(value);
|
|
}
|
|
|
|
LogicalResult TransposeOp::verify()
|
|
{
|
|
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
|
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
|
if (!inputType || !resultType)
|
|
return success();
|
|
|
|
auto inputShape = inputType.getShape();
|
|
if (!std::equal(inputShape.begin(), inputShape.end(),
|
|
resultType.getShape().rbegin()))
|
|
{
|
|
return emitError()
|
|
<< "expected result shape to be a transpose of the input";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void TransposeOp::inferShapes()
|
|
{
|
|
// Transpose will reverse the shape of tensor.
|
|
auto tensorType = llvm::cast<RankedTensorType>(getOperand().getType());
|
|
// And assume that transpose only applies for matrix.
|
|
SmallVector<int64_t, 2> dimensions(llvm::reverse(tensorType.getShape()));
|
|
getResult().setType(RankedTensorType::get(dimensions, tensorType.getElementType()));
|
|
}
|
|
|
|
// CastOp
|
|
|
|
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs)
|
|
{
|
|
if (inputs.size() != 1 || outputs.size() != 1)
|
|
{
|
|
return false;
|
|
}
|
|
|
|
const auto inputTensorType = mlir::dyn_cast<TensorType>(inputs.front());
|
|
const auto outputTensorType = mlir::dyn_cast<TensorType>(outputs.front());
|
|
|
|
if (!inputTensorType || !outputTensorType || inputTensorType.getElementType() != outputTensorType.getElementType())
|
|
{
|
|
return false;
|
|
}
|
|
|
|
// If both have rank, they must be to the size.
|
|
// And the known size can be cast into unknown size.
|
|
return !inputTensorType.hasRank() || !outputTensorType.hasRank() || inputTensorType == outputTensorType;
|
|
}
|
|
|
|
void CastOp::inferShapes()
|
|
{
|
|
getResult().setType(getInput().getType());
|
|
}
|
|
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "hello/Ops.cpp.inc"
|