hello-mlir/include/SyntaxNode.h

359 lines
8.7 KiB
C++

//
// Created by ricardo on 28/05/25.
//
#ifndef SYNTEXNODE_H
#define SYNTEXNODE_H
#include <vector>
#include <cstdint>
#include <mlir/IR/Location.h>
namespace hello
{
struct ValueType
{
std::vector<int64_t> shape;
};
struct Location
{
std::shared_ptr<std::string> file;
int line;
int col;
};
class ExpressionNodeBase
{
public:
enum ExpressionNodeKind
{
VariableDeclaration,
Return,
Number,
Literal,
Variable,
BinaryOperation,
Call,
Print
};
ExpressionNodeBase(ExpressionNodeKind kind, Location location) : kind(kind), location(location)
{
}
virtual ~ExpressionNodeBase() = default;
const Location& getLocation() const
{
return location;
}
ExpressionNodeKind getKind() const
{
return kind;
}
private:
const ExpressionNodeKind kind;
Location location;
};
using ExpressionList = std::vector<std::unique_ptr<ExpressionNodeBase>>;
class NumberExpression : public ExpressionNodeBase
{
double value;
public:
NumberExpression(Location location, const double value) : ExpressionNodeBase(Number,
std::move(location)), value(value)
{
}
double getValue() const
{
return value;
}
static bool classof(const ExpressionNodeBase* c)
{
return c->getKind() == Number;
}
};
class LiteralExpression : public ExpressionNodeBase
{
std::vector<std::unique_ptr<ExpressionNodeBase>> values;
std::vector<int64_t> dimensions;
public:
LiteralExpression(Location location, std::vector<std::unique_ptr<ExpressionNodeBase>> values,
std::vector<int64_t> dimensions)
: ExpressionNodeBase(Literal, std::move(location)), values(std::move(values)),
dimensions(std::move(dimensions))
{
}
llvm::ArrayRef<std::unique_ptr<ExpressionNodeBase>> getValues()
{
return values;
}
llvm::ArrayRef<int64_t> getDimensions()
{
return dimensions;
}
static bool classof(const ExpressionNodeBase* c)
{
return c->getKind() == Literal;
}
};
class VariableExpression : public ExpressionNodeBase
{
std::string name;
public:
VariableExpression(Location location, const llvm::StringRef name)
: ExpressionNodeBase(Variable, std::move(location)), name(name)
{
}
llvm::StringRef getName()
{
return name;
}
static bool classof(const ExpressionNodeBase* c)
{
return c->getKind() == Variable;
}
};
class VariableDeclarationExpression : public ExpressionNodeBase
{
std::string name;
ValueType variableType;
std::unique_ptr<ExpressionNodeBase> initialValue;
public:
VariableDeclarationExpression(Location location, const llvm::StringRef name, ValueType type,
std::unique_ptr<ExpressionNodeBase> initialValue):
ExpressionNodeBase(VariableDeclaration, std::move(location)), name(name), variableType(std::move(type)),
initialValue(std::move(initialValue))
{
}
llvm::StringRef getName()
{
return name;
}
ExpressionNodeBase* getInitialValue() const
{
return initialValue.get();
}
const ValueType& getType()
{
return variableType;
}
static bool classof(const ExpressionNodeBase* c)
{
return c->getKind() == VariableDeclaration;
}
};
class ReturnExpression : public ExpressionNodeBase
{
std::optional<std::unique_ptr<ExpressionNodeBase>> returnExpression;
public:
explicit ReturnExpression(Location location) : ExpressionNodeBase(Return, std::move(location))
{
}
ReturnExpression(Location location, std::unique_ptr<ExpressionNodeBase> expression) : ExpressionNodeBase(
Return,
std::move(location)), returnExpression(std::make_optional(std::move(expression)))
{
}
std::optional<ExpressionNodeBase*> getReturnExpression() const
{
if (returnExpression.has_value())
{
return returnExpression->get();
}
return std::nullopt;
}
static bool classof(const ExpressionNodeBase* c)
{
return c->getKind() == Return;
}
};
class BinaryExpression : public ExpressionNodeBase
{
char binaryOperator;
std::unique_ptr<ExpressionNodeBase> left, right;
public:
BinaryExpression(Location location, char op, std::unique_ptr<ExpressionNodeBase> left,
std::unique_ptr<ExpressionNodeBase> right)
: ExpressionNodeBase(BinaryOperation, std::move(location)), binaryOperator(op), left(std::move(left)),
right(std::move(right))
{
}
char getOperator() const
{
return binaryOperator;
}
ExpressionNodeBase* getLeft() const
{
return left.get();
}
ExpressionNodeBase* getRight() const
{
return right.get();
}
static bool classof(const ExpressionNodeBase* c)
{
return c->getKind() == BinaryOperation;
}
};
class CallExpression : public ExpressionNodeBase
{
std::string name;
std::vector<std::unique_ptr<ExpressionNodeBase>> arguments;
public:
CallExpression(Location location, const std::string& callee,
std::vector<std::unique_ptr<ExpressionNodeBase>> arguments)
: ExpressionNodeBase(Call, std::move(location)), name(std::move(callee)), arguments(std::move(arguments))
{
}
llvm::StringRef getName() const
{
return name;
}
llvm::ArrayRef<std::unique_ptr<ExpressionNodeBase>> getArguments() const
{
return arguments;
}
static bool classof(const ExpressionNodeBase* c)
{
return c->getKind() == Call;
}
};
class PrintExpression : public ExpressionNodeBase
{
std::unique_ptr<ExpressionNodeBase> argument;
public:
PrintExpression(Location location, std::unique_ptr<ExpressionNodeBase> argument)
: ExpressionNodeBase(Print, std::move(location)), argument(std::move(argument))
{
}
ExpressionNodeBase* getArgument() const
{
return argument.get();
}
static bool classof(const ExpressionNodeBase* c)
{
return c->getKind() == Print;
}
};
class FunctionPrototype
{
Location location;
std::string name;
std::vector<std::unique_ptr<VariableExpression>> parameters;
public:
FunctionPrototype(Location location, const std::string& name,
std::vector<std::unique_ptr<VariableExpression>> parameters): location(std::move(location)),
name(name), parameters(std::move(parameters))
{
}
const Location& getLocation() const
{
return location;
}
llvm::StringRef getName() const
{
return name;
}
llvm::ArrayRef<std::unique_ptr<VariableExpression>> getParameters() const
{
return parameters;
}
};
class Function
{
std::unique_ptr<FunctionPrototype> prototype;
std::unique_ptr<ExpressionList> body;
public:
Function(std::unique_ptr<FunctionPrototype> prototype, std::unique_ptr<ExpressionList> body) :
prototype(std::move(prototype)),
body(std::move(body))
{
}
FunctionPrototype* getPrototype() const
{
return prototype.get();
}
ExpressionList* getBody() const
{
return body.get();
}
};
class Module
{
std::vector<Function> functions;
public:
explicit Module(std::vector<Function> functions) : functions(std::move(functions))
{
}
auto begin()
{
return functions.begin();
}
auto end()
{
return functions.end();
}
void dump();
};
}
#endif //SYNTEXNODE_H