feat: toy tutorial chapter 1.
This commit is contained in:
commit
1a64b78ef8
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
cmake-*/
|
||||||
|
build/
|
||||||
|
.idea/
|
43
CMakeLists.txt
Normal file
43
CMakeLists.txt
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
project(hello_mlir)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
|
||||||
|
find_package(MLIR REQUIRED CONFIG)
|
||||||
|
find_package(LLVM REQUIRED CONFIG)
|
||||||
|
message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
|
||||||
|
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
|
||||||
|
|
||||||
|
set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin)
|
||||||
|
set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib)
|
||||||
|
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})
|
||||||
|
|
||||||
|
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
|
||||||
|
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
|
||||||
|
|
||||||
|
include(TableGen)
|
||||||
|
include(AddLLVM)
|
||||||
|
include(AddMLIR)
|
||||||
|
include(HandleLLVMOptions)
|
||||||
|
|
||||||
|
message(${MLIR_INCLUDE_DIRS})
|
||||||
|
include_directories(${LLVM_INCLUDE_DIRS})
|
||||||
|
include_directories(${MLIR_INCLUDE_DIRS})
|
||||||
|
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
||||||
|
add_definitions(${LLVM_DEFINITIONS})
|
||||||
|
|
||||||
|
include_directories(include)
|
||||||
|
|
||||||
|
add_library(SyntaxNode SHARED lib/SyntaxNode.cpp include/SyntaxNode.h include/Parser.h include/Lexer.h)
|
||||||
|
|
||||||
|
target_link_libraries(SyntaxNode
|
||||||
|
PRIVATE
|
||||||
|
MLIRSupport)
|
||||||
|
|
||||||
|
add_executable(hello-mlir main.cpp)
|
||||||
|
|
||||||
|
target_link_libraries(hello-mlir
|
||||||
|
PRIVATE
|
||||||
|
SyntaxNode
|
||||||
|
LLVMSupport
|
||||||
|
LLVMCore)
|
13
examples/transpose.hello
Normal file
13
examples/transpose.hello
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
def main() {
|
||||||
|
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
|
||||||
|
# The shape is inferred from the supplied literal.
|
||||||
|
var a = [[1, 2, 3], [4, 5, 6]];
|
||||||
|
|
||||||
|
# b is identical to a, the literal tensor is implicitly reshaped: defining new
|
||||||
|
# variables is the way to reshape tensors (element count must match).
|
||||||
|
var b<2, 3> = [1, 2, 3, 4, 5, 6];
|
||||||
|
|
||||||
|
# transpose() and print() are the only builtin, the following will transpose
|
||||||
|
# a and b and perform an element-wise multiplication before printing the result.
|
||||||
|
print(transpose(a) * transpose(b));
|
||||||
|
}
|
250
include/Lexer.h
Normal file
250
include/Lexer.h
Normal file
|
@ -0,0 +1,250 @@
|
||||||
|
//===- Lexer.h - Lexer for the Toy language -------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements a simple Lexer for the Toy language.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TOY_LEXER_H
|
||||||
|
#define TOY_LEXER_H
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "SyntaxNode.h"
|
||||||
|
|
||||||
|
namespace hello
|
||||||
|
{
|
||||||
|
// List of Token returned by the lexer.
|
||||||
|
enum Token : int
|
||||||
|
{
|
||||||
|
tok_semicolon = ';',
|
||||||
|
tok_parenthese_open = '(',
|
||||||
|
tok_parenthese_close = ')',
|
||||||
|
tok_bracket_open = '{',
|
||||||
|
tok_bracket_close = '}',
|
||||||
|
tok_sbracket_open = '[',
|
||||||
|
tok_sbracket_close = ']',
|
||||||
|
|
||||||
|
tok_eof = -1,
|
||||||
|
|
||||||
|
// commands
|
||||||
|
tok_return = -2,
|
||||||
|
tok_var = -3,
|
||||||
|
tok_def = -4,
|
||||||
|
|
||||||
|
// primary
|
||||||
|
tok_identifier = -5,
|
||||||
|
tok_number = -6,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// The Lexer is an abstract base class providing all the facilities that the
|
||||||
|
/// Parser expects. It goes through the stream one token at a time and keeps
|
||||||
|
/// track of the location in the file for debugging purposes.
|
||||||
|
/// It relies on a subclass to provide a `readNextLine()` method. The subclass
|
||||||
|
/// can proceed by reading the next line from the standard input or from a
|
||||||
|
/// memory mapped file.
|
||||||
|
class Lexer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
/// Create a lexer for the given filename. The filename is kept only for
|
||||||
|
/// debugging purposes (attaching a location to a Token).
|
||||||
|
explicit Lexer(std::string filename)
|
||||||
|
: lastLocation(
|
||||||
|
{std::make_shared<std::string>(std::move(filename)), 0, 0})
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~Lexer() = default;
|
||||||
|
|
||||||
|
/// Look at the current token in the stream.
|
||||||
|
Token getCurToken() const { return curTok; }
|
||||||
|
|
||||||
|
/// Move to the next token in the stream and return it.
|
||||||
|
Token getNextToken() { return curTok = getTok(); }
|
||||||
|
|
||||||
|
/// Move to the next token in the stream, asserting on the current token
|
||||||
|
/// matching the expectation.
|
||||||
|
void consume(Token tok)
|
||||||
|
{
|
||||||
|
assert(tok == curTok && "consume Token mismatch expectation");
|
||||||
|
getNextToken();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
|
||||||
|
llvm::StringRef getId()
|
||||||
|
{
|
||||||
|
assert(curTok == tok_identifier);
|
||||||
|
return identifierStr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the current number (prereq: getCurToken() == tok_number)
|
||||||
|
double getValue()
|
||||||
|
{
|
||||||
|
assert(curTok == tok_number);
|
||||||
|
return numVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the location for the beginning of the current token.
|
||||||
|
Location getLastLocation() { return lastLocation; }
|
||||||
|
|
||||||
|
// Return the current line in the file.
|
||||||
|
int getLine() const { return curLineNum; }
|
||||||
|
|
||||||
|
// Return the current column in the file.
|
||||||
|
int getCol() const { return curCol; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Delegate to a derived class fetching the next line. Returns an empty
|
||||||
|
/// string to signal end of file (EOF). Lines are expected to always finish
|
||||||
|
/// with "\n"
|
||||||
|
virtual llvm::StringRef readNextLine() = 0;
|
||||||
|
|
||||||
|
/// Return the next character from the stream. This manages the buffer for the
|
||||||
|
/// current line and request the next line buffer to the derived class as
|
||||||
|
/// needed.
|
||||||
|
int getNextChar()
|
||||||
|
{
|
||||||
|
// The current line buffer should not be empty unless it is the end of file.
|
||||||
|
if (curLineBuffer.empty())
|
||||||
|
return EOF;
|
||||||
|
++curCol;
|
||||||
|
auto nextchar = curLineBuffer.front();
|
||||||
|
curLineBuffer = curLineBuffer.drop_front();
|
||||||
|
if (curLineBuffer.empty())
|
||||||
|
curLineBuffer = readNextLine();
|
||||||
|
if (nextchar == '\n')
|
||||||
|
{
|
||||||
|
++curLineNum;
|
||||||
|
curCol = 0;
|
||||||
|
}
|
||||||
|
return nextchar;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the next token from standard input.
|
||||||
|
Token getTok()
|
||||||
|
{
|
||||||
|
// Skip any whitespace.
|
||||||
|
while (isspace(lastChar))
|
||||||
|
lastChar = Token(getNextChar());
|
||||||
|
|
||||||
|
// Save the current location before reading the token characters.
|
||||||
|
lastLocation.line = curLineNum;
|
||||||
|
lastLocation.col = curCol;
|
||||||
|
|
||||||
|
// Identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||||
|
if (isalpha(lastChar))
|
||||||
|
{
|
||||||
|
identifierStr = (char)lastChar;
|
||||||
|
while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
|
||||||
|
identifierStr += (char)lastChar;
|
||||||
|
|
||||||
|
if (identifierStr == "return")
|
||||||
|
return tok_return;
|
||||||
|
if (identifierStr == "def")
|
||||||
|
return tok_def;
|
||||||
|
if (identifierStr == "var")
|
||||||
|
return tok_var;
|
||||||
|
return tok_identifier;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number: [0-9.]+
|
||||||
|
if (isdigit(lastChar) || lastChar == '.')
|
||||||
|
{
|
||||||
|
std::string numStr;
|
||||||
|
do
|
||||||
|
{
|
||||||
|
numStr += lastChar;
|
||||||
|
lastChar = Token(getNextChar());
|
||||||
|
}
|
||||||
|
while (isdigit(lastChar) || lastChar == '.');
|
||||||
|
|
||||||
|
numVal = strtod(numStr.c_str(), nullptr);
|
||||||
|
return tok_number;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lastChar == '#')
|
||||||
|
{
|
||||||
|
// Comment until end of line.
|
||||||
|
do
|
||||||
|
{
|
||||||
|
lastChar = Token(getNextChar());
|
||||||
|
}
|
||||||
|
while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
|
||||||
|
|
||||||
|
if (lastChar != EOF)
|
||||||
|
return getTok();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for end of file. Don't eat the EOF.
|
||||||
|
if (lastChar == EOF)
|
||||||
|
return tok_eof;
|
||||||
|
|
||||||
|
// Otherwise, just return the character as its ascii value.
|
||||||
|
Token thisChar = Token(lastChar);
|
||||||
|
lastChar = Token(getNextChar());
|
||||||
|
return thisChar;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The last token read from the input.
|
||||||
|
Token curTok = tok_eof;
|
||||||
|
|
||||||
|
/// Location for `curTok`.
|
||||||
|
Location lastLocation;
|
||||||
|
|
||||||
|
/// If the current Token is an identifier, this string contains the value.
|
||||||
|
std::string identifierStr;
|
||||||
|
|
||||||
|
/// If the current Token is a number, this contains the value.
|
||||||
|
double numVal = 0;
|
||||||
|
|
||||||
|
/// The last value returned by getNextChar(). We need to keep it around as we
|
||||||
|
/// always need to read ahead one character to decide when to end a token and
|
||||||
|
/// we can't put it back in the stream after reading from it.
|
||||||
|
Token lastChar = Token(' ');
|
||||||
|
|
||||||
|
/// Keep track of the current line number in the input stream
|
||||||
|
int curLineNum = 0;
|
||||||
|
|
||||||
|
/// Keep track of the current column number in the input stream
|
||||||
|
int curCol = 0;
|
||||||
|
|
||||||
|
/// Buffer supplied by the derived class on calls to `readNextLine()`
|
||||||
|
llvm::StringRef curLineBuffer = "\n";
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A lexer implementation operating on a buffer in memory.
|
||||||
|
class LexerBuffer final : public Lexer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LexerBuffer(const char* begin, const char* end, std::string filename)
|
||||||
|
: Lexer(std::move(filename)), current(begin), end(end)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Provide one line at a time to the Lexer, return an empty string when
|
||||||
|
/// reaching the end of the buffer.
|
||||||
|
llvm::StringRef readNextLine() override
|
||||||
|
{
|
||||||
|
auto* begin = current;
|
||||||
|
while (current <= end && *current && *current != '\n')
|
||||||
|
++current;
|
||||||
|
if (current <= end && *current)
|
||||||
|
++current;
|
||||||
|
|
||||||
|
return llvm::StringRef{begin, static_cast<size_t>(current - begin)};
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *current, *end;
|
||||||
|
};
|
||||||
|
} // namespace toy
|
||||||
|
|
||||||
|
#endif // TOY_LEXER_H
|
539
include/Parser.h
Normal file
539
include/Parser.h
Normal file
|
@ -0,0 +1,539 @@
|
||||||
|
//===- Parser.h - Toy Language Parser -------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements the parser for the Toy language. It processes the Token
|
||||||
|
// provided by the Lexer and returns an AST.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TOY_PARSER_H
|
||||||
|
#define TOY_PARSER_H
|
||||||
|
|
||||||
|
#include <complex>
|
||||||
|
|
||||||
|
#include "SyntaxNode.h"
|
||||||
|
#include "Lexer.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
namespace hello
|
||||||
|
{
|
||||||
|
/// This is a simple recursive parser for the Toy language. It produces a well
|
||||||
|
/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
|
||||||
|
/// or symbol resolution is performed. For example, variables are referenced by
|
||||||
|
/// string and the code could reference an undeclared variable and the parsing
|
||||||
|
/// succeeds.
|
||||||
|
class Parser
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
/// Create a Parser for the supplied lexer.
|
||||||
|
explicit Parser(Lexer& lexer) : lexer(lexer)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a full Module. A module is a list of function definitions.
|
||||||
|
std::unique_ptr<Module> parseModule()
|
||||||
|
{
|
||||||
|
lexer.getNextToken(); // prime the lexer
|
||||||
|
|
||||||
|
// Parse functions one at a time and accumulate in this vector.
|
||||||
|
std::vector<Function> functions;
|
||||||
|
while (auto f = parseDefinition())
|
||||||
|
{
|
||||||
|
functions.push_back(std::move(*f));
|
||||||
|
if (lexer.getCurToken() == tok_eof)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// If we didn't reach EOF, there was an error during parsing
|
||||||
|
if (lexer.getCurToken() != tok_eof)
|
||||||
|
return parseError<Module>("nothing", "at end of module");
|
||||||
|
|
||||||
|
return std::make_unique<Module>(std::move(functions));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Lexer& lexer;
|
||||||
|
|
||||||
|
/// Parse a return statement.
|
||||||
|
/// return :== return ; | return expr ;
|
||||||
|
std::unique_ptr<ReturnExpression> parseReturn()
|
||||||
|
{
|
||||||
|
auto loc = lexer.getLastLocation();
|
||||||
|
lexer.consume(tok_return);
|
||||||
|
|
||||||
|
// return takes an optional argument
|
||||||
|
if (lexer.getCurToken() != ';')
|
||||||
|
{
|
||||||
|
std::unique_ptr<ExpressionNodeBase> expr = parseExpression();
|
||||||
|
if (!expr)
|
||||||
|
{
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_unique<ReturnExpression>(std::move(loc), std::move(expr));
|
||||||
|
}
|
||||||
|
return std::make_unique<ReturnExpression>(std::move(loc));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a literal number.
|
||||||
|
/// numberexpr ::= number
|
||||||
|
std::unique_ptr<ExpressionNodeBase> parseNumberExpr()
|
||||||
|
{
|
||||||
|
auto loc = lexer.getLastLocation();
|
||||||
|
auto result =
|
||||||
|
std::make_unique<NumberExpression>(std::move(loc), lexer.getValue());
|
||||||
|
lexer.consume(tok_number);
|
||||||
|
return std::move(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a literal array expression.
|
||||||
|
/// tensorLiteral ::= [ literalList ] | number
|
||||||
|
/// literalList ::= tensorLiteral | tensorLiteral, literalList
|
||||||
|
std::unique_ptr<ExpressionNodeBase> parseTensorLiteralExpr()
|
||||||
|
{
|
||||||
|
auto loc = lexer.getLastLocation();
|
||||||
|
lexer.consume(Token('['));
|
||||||
|
|
||||||
|
// Hold the list of values at this nesting level.
|
||||||
|
std::vector<std::unique_ptr<ExpressionNodeBase>> values;
|
||||||
|
// Hold the dimensions for all the nesting inside this level.
|
||||||
|
std::vector<int64_t> dims;
|
||||||
|
do
|
||||||
|
{
|
||||||
|
// We can have either another nested array or a number literal.
|
||||||
|
if (lexer.getCurToken() == '[')
|
||||||
|
{
|
||||||
|
values.push_back(parseTensorLiteralExpr());
|
||||||
|
if (!values.back())
|
||||||
|
return nullptr; // parse error in the nested array.
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (lexer.getCurToken() != tok_number)
|
||||||
|
return parseError<ExpressionNodeBase>("<num> or [", "in literal expression");
|
||||||
|
values.push_back(parseNumberExpr());
|
||||||
|
}
|
||||||
|
|
||||||
|
// End of this list on ']'
|
||||||
|
if (lexer.getCurToken() == ']')
|
||||||
|
break;
|
||||||
|
|
||||||
|
// Elements are separated by a comma.
|
||||||
|
if (lexer.getCurToken() != ',')
|
||||||
|
return parseError<ExpressionNodeBase>("] or ,", "in literal expression");
|
||||||
|
|
||||||
|
lexer.getNextToken(); // eat ,
|
||||||
|
}
|
||||||
|
while (true);
|
||||||
|
if (values.empty())
|
||||||
|
return parseError<ExpressionNodeBase>("<something>", "to fill literal expression");
|
||||||
|
lexer.getNextToken(); // eat ]
|
||||||
|
|
||||||
|
/// Fill in the dimensions now. First the current nesting level:
|
||||||
|
dims.push_back(values.size());
|
||||||
|
|
||||||
|
/// If there is any nested array, process all of them and ensure that
|
||||||
|
/// dimensions are uniform.
|
||||||
|
if (llvm::any_of(values, [](std::unique_ptr<ExpressionNodeBase>& expr)
|
||||||
|
{
|
||||||
|
return llvm::isa<LiteralExpression>(expr.get());
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
auto* firstLiteral = llvm::dyn_cast<LiteralExpression>(values.front().get());
|
||||||
|
if (!firstLiteral)
|
||||||
|
return parseError<ExpressionNodeBase>("uniform well-nested dimensions",
|
||||||
|
"inside literal expression");
|
||||||
|
|
||||||
|
// Append the nested dimensions to the current level
|
||||||
|
auto firstDims = firstLiteral->getDimensions();
|
||||||
|
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
|
||||||
|
|
||||||
|
// Sanity check that shape is uniform across all elements of the list.
|
||||||
|
for (auto& expr : values)
|
||||||
|
{
|
||||||
|
auto* exprLiteral = llvm::cast<LiteralExpression>(expr.get());
|
||||||
|
if (!exprLiteral)
|
||||||
|
return parseError<ExpressionNodeBase>("uniform well-nested dimensions",
|
||||||
|
"inside literal expression");
|
||||||
|
if (exprLiteral->getDimensions() != firstDims)
|
||||||
|
return parseError<ExpressionNodeBase>("uniform well-nested dimensions",
|
||||||
|
"inside literal expression");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_unique<LiteralExpression>(std::move(loc), std::move(values),
|
||||||
|
std::move(dims));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// parenexpr ::= '(' expression ')'
|
||||||
|
std::unique_ptr<ExpressionNodeBase> parseParenExpr()
|
||||||
|
{
|
||||||
|
lexer.getNextToken(); // eat (.
|
||||||
|
auto v = parseExpression();
|
||||||
|
if (!v)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
if (lexer.getCurToken() != ')')
|
||||||
|
return parseError<ExpressionNodeBase>(")", "to close expression with parentheses");
|
||||||
|
lexer.consume(Token(')'));
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// identifierexpr
|
||||||
|
/// ::= identifier
|
||||||
|
/// ::= identifier '(' expression ')'
|
||||||
|
std::unique_ptr<ExpressionNodeBase> parseIdentifierExpr()
|
||||||
|
{
|
||||||
|
std::string name(lexer.getId());
|
||||||
|
|
||||||
|
auto loc = lexer.getLastLocation();
|
||||||
|
lexer.getNextToken(); // eat identifier.
|
||||||
|
|
||||||
|
if (lexer.getCurToken() != '(') // Simple variable ref.
|
||||||
|
return std::make_unique<VariableExpression>(std::move(loc), name);
|
||||||
|
|
||||||
|
// This is a function call.
|
||||||
|
lexer.consume(Token('('));
|
||||||
|
std::vector<std::unique_ptr<ExpressionNodeBase>> args;
|
||||||
|
if (lexer.getCurToken() != ')')
|
||||||
|
{
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
if (auto arg = parseExpression())
|
||||||
|
args.push_back(std::move(arg));
|
||||||
|
else
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
if (lexer.getCurToken() == ')')
|
||||||
|
break;
|
||||||
|
|
||||||
|
if (lexer.getCurToken() != ',')
|
||||||
|
return parseError<ExpressionNodeBase>(", or )", "in argument list");
|
||||||
|
lexer.getNextToken();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lexer.consume(Token(')'));
|
||||||
|
|
||||||
|
// It can be a builtin call to print
|
||||||
|
if (name == "print")
|
||||||
|
{
|
||||||
|
if (args.size() != 1)
|
||||||
|
return parseError<ExpressionNodeBase>("<single arg>", "as argument to print()");
|
||||||
|
|
||||||
|
return std::make_unique<PrintExpression>(std::move(loc), std::move(args[0]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call to a user-defined function
|
||||||
|
return std::make_unique<CallExpression>(std::move(loc), name, std::move(args));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// primary
|
||||||
|
/// ::= identifierexpr
|
||||||
|
/// ::= numberexpr
|
||||||
|
/// ::= parenexpr
|
||||||
|
/// ::= tensorliteral
|
||||||
|
std::unique_ptr<ExpressionNodeBase> parsePrimary()
|
||||||
|
{
|
||||||
|
switch (lexer.getCurToken())
|
||||||
|
{
|
||||||
|
default:
|
||||||
|
llvm::errs() << "unknown token '" << lexer.getCurToken()
|
||||||
|
<< "' when expecting an expression\n";
|
||||||
|
return nullptr;
|
||||||
|
case tok_identifier:
|
||||||
|
return parseIdentifierExpr();
|
||||||
|
case tok_number:
|
||||||
|
return parseNumberExpr();
|
||||||
|
case '(':
|
||||||
|
return parseParenExpr();
|
||||||
|
case '[':
|
||||||
|
return parseTensorLiteralExpr();
|
||||||
|
case ';':
|
||||||
|
return nullptr;
|
||||||
|
case '}':
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Recursively parse the right hand side of a binary expression, the ExprPrec
|
||||||
|
/// argument indicates the precedence of the current binary operator.
|
||||||
|
///
|
||||||
|
/// binoprhs ::= ('+' primary)*
|
||||||
|
std::unique_ptr<ExpressionNodeBase> parseBinOpRHS(int exprPrec,
|
||||||
|
std::unique_ptr<ExpressionNodeBase> lhs)
|
||||||
|
{
|
||||||
|
// If this is a binop, find its precedence.
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
int tokPrec = getTokPrecedence();
|
||||||
|
|
||||||
|
// If this is a binop that binds at least as tightly as the current binop,
|
||||||
|
// consume it, otherwise we are done.
|
||||||
|
if (tokPrec < exprPrec)
|
||||||
|
return lhs;
|
||||||
|
|
||||||
|
// Okay, we know this is a binop.
|
||||||
|
int binOp = lexer.getCurToken();
|
||||||
|
lexer.consume(Token(binOp));
|
||||||
|
auto loc = lexer.getLastLocation();
|
||||||
|
|
||||||
|
// Parse the primary expression after the binary operator.
|
||||||
|
auto rhs = parsePrimary();
|
||||||
|
if (!rhs)
|
||||||
|
return parseError<ExpressionNodeBase>("expression", "to complete binary operator");
|
||||||
|
|
||||||
|
// If BinOp binds less tightly with rhs than the operator after rhs, let
|
||||||
|
// the pending operator take rhs as its lhs.
|
||||||
|
int nextPrec = getTokPrecedence();
|
||||||
|
if (tokPrec < nextPrec)
|
||||||
|
{
|
||||||
|
rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
|
||||||
|
if (!rhs)
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge lhs/RHS.
|
||||||
|
lhs = std::make_unique<BinaryExpression>(std::move(loc), binOp,
|
||||||
|
std::move(lhs), std::move(rhs));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// expression::= primary binop rhs
|
||||||
|
std::unique_ptr<ExpressionNodeBase> parseExpression()
|
||||||
|
{
|
||||||
|
auto lhs = parsePrimary();
|
||||||
|
if (!lhs)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
return parseBinOpRHS(0, std::move(lhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// type ::= < shape_list >
|
||||||
|
/// shape_list ::= num | num , shape_list
|
||||||
|
std::unique_ptr<ValueType> parseType()
|
||||||
|
{
|
||||||
|
if (lexer.getCurToken() != '<')
|
||||||
|
return parseError<ValueType>("<", "to begin type");
|
||||||
|
lexer.getNextToken(); // eat <
|
||||||
|
|
||||||
|
auto type = std::make_unique<ValueType>();
|
||||||
|
|
||||||
|
while (lexer.getCurToken() == tok_number)
|
||||||
|
{
|
||||||
|
type->shape.push_back(lexer.getValue());
|
||||||
|
lexer.getNextToken();
|
||||||
|
if (lexer.getCurToken() == ',')
|
||||||
|
lexer.getNextToken();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lexer.getCurToken() != '>')
|
||||||
|
return parseError<ValueType>(">", "to end type");
|
||||||
|
lexer.getNextToken(); // eat >
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a variable declaration, it starts with a `var` keyword followed by
|
||||||
|
/// and identifier and an optional type (shape specification) before the
|
||||||
|
/// initializer.
|
||||||
|
/// decl ::= var identifier [ type ] = expr
|
||||||
|
std::unique_ptr<VariableDeclarationExpression> parseDeclaration()
|
||||||
|
{
|
||||||
|
if (lexer.getCurToken() != tok_var)
|
||||||
|
return parseError<VariableDeclarationExpression>("var", "to begin declaration");
|
||||||
|
auto loc = lexer.getLastLocation();
|
||||||
|
lexer.getNextToken(); // eat var
|
||||||
|
|
||||||
|
if (lexer.getCurToken() != tok_identifier)
|
||||||
|
return parseError<VariableDeclarationExpression>("identified",
|
||||||
|
"after 'var' declaration");
|
||||||
|
std::string id(lexer.getId());
|
||||||
|
lexer.getNextToken(); // eat id
|
||||||
|
|
||||||
|
std::unique_ptr<ValueType> type; // Type is optional, it can be inferred
|
||||||
|
if (lexer.getCurToken() == '<')
|
||||||
|
{
|
||||||
|
type = parseType();
|
||||||
|
if (!type)
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!type)
|
||||||
|
type = std::make_unique<ValueType>();
|
||||||
|
lexer.consume(Token('='));
|
||||||
|
auto expr = parseExpression();
|
||||||
|
return std::make_unique<VariableDeclarationExpression>(std::move(loc), std::move(id),
|
||||||
|
std::move(*type), std::move(expr));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a block: a list of expression separated by semicolons and wrapped in
|
||||||
|
/// curly braces.
|
||||||
|
///
|
||||||
|
/// block ::= { expression_list }
|
||||||
|
/// expression_list ::= block_expr ; expression_list
|
||||||
|
/// block_expr ::= decl | "return" | expr
|
||||||
|
std::unique_ptr<ExpressionList> parseBlock()
|
||||||
|
{
|
||||||
|
if (lexer.getCurToken() != '{')
|
||||||
|
return parseError<ExpressionList>("{", "to begin block");
|
||||||
|
lexer.consume(Token('{'));
|
||||||
|
|
||||||
|
auto exprList = std::make_unique<ExpressionList>();
|
||||||
|
|
||||||
|
// Ignore empty expressions: swallow sequences of semicolons.
|
||||||
|
while (lexer.getCurToken() == ';')
|
||||||
|
lexer.consume(Token(';'));
|
||||||
|
|
||||||
|
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof)
|
||||||
|
{
|
||||||
|
if (lexer.getCurToken() == tok_var)
|
||||||
|
{
|
||||||
|
// Variable declaration
|
||||||
|
auto varDecl = parseDeclaration();
|
||||||
|
if (!varDecl)
|
||||||
|
return nullptr;
|
||||||
|
exprList->push_back(std::move(varDecl));
|
||||||
|
}
|
||||||
|
else if (lexer.getCurToken() == tok_return)
|
||||||
|
{
|
||||||
|
// Return statement
|
||||||
|
auto ret = parseReturn();
|
||||||
|
if (!ret)
|
||||||
|
return nullptr;
|
||||||
|
exprList->push_back(std::move(ret));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// General expression
|
||||||
|
auto expr = parseExpression();
|
||||||
|
if (!expr)
|
||||||
|
return nullptr;
|
||||||
|
exprList->push_back(std::move(expr));
|
||||||
|
}
|
||||||
|
// Ensure that elements are separated by a semicolon.
|
||||||
|
if (lexer.getCurToken() != ';')
|
||||||
|
return parseError<ExpressionList>(";", "after expression");
|
||||||
|
|
||||||
|
// Ignore empty expressions: swallow sequences of semicolons.
|
||||||
|
while (lexer.getCurToken() == ';')
|
||||||
|
lexer.consume(Token(';'));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lexer.getCurToken() != '}')
|
||||||
|
return parseError<ExpressionList>("}", "to close block");
|
||||||
|
|
||||||
|
lexer.consume(Token('}'));
|
||||||
|
return exprList;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// prototype ::= def id '(' decl_list ')'
|
||||||
|
/// decl_list ::= identifier | identifier, decl_list
|
||||||
|
std::unique_ptr<FunctionPrototype> parsePrototype()
|
||||||
|
{
|
||||||
|
auto loc = lexer.getLastLocation();
|
||||||
|
|
||||||
|
if (lexer.getCurToken() != tok_def)
|
||||||
|
return parseError<FunctionPrototype>("def", "in prototype");
|
||||||
|
lexer.consume(tok_def);
|
||||||
|
|
||||||
|
if (lexer.getCurToken() != tok_identifier)
|
||||||
|
return parseError<FunctionPrototype>("function name", "in prototype");
|
||||||
|
|
||||||
|
std::string fnName(lexer.getId());
|
||||||
|
lexer.consume(tok_identifier);
|
||||||
|
|
||||||
|
if (lexer.getCurToken() != '(')
|
||||||
|
return parseError<FunctionPrototype>("(", "in prototype");
|
||||||
|
lexer.consume(Token('('));
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<VariableExpression>> args;
|
||||||
|
if (lexer.getCurToken() != ')')
|
||||||
|
{
|
||||||
|
do
|
||||||
|
{
|
||||||
|
std::string name(lexer.getId());
|
||||||
|
auto loc = lexer.getLastLocation();
|
||||||
|
lexer.consume(tok_identifier);
|
||||||
|
auto decl = std::make_unique<VariableExpression>(std::move(loc), name);
|
||||||
|
args.push_back(std::move(decl));
|
||||||
|
if (lexer.getCurToken() != ',')
|
||||||
|
break;
|
||||||
|
lexer.consume(Token(','));
|
||||||
|
if (lexer.getCurToken() != tok_identifier)
|
||||||
|
return parseError<FunctionPrototype>(
|
||||||
|
"identifier", "after ',' in function parameter list");
|
||||||
|
}
|
||||||
|
while (true);
|
||||||
|
}
|
||||||
|
if (lexer.getCurToken() != ')')
|
||||||
|
return parseError<FunctionPrototype>(")", "to end function prototype");
|
||||||
|
|
||||||
|
// success.
|
||||||
|
lexer.consume(Token(')'));
|
||||||
|
return std::make_unique<FunctionPrototype>(std::move(loc), fnName,
|
||||||
|
std::move(args));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a function definition, we expect a prototype initiated with the
|
||||||
|
/// `def` keyword, followed by a block containing a list of expressions.
|
||||||
|
///
|
||||||
|
/// definition ::= prototype block
|
||||||
|
std::unique_ptr<Function> parseDefinition()
|
||||||
|
{
|
||||||
|
auto proto = parsePrototype();
|
||||||
|
if (!proto)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
if (auto block = parseBlock())
|
||||||
|
return std::make_unique<Function>(std::move(proto), std::move(block));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the precedence of the pending binary operator token.
|
||||||
|
int getTokPrecedence()
|
||||||
|
{
|
||||||
|
if (!isascii(lexer.getCurToken()))
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
// 1 is lowest precedence.
|
||||||
|
switch (static_cast<char>(lexer.getCurToken()))
|
||||||
|
{
|
||||||
|
case '-':
|
||||||
|
return 20;
|
||||||
|
case '+':
|
||||||
|
return 20;
|
||||||
|
case '*':
|
||||||
|
return 40;
|
||||||
|
default:
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to signal errors while parsing, it takes an argument
|
||||||
|
/// indicating the expected token and another argument giving more context.
|
||||||
|
/// Location is retrieved from the lexer to enrich the error message.
|
||||||
|
template <typename R, typename T, typename U = const char*>
|
||||||
|
std::unique_ptr<R> parseError(T&& expected, U&& context = "")
|
||||||
|
{
|
||||||
|
auto curToken = lexer.getCurToken();
|
||||||
|
llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
|
||||||
|
<< lexer.getLastLocation().col << "): expected '" << expected
|
||||||
|
<< "' " << context << " but has Token " << curToken;
|
||||||
|
if (isprint(curToken))
|
||||||
|
llvm::errs() << " '" << (char)curToken << "'";
|
||||||
|
llvm::errs() << "\n";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TOY_PARSER_H
|
358
include/SyntaxNode.h
Normal file
358
include/SyntaxNode.h
Normal file
|
@ -0,0 +1,358 @@
|
||||||
|
//
|
||||||
|
// Created by ricardo on 28/05/25.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SYNTEXNODE_H
|
||||||
|
#define SYNTEXNODE_H
|
||||||
|
#include <vector>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <mlir/IR/Location.h>
|
||||||
|
|
||||||
|
namespace hello
|
||||||
|
{
|
||||||
|
struct ValueType
|
||||||
|
{
|
||||||
|
std::vector<int64_t> shape;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Location
|
||||||
|
{
|
||||||
|
std::shared_ptr<std::string> file;
|
||||||
|
int line;
|
||||||
|
int col;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ExpressionNodeBase
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
enum ExpressionNodeKind
|
||||||
|
{
|
||||||
|
VariableDeclaration,
|
||||||
|
Return,
|
||||||
|
Number,
|
||||||
|
Literal,
|
||||||
|
Variable,
|
||||||
|
BinaryOperation,
|
||||||
|
Call,
|
||||||
|
Print
|
||||||
|
};
|
||||||
|
|
||||||
|
ExpressionNodeBase(ExpressionNodeKind kind, Location location) : kind(kind), location(location)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~ExpressionNodeBase() = default;
|
||||||
|
|
||||||
|
const Location& getLocation() const
|
||||||
|
{
|
||||||
|
return location;
|
||||||
|
}
|
||||||
|
|
||||||
|
ExpressionNodeKind getKind() const
|
||||||
|
{
|
||||||
|
return kind;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const ExpressionNodeKind kind;
|
||||||
|
Location location;
|
||||||
|
};
|
||||||
|
|
||||||
|
using ExpressionList = std::vector<std::unique_ptr<ExpressionNodeBase>>;
|
||||||
|
|
||||||
|
class NumberExpression : public ExpressionNodeBase
|
||||||
|
{
|
||||||
|
double value;
|
||||||
|
|
||||||
|
public:
|
||||||
|
NumberExpression(Location location, const double value) : ExpressionNodeBase(Number,
|
||||||
|
std::move(location)), value(value)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
double getValue() const
|
||||||
|
{
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool classof(const ExpressionNodeBase* c)
|
||||||
|
{
|
||||||
|
return c->getKind() == Number;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LiteralExpression : public ExpressionNodeBase
|
||||||
|
{
|
||||||
|
std::vector<std::unique_ptr<ExpressionNodeBase>> values;
|
||||||
|
std::vector<int64_t> dimensions;
|
||||||
|
|
||||||
|
public:
|
||||||
|
LiteralExpression(Location location, std::vector<std::unique_ptr<ExpressionNodeBase>> values,
|
||||||
|
std::vector<int64_t> dimensions)
|
||||||
|
: ExpressionNodeBase(Literal, std::move(location)), values(std::move(values)),
|
||||||
|
dimensions(std::move(dimensions))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::ArrayRef<std::unique_ptr<ExpressionNodeBase>> getValues()
|
||||||
|
{
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::ArrayRef<int64_t> getDimensions()
|
||||||
|
{
|
||||||
|
return dimensions;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool classof(const ExpressionNodeBase* c)
|
||||||
|
{
|
||||||
|
return c->getKind() == Literal;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class VariableExpression : public ExpressionNodeBase
|
||||||
|
{
|
||||||
|
std::string name;
|
||||||
|
|
||||||
|
public:
|
||||||
|
VariableExpression(Location location, const llvm::StringRef name)
|
||||||
|
: ExpressionNodeBase(Variable, std::move(location)), name(name)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::StringRef getName()
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool classof(const ExpressionNodeBase* c)
|
||||||
|
{
|
||||||
|
return c->getKind() == Variable;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class VariableDeclarationExpression : public ExpressionNodeBase
|
||||||
|
{
|
||||||
|
std::string name;
|
||||||
|
ValueType variableType;
|
||||||
|
std::unique_ptr<ExpressionNodeBase> initialValue;
|
||||||
|
|
||||||
|
public:
|
||||||
|
VariableDeclarationExpression(Location location, const llvm::StringRef name, ValueType type,
|
||||||
|
std::unique_ptr<ExpressionNodeBase> initialValue):
|
||||||
|
ExpressionNodeBase(VariableDeclaration, std::move(location)), name(name), variableType(std::move(type)),
|
||||||
|
initialValue(std::move(initialValue))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::StringRef getName()
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
ExpressionNodeBase* getInitialValue() const
|
||||||
|
{
|
||||||
|
return initialValue.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
const ValueType& getType()
|
||||||
|
{
|
||||||
|
return variableType;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool classof(const ExpressionNodeBase* c)
|
||||||
|
{
|
||||||
|
return c->getKind() == VariableDeclaration;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ReturnExpression : public ExpressionNodeBase
|
||||||
|
{
|
||||||
|
std::optional<std::unique_ptr<ExpressionNodeBase>> returnExpression;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit ReturnExpression(Location location) : ExpressionNodeBase(Return, std::move(location))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
ReturnExpression(Location location, std::unique_ptr<ExpressionNodeBase> expression) : ExpressionNodeBase(
|
||||||
|
Return,
|
||||||
|
std::move(location)), returnExpression(std::make_optional(std::move(expression)))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<ExpressionNodeBase*> getReturnExpression() const
|
||||||
|
{
|
||||||
|
if (returnExpression.has_value())
|
||||||
|
{
|
||||||
|
return returnExpression->get();
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool classof(const ExpressionNodeBase* c)
|
||||||
|
{
|
||||||
|
return c->getKind() == Return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class BinaryExpression : public ExpressionNodeBase
|
||||||
|
{
|
||||||
|
char binaryOperator;
|
||||||
|
std::unique_ptr<ExpressionNodeBase> left, right;
|
||||||
|
|
||||||
|
public:
|
||||||
|
BinaryExpression(Location location, char op, std::unique_ptr<ExpressionNodeBase> left,
|
||||||
|
std::unique_ptr<ExpressionNodeBase> right)
|
||||||
|
: ExpressionNodeBase(BinaryOperation, std::move(location)), binaryOperator(op), left(std::move(left)),
|
||||||
|
right(std::move(right))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
char getOperator() const
|
||||||
|
{
|
||||||
|
return binaryOperator;
|
||||||
|
}
|
||||||
|
|
||||||
|
ExpressionNodeBase* getLeft() const
|
||||||
|
{
|
||||||
|
return left.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
ExpressionNodeBase* getRight() const
|
||||||
|
{
|
||||||
|
return right.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool classof(const ExpressionNodeBase* c)
|
||||||
|
{
|
||||||
|
return c->getKind() == BinaryOperation;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CallExpression : public ExpressionNodeBase
|
||||||
|
{
|
||||||
|
std::string name;
|
||||||
|
std::vector<std::unique_ptr<ExpressionNodeBase>> arguments;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CallExpression(Location location, const std::string& callee,
|
||||||
|
std::vector<std::unique_ptr<ExpressionNodeBase>> arguments)
|
||||||
|
: ExpressionNodeBase(Call, std::move(location)), name(std::move(callee)), arguments(std::move(arguments))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::StringRef getName() const
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::ArrayRef<std::unique_ptr<ExpressionNodeBase>> getArguments() const
|
||||||
|
{
|
||||||
|
return arguments;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool classof(const ExpressionNodeBase* c)
|
||||||
|
{
|
||||||
|
return c->getKind() == Call;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class PrintExpression : public ExpressionNodeBase
|
||||||
|
{
|
||||||
|
std::unique_ptr<ExpressionNodeBase> argument;
|
||||||
|
|
||||||
|
public:
|
||||||
|
PrintExpression(Location location, std::unique_ptr<ExpressionNodeBase> argument)
|
||||||
|
: ExpressionNodeBase(Print, std::move(location)), argument(std::move(argument))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
ExpressionNodeBase* getArgument() const
|
||||||
|
{
|
||||||
|
return argument.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool classof(const ExpressionNodeBase* c)
|
||||||
|
{
|
||||||
|
return c->getKind() == Print;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class FunctionPrototype
|
||||||
|
{
|
||||||
|
Location location;
|
||||||
|
std::string name;
|
||||||
|
std::vector<std::unique_ptr<VariableExpression>> parameters;
|
||||||
|
|
||||||
|
public:
|
||||||
|
FunctionPrototype(Location location, const std::string& name,
|
||||||
|
std::vector<std::unique_ptr<VariableExpression>> parameters): location(std::move(location)),
|
||||||
|
name(name), parameters(std::move(parameters))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
const Location& getLocation() const
|
||||||
|
{
|
||||||
|
return location;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::StringRef getName() const
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::ArrayRef<std::unique_ptr<VariableExpression>> getParameters() const
|
||||||
|
{
|
||||||
|
return parameters;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Function
|
||||||
|
{
|
||||||
|
std::unique_ptr<FunctionPrototype> prototype;
|
||||||
|
std::unique_ptr<ExpressionList> body;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Function(std::unique_ptr<FunctionPrototype> prototype, std::unique_ptr<ExpressionList> body) :
|
||||||
|
prototype(std::move(prototype)),
|
||||||
|
body(std::move(body))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionPrototype* getPrototype() const
|
||||||
|
{
|
||||||
|
return prototype.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
ExpressionList* getBody() const
|
||||||
|
{
|
||||||
|
return body.get();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Module
|
||||||
|
{
|
||||||
|
std::vector<Function> functions;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit Module(std::vector<Function> functions) : functions(std::move(functions))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
auto begin()
|
||||||
|
{
|
||||||
|
return functions.begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto end()
|
||||||
|
{
|
||||||
|
return functions.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
void dump();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //SYNTEXNODE_H
|
245
lib/SyntaxNode.cpp
Normal file
245
lib/SyntaxNode.cpp
Normal file
|
@ -0,0 +1,245 @@
|
||||||
|
//
|
||||||
|
// Created by ricardo on 28/05/25.
|
||||||
|
//
|
||||||
|
#include <llvm/Support/raw_ostream.h>
|
||||||
|
#include <llvm/ADT/TypeSwitch.h>
|
||||||
|
#include "SyntaxNode.h"
|
||||||
|
|
||||||
|
using namespace hello;
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
struct Indent
|
||||||
|
{
|
||||||
|
int& level;
|
||||||
|
|
||||||
|
explicit Indent(int& level) : level(level)
|
||||||
|
{
|
||||||
|
++level;
|
||||||
|
}
|
||||||
|
|
||||||
|
~Indent()
|
||||||
|
{
|
||||||
|
--level;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class SyntaxNodeDumper
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
void dump(Module* module);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static void dump(const ValueType& type);
|
||||||
|
void dump(VariableDeclarationExpression* varDecl);
|
||||||
|
void dump(ExpressionNodeBase* expr);
|
||||||
|
void dump(const ExpressionList* exprList);
|
||||||
|
void dump(NumberExpression* num);
|
||||||
|
void dump(LiteralExpression* node);
|
||||||
|
void dump(VariableExpression* node);
|
||||||
|
void dump(const ReturnExpression* node);
|
||||||
|
void dump(BinaryExpression* node);
|
||||||
|
void dump(CallExpression* node);
|
||||||
|
void dump(PrintExpression* node);
|
||||||
|
void dump(FunctionPrototype* node);
|
||||||
|
void dump(const Function* node);
|
||||||
|
|
||||||
|
void indent() const
|
||||||
|
{
|
||||||
|
for (int i = 0; i < currentIndent; i++)
|
||||||
|
{
|
||||||
|
llvm::outs() << " ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int currentIndent = 0;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static std::string formatLocation(T* node)
|
||||||
|
{
|
||||||
|
const auto& location = node->getLocation();
|
||||||
|
return (llvm::Twine("@") + *location.file + ":" + llvm::Twine(location.line) + ":" + llvm::Twine(location.col)).
|
||||||
|
str();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define INDENT() Indent level(currentIndent); indent();
|
||||||
|
|
||||||
|
|
||||||
|
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
|
||||||
|
void SyntaxNodeDumper::dump(ExpressionNodeBase* expr)
|
||||||
|
{
|
||||||
|
llvm::TypeSwitch<ExpressionNodeBase*>(expr)
|
||||||
|
.Case<BinaryExpression, CallExpression, LiteralExpression, NumberExpression,
|
||||||
|
PrintExpression, ReturnExpression, VariableDeclarationExpression, VariableExpression>(
|
||||||
|
[&](auto* node) { this->dump(node); })
|
||||||
|
.Default([&](ExpressionNodeBase*)
|
||||||
|
{
|
||||||
|
// No match, fallback to a generic message
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A variable declaration is printing the variable name, the type, and then
|
||||||
|
/// recurse in the initializer value.
|
||||||
|
void SyntaxNodeDumper::dump(VariableDeclarationExpression* varDecl)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "VarDecl " << varDecl->getName();
|
||||||
|
dump(varDecl->getType());
|
||||||
|
llvm::outs() << " " << formatLocation(varDecl) << "\n";
|
||||||
|
dump(varDecl->getInitialValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A "block", or a list of expression
|
||||||
|
void SyntaxNodeDumper::dump(const ExpressionList* exprList)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "Block {\n";
|
||||||
|
for (auto& expr : *exprList)
|
||||||
|
dump(expr.get());
|
||||||
|
indent();
|
||||||
|
llvm::outs() << "} // Block\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A literal number, just print the value.
|
||||||
|
void SyntaxNodeDumper::dump(NumberExpression* num)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << num->getValue() << " " << formatLocation(num) << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to print recursively a literal. This handles nested array like:
|
||||||
|
/// [ [ 1, 2 ], [ 3, 4 ] ]
|
||||||
|
/// We print out such array with the dimensions spelled out at every level:
|
||||||
|
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||||
|
void printLitHelper(ExpressionNodeBase* litOrNum)
|
||||||
|
{
|
||||||
|
// Inside a literal expression we can have either a number or another literal
|
||||||
|
if (const auto* num = llvm::dyn_cast<NumberExpression>(litOrNum))
|
||||||
|
{
|
||||||
|
llvm::outs() << num->getValue();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto* literal = llvm::cast<LiteralExpression>(litOrNum);
|
||||||
|
|
||||||
|
// Print the dimension for this literal first
|
||||||
|
llvm::outs() << "<";
|
||||||
|
interleaveComma(literal->getDimensions(), llvm::outs());
|
||||||
|
llvm::outs() << ">";
|
||||||
|
|
||||||
|
// Now print the content, recursing on every element of the list
|
||||||
|
llvm::outs() << "[ ";
|
||||||
|
interleaveComma(literal->getValues(), llvm::outs(),
|
||||||
|
[&](auto& elt) { printLitHelper(elt.get()); });
|
||||||
|
llvm::outs() << "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print a literal, see the recursive helper above for the implementation.
|
||||||
|
void SyntaxNodeDumper::dump(LiteralExpression* node)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "Literal: ";
|
||||||
|
printLitHelper(node);
|
||||||
|
llvm::outs() << " " << formatLocation(node) << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print a variable reference (just a name).
|
||||||
|
void SyntaxNodeDumper::dump(VariableExpression* node)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "var: " << node->getName() << " " << formatLocation(node) << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return statement print the return and its (optional) argument.
|
||||||
|
void SyntaxNodeDumper::dump(const ReturnExpression* node)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "Return\n";
|
||||||
|
if (node->getReturnExpression().has_value())
|
||||||
|
return dump(*node->getReturnExpression());
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "(void)\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
|
||||||
|
void SyntaxNodeDumper::dump(BinaryExpression* node)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "BinOp: " << node->getOperator() << " " << formatLocation(node) << "\n";
|
||||||
|
dump(node->getLeft());
|
||||||
|
dump(node->getRight());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print a call expression, first the callee name and the list of args by
|
||||||
|
/// recursing into each individual argument.
|
||||||
|
void SyntaxNodeDumper::dump(CallExpression* node)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "Call '" << node->getName() << "' [ " << formatLocation(node) << "\n";
|
||||||
|
for (auto& arg : node->getArguments())
|
||||||
|
dump(arg.get());
|
||||||
|
indent();
|
||||||
|
llvm::outs() << "]\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print a builtin print call, first the builtin name and then the argument.
|
||||||
|
void SyntaxNodeDumper::dump(PrintExpression* node)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "Print [ " << formatLocation(node) << "\n";
|
||||||
|
dump(node->getArgument());
|
||||||
|
indent();
|
||||||
|
llvm::outs() << "]\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print type: only the shape is printed in between '<' and '>'
|
||||||
|
void SyntaxNodeDumper::dump(const ValueType& type)
|
||||||
|
{
|
||||||
|
llvm::outs() << "<";
|
||||||
|
interleaveComma(type.shape, llvm::outs());
|
||||||
|
llvm::outs() << ">";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print a function prototype, first the function name, and then the list of
|
||||||
|
/// parameters names.
|
||||||
|
void SyntaxNodeDumper::dump(FunctionPrototype* node)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "Proto '" << node->getName() << "' " << formatLocation(node) << "\n";
|
||||||
|
indent();
|
||||||
|
llvm::outs() << "Params: [";
|
||||||
|
llvm::interleaveComma(node->getParameters(), llvm::outs(),
|
||||||
|
[](auto& arg) { llvm::outs() << arg->getName(); });
|
||||||
|
llvm::outs() << "]\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print a function, first the prototype and then the body.
|
||||||
|
void SyntaxNodeDumper::dump(const Function* node)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "Function \n";
|
||||||
|
dump(node->getPrototype());
|
||||||
|
dump(node->getBody());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print a module, actually loop over the functions and print them in sequence.
|
||||||
|
void SyntaxNodeDumper::dump(Module* module)
|
||||||
|
{
|
||||||
|
INDENT();
|
||||||
|
llvm::outs() << "Module:\n";
|
||||||
|
for (auto& f : *module)
|
||||||
|
dump(&f);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace hello
|
||||||
|
{
|
||||||
|
void Module::dump()
|
||||||
|
{
|
||||||
|
SyntaxNodeDumper().dump(this);
|
||||||
|
}
|
||||||
|
}
|
56
main.cpp
Normal file
56
main.cpp
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
#include "Lexer.h"
|
||||||
|
#include "Parser.h"
|
||||||
|
|
||||||
|
#include <llvm/Support/CommandLine.h>
|
||||||
|
#include <llvm/Support/ErrorOr.h>
|
||||||
|
#include <llvm/Support/MemoryBuffer.h>
|
||||||
|
|
||||||
|
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
|
||||||
|
llvm::cl::desc("<input hello file>"),
|
||||||
|
llvm::cl::init("-"),
|
||||||
|
llvm::cl::value_desc("filename"));
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
enum Action { None, DumpSyntaxNode };
|
||||||
|
}
|
||||||
|
|
||||||
|
static llvm::cl::opt<Action> emitAction("emit", llvm::cl::desc("Select the kind of output desired"),
|
||||||
|
llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node")));
|
||||||
|
|
||||||
|
std::unique_ptr<hello::Module> parseInputFile(llvm::StringRef filename)
|
||||||
|
{
|
||||||
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
|
llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
||||||
|
if (std::error_code ec = fileOrErr.getError())
|
||||||
|
{
|
||||||
|
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto buffer = fileOrErr.get()->getBuffer();
|
||||||
|
hello::LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename));
|
||||||
|
hello::Parser parser(lexer);
|
||||||
|
return parser.parseModule();
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n");
|
||||||
|
|
||||||
|
auto module = parseInputFile(inputFilename);
|
||||||
|
|
||||||
|
if (!module)
|
||||||
|
{
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (emitAction)
|
||||||
|
{
|
||||||
|
case DumpSyntaxNode:
|
||||||
|
module->dump();
|
||||||
|
return 0;
|
||||||
|
default:
|
||||||
|
llvm::errs() << "Unrecognized action\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user