hello-mlir/lib/SyntaxNode.cpp

246 lines
6.9 KiB
C++

//
// Created by ricardo on 28/05/25.
//
#include <llvm/Support/raw_ostream.h>
#include <llvm/ADT/TypeSwitch.h>
#include "SyntaxNode.h"
using namespace hello;
namespace
{
struct Indent
{
int& level;
explicit Indent(int& level) : level(level)
{
++level;
}
~Indent()
{
--level;
}
};
class SyntaxNodeDumper
{
public:
void dump(Module* module);
private:
static void dump(const ValueType& type);
void dump(VariableDeclarationExpression* varDecl);
void dump(ExpressionNodeBase* expr);
void dump(const ExpressionList* exprList);
void dump(NumberExpression* num);
void dump(LiteralExpression* node);
void dump(VariableExpression* node);
void dump(const ReturnExpression* node);
void dump(BinaryExpression* node);
void dump(CallExpression* node);
void dump(PrintExpression* node);
void dump(FunctionPrototype* node);
void dump(const Function* node);
void indent() const
{
for (int i = 0; i < currentIndent; i++)
{
llvm::outs() << " ";
}
}
int currentIndent = 0;
};
}
template <typename T>
static std::string formatLocation(T* node)
{
const auto& location = node->getLocation();
return (llvm::Twine("@") + *location.file + ":" + llvm::Twine(location.line) + ":" + llvm::Twine(location.col)).
str();
}
#define INDENT() Indent level(currentIndent); indent();
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void SyntaxNodeDumper::dump(ExpressionNodeBase* expr)
{
llvm::TypeSwitch<ExpressionNodeBase*>(expr)
.Case<BinaryExpression, CallExpression, LiteralExpression, NumberExpression,
PrintExpression, ReturnExpression, VariableDeclarationExpression, VariableExpression>(
[&](auto* node) { this->dump(node); })
.Default([&](ExpressionNodeBase*)
{
// No match, fallback to a generic message
INDENT();
llvm::outs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
});
}
/// A variable declaration is printing the variable name, the type, and then
/// recurse in the initializer value.
void SyntaxNodeDumper::dump(VariableDeclarationExpression* varDecl)
{
INDENT();
llvm::outs() << "VarDecl " << varDecl->getName();
dump(varDecl->getType());
llvm::outs() << " " << formatLocation(varDecl) << "\n";
dump(varDecl->getInitialValue());
}
/// A "block", or a list of expression
void SyntaxNodeDumper::dump(const ExpressionList* exprList)
{
INDENT();
llvm::outs() << "Block {\n";
for (auto& expr : *exprList)
dump(expr.get());
indent();
llvm::outs() << "} // Block\n";
}
/// A literal number, just print the value.
void SyntaxNodeDumper::dump(NumberExpression* num)
{
INDENT();
llvm::outs() << num->getValue() << " " << formatLocation(num) << "\n";
}
/// Helper to print recursively a literal. This handles nested array like:
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
void printLitHelper(ExpressionNodeBase* litOrNum)
{
// Inside a literal expression we can have either a number or another literal
if (const auto* num = llvm::dyn_cast<NumberExpression>(litOrNum))
{
llvm::outs() << num->getValue();
return;
}
auto* literal = llvm::cast<LiteralExpression>(litOrNum);
// Print the dimension for this literal first
llvm::outs() << "<";
interleaveComma(literal->getDimensions(), llvm::outs());
llvm::outs() << ">";
// Now print the content, recursing on every element of the list
llvm::outs() << "[ ";
interleaveComma(literal->getValues(), llvm::outs(),
[&](auto& elt) { printLitHelper(elt.get()); });
llvm::outs() << "]";
}
/// Print a literal, see the recursive helper above for the implementation.
void SyntaxNodeDumper::dump(LiteralExpression* node)
{
INDENT();
llvm::outs() << "Literal: ";
printLitHelper(node);
llvm::outs() << " " << formatLocation(node) << "\n";
}
/// Print a variable reference (just a name).
void SyntaxNodeDumper::dump(VariableExpression* node)
{
INDENT();
llvm::outs() << "var: " << node->getName() << " " << formatLocation(node) << "\n";
}
/// Return statement print the return and its (optional) argument.
void SyntaxNodeDumper::dump(const ReturnExpression* node)
{
INDENT();
llvm::outs() << "Return\n";
if (node->getReturnExpression().has_value())
return dump(*node->getReturnExpression());
{
INDENT();
llvm::outs() << "(void)\n";
}
}
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
void SyntaxNodeDumper::dump(BinaryExpression* node)
{
INDENT();
llvm::outs() << "BinOp: " << node->getOperator() << " " << formatLocation(node) << "\n";
dump(node->getLeft());
dump(node->getRight());
}
/// Print a call expression, first the callee name and the list of args by
/// recursing into each individual argument.
void SyntaxNodeDumper::dump(CallExpression* node)
{
INDENT();
llvm::outs() << "Call '" << node->getName() << "' [ " << formatLocation(node) << "\n";
for (auto& arg : node->getArguments())
dump(arg.get());
indent();
llvm::outs() << "]\n";
}
/// Print a builtin print call, first the builtin name and then the argument.
void SyntaxNodeDumper::dump(PrintExpression* node)
{
INDENT();
llvm::outs() << "Print [ " << formatLocation(node) << "\n";
dump(node->getArgument());
indent();
llvm::outs() << "]\n";
}
/// Print type: only the shape is printed in between '<' and '>'
void SyntaxNodeDumper::dump(const ValueType& type)
{
llvm::outs() << "<";
interleaveComma(type.shape, llvm::outs());
llvm::outs() << ">";
}
/// Print a function prototype, first the function name, and then the list of
/// parameters names.
void SyntaxNodeDumper::dump(FunctionPrototype* node)
{
INDENT();
llvm::outs() << "Proto '" << node->getName() << "' " << formatLocation(node) << "\n";
indent();
llvm::outs() << "Params: [";
llvm::interleaveComma(node->getParameters(), llvm::outs(),
[](auto& arg) { llvm::outs() << arg->getName(); });
llvm::outs() << "]\n";
}
/// Print a function, first the prototype and then the body.
void SyntaxNodeDumper::dump(const Function* node)
{
INDENT();
llvm::outs() << "Function \n";
dump(node->getPrototype());
dump(node->getBody());
}
/// Print a module, actually loop over the functions and print them in sequence.
void SyntaxNodeDumper::dump(Module* module)
{
INDENT();
llvm::outs() << "Module:\n";
for (auto& f : *module)
dump(&f);
}
namespace hello
{
void Module::dump()
{
SyntaxNodeDumper().dump(this);
}
}