359 lines
8.7 KiB
C++
359 lines
8.7 KiB
C++
//
|
|
// Created by ricardo on 28/05/25.
|
|
//
|
|
|
|
#ifndef SYNTEXNODE_H
|
|
#define SYNTEXNODE_H
|
|
#include <vector>
|
|
#include <cstdint>
|
|
#include <mlir/IR/Location.h>
|
|
|
|
namespace hello
|
|
{
|
|
struct ValueType
|
|
{
|
|
std::vector<int64_t> shape;
|
|
};
|
|
|
|
struct Location
|
|
{
|
|
std::shared_ptr<std::string> file;
|
|
int line;
|
|
int col;
|
|
};
|
|
|
|
class ExpressionNodeBase
|
|
{
|
|
public:
|
|
enum ExpressionNodeKind
|
|
{
|
|
VariableDeclaration,
|
|
Return,
|
|
Number,
|
|
Literal,
|
|
Variable,
|
|
BinaryOperation,
|
|
Call,
|
|
Print
|
|
};
|
|
|
|
ExpressionNodeBase(ExpressionNodeKind kind, Location location) : kind(kind), location(location)
|
|
{
|
|
}
|
|
|
|
virtual ~ExpressionNodeBase() = default;
|
|
|
|
const Location& getLocation() const
|
|
{
|
|
return location;
|
|
}
|
|
|
|
ExpressionNodeKind getKind() const
|
|
{
|
|
return kind;
|
|
}
|
|
|
|
private:
|
|
const ExpressionNodeKind kind;
|
|
Location location;
|
|
};
|
|
|
|
using ExpressionList = std::vector<std::unique_ptr<ExpressionNodeBase>>;
|
|
|
|
class NumberExpression : public ExpressionNodeBase
|
|
{
|
|
double value;
|
|
|
|
public:
|
|
NumberExpression(Location location, const double value) : ExpressionNodeBase(Number,
|
|
std::move(location)), value(value)
|
|
{
|
|
}
|
|
|
|
double getValue() const
|
|
{
|
|
return value;
|
|
}
|
|
|
|
static bool classof(const ExpressionNodeBase* c)
|
|
{
|
|
return c->getKind() == Number;
|
|
}
|
|
};
|
|
|
|
class LiteralExpression : public ExpressionNodeBase
|
|
{
|
|
std::vector<std::unique_ptr<ExpressionNodeBase>> values;
|
|
std::vector<int64_t> dimensions;
|
|
|
|
public:
|
|
LiteralExpression(Location location, std::vector<std::unique_ptr<ExpressionNodeBase>> values,
|
|
std::vector<int64_t> dimensions)
|
|
: ExpressionNodeBase(Literal, std::move(location)), values(std::move(values)),
|
|
dimensions(std::move(dimensions))
|
|
{
|
|
}
|
|
|
|
llvm::ArrayRef<std::unique_ptr<ExpressionNodeBase>> getValues()
|
|
{
|
|
return values;
|
|
}
|
|
|
|
llvm::ArrayRef<int64_t> getDimensions()
|
|
{
|
|
return dimensions;
|
|
}
|
|
|
|
static bool classof(const ExpressionNodeBase* c)
|
|
{
|
|
return c->getKind() == Literal;
|
|
}
|
|
};
|
|
|
|
class VariableExpression : public ExpressionNodeBase
|
|
{
|
|
std::string name;
|
|
|
|
public:
|
|
VariableExpression(Location location, const llvm::StringRef name)
|
|
: ExpressionNodeBase(Variable, std::move(location)), name(name)
|
|
{
|
|
}
|
|
|
|
llvm::StringRef getName()
|
|
{
|
|
return name;
|
|
}
|
|
|
|
static bool classof(const ExpressionNodeBase* c)
|
|
{
|
|
return c->getKind() == Variable;
|
|
}
|
|
};
|
|
|
|
class VariableDeclarationExpression : public ExpressionNodeBase
|
|
{
|
|
std::string name;
|
|
ValueType variableType;
|
|
std::unique_ptr<ExpressionNodeBase> initialValue;
|
|
|
|
public:
|
|
VariableDeclarationExpression(Location location, const llvm::StringRef name, ValueType type,
|
|
std::unique_ptr<ExpressionNodeBase> initialValue):
|
|
ExpressionNodeBase(VariableDeclaration, std::move(location)), name(name), variableType(std::move(type)),
|
|
initialValue(std::move(initialValue))
|
|
{
|
|
}
|
|
|
|
llvm::StringRef getName()
|
|
{
|
|
return name;
|
|
}
|
|
|
|
ExpressionNodeBase* getInitialValue() const
|
|
{
|
|
return initialValue.get();
|
|
}
|
|
|
|
const ValueType& getType()
|
|
{
|
|
return variableType;
|
|
}
|
|
|
|
static bool classof(const ExpressionNodeBase* c)
|
|
{
|
|
return c->getKind() == VariableDeclaration;
|
|
}
|
|
};
|
|
|
|
class ReturnExpression : public ExpressionNodeBase
|
|
{
|
|
std::optional<std::unique_ptr<ExpressionNodeBase>> returnExpression;
|
|
|
|
public:
|
|
explicit ReturnExpression(Location location) : ExpressionNodeBase(Return, std::move(location))
|
|
{
|
|
}
|
|
|
|
ReturnExpression(Location location, std::unique_ptr<ExpressionNodeBase> expression) : ExpressionNodeBase(
|
|
Return,
|
|
std::move(location)), returnExpression(std::make_optional(std::move(expression)))
|
|
{
|
|
}
|
|
|
|
std::optional<ExpressionNodeBase*> getReturnExpression() const
|
|
{
|
|
if (returnExpression.has_value())
|
|
{
|
|
return returnExpression->get();
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
static bool classof(const ExpressionNodeBase* c)
|
|
{
|
|
return c->getKind() == Return;
|
|
}
|
|
};
|
|
|
|
class BinaryExpression : public ExpressionNodeBase
|
|
{
|
|
char binaryOperator;
|
|
std::unique_ptr<ExpressionNodeBase> left, right;
|
|
|
|
public:
|
|
BinaryExpression(Location location, char op, std::unique_ptr<ExpressionNodeBase> left,
|
|
std::unique_ptr<ExpressionNodeBase> right)
|
|
: ExpressionNodeBase(BinaryOperation, std::move(location)), binaryOperator(op), left(std::move(left)),
|
|
right(std::move(right))
|
|
{
|
|
}
|
|
|
|
char getOperator() const
|
|
{
|
|
return binaryOperator;
|
|
}
|
|
|
|
ExpressionNodeBase* getLeft() const
|
|
{
|
|
return left.get();
|
|
}
|
|
|
|
ExpressionNodeBase* getRight() const
|
|
{
|
|
return right.get();
|
|
}
|
|
|
|
static bool classof(const ExpressionNodeBase* c)
|
|
{
|
|
return c->getKind() == BinaryOperation;
|
|
}
|
|
};
|
|
|
|
class CallExpression : public ExpressionNodeBase
|
|
{
|
|
std::string name;
|
|
std::vector<std::unique_ptr<ExpressionNodeBase>> arguments;
|
|
|
|
public:
|
|
CallExpression(Location location, const std::string& callee,
|
|
std::vector<std::unique_ptr<ExpressionNodeBase>> arguments)
|
|
: ExpressionNodeBase(Call, std::move(location)), name(std::move(callee)), arguments(std::move(arguments))
|
|
{
|
|
}
|
|
|
|
llvm::StringRef getName() const
|
|
{
|
|
return name;
|
|
}
|
|
|
|
llvm::ArrayRef<std::unique_ptr<ExpressionNodeBase>> getArguments() const
|
|
{
|
|
return arguments;
|
|
}
|
|
|
|
static bool classof(const ExpressionNodeBase* c)
|
|
{
|
|
return c->getKind() == Call;
|
|
}
|
|
};
|
|
|
|
class PrintExpression : public ExpressionNodeBase
|
|
{
|
|
std::unique_ptr<ExpressionNodeBase> argument;
|
|
|
|
public:
|
|
PrintExpression(Location location, std::unique_ptr<ExpressionNodeBase> argument)
|
|
: ExpressionNodeBase(Print, std::move(location)), argument(std::move(argument))
|
|
{
|
|
}
|
|
|
|
ExpressionNodeBase* getArgument() const
|
|
{
|
|
return argument.get();
|
|
}
|
|
|
|
static bool classof(const ExpressionNodeBase* c)
|
|
{
|
|
return c->getKind() == Print;
|
|
}
|
|
};
|
|
|
|
class FunctionPrototype
|
|
{
|
|
Location location;
|
|
std::string name;
|
|
std::vector<std::unique_ptr<VariableExpression>> parameters;
|
|
|
|
public:
|
|
FunctionPrototype(Location location, const std::string& name,
|
|
std::vector<std::unique_ptr<VariableExpression>> parameters): location(std::move(location)),
|
|
name(name), parameters(std::move(parameters))
|
|
{
|
|
}
|
|
|
|
const Location& getLocation() const
|
|
{
|
|
return location;
|
|
}
|
|
|
|
llvm::StringRef getName() const
|
|
{
|
|
return name;
|
|
}
|
|
|
|
llvm::ArrayRef<std::unique_ptr<VariableExpression>> getParameters() const
|
|
{
|
|
return parameters;
|
|
}
|
|
};
|
|
|
|
class Function
|
|
{
|
|
std::unique_ptr<FunctionPrototype> prototype;
|
|
std::unique_ptr<ExpressionList> body;
|
|
|
|
public:
|
|
Function(std::unique_ptr<FunctionPrototype> prototype, std::unique_ptr<ExpressionList> body) :
|
|
prototype(std::move(prototype)),
|
|
body(std::move(body))
|
|
{
|
|
}
|
|
|
|
FunctionPrototype* getPrototype() const
|
|
{
|
|
return prototype.get();
|
|
}
|
|
|
|
ExpressionList* getBody() const
|
|
{
|
|
return body.get();
|
|
}
|
|
};
|
|
|
|
class Module
|
|
{
|
|
std::vector<Function> functions;
|
|
|
|
public:
|
|
explicit Module(std::vector<Function> functions) : functions(std::move(functions))
|
|
{
|
|
}
|
|
|
|
auto begin()
|
|
{
|
|
return functions.begin();
|
|
}
|
|
|
|
auto end()
|
|
{
|
|
return functions.end();
|
|
}
|
|
|
|
void dump();
|
|
};
|
|
}
|
|
|
|
#endif //SYNTEXNODE_H
|