540 lines
20 KiB
C++
540 lines
20 KiB
C++
//===- Parser.h - Toy Language Parser -------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the parser for the Toy language. It processes the Token
|
|
// provided by the Lexer and returns an AST.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef TOY_PARSER_H
|
|
#define TOY_PARSER_H
|
|
|
|
#include <complex>
|
|
|
|
#include "SyntaxNode.h"
|
|
#include "Lexer.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
#include <utility>
|
|
#include <vector>
|
|
#include <optional>
|
|
|
|
namespace hello
|
|
{
|
|
/// This is a simple recursive parser for the Toy language. It produces a well
|
|
/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
|
|
/// or symbol resolution is performed. For example, variables are referenced by
|
|
/// string and the code could reference an undeclared variable and the parsing
|
|
/// succeeds.
|
|
class Parser
|
|
{
|
|
public:
|
|
/// Create a Parser for the supplied lexer.
|
|
explicit Parser(Lexer& lexer) : lexer(lexer)
|
|
{
|
|
}
|
|
|
|
/// Parse a full Module. A module is a list of function definitions.
|
|
std::unique_ptr<Module> parseModule()
|
|
{
|
|
lexer.getNextToken(); // prime the lexer
|
|
|
|
// Parse functions one at a time and accumulate in this vector.
|
|
std::vector<Function> functions;
|
|
while (auto f = parseDefinition())
|
|
{
|
|
functions.push_back(std::move(*f));
|
|
if (lexer.getCurToken() == tok_eof)
|
|
break;
|
|
}
|
|
// If we didn't reach EOF, there was an error during parsing
|
|
if (lexer.getCurToken() != tok_eof)
|
|
return parseError<Module>("nothing", "at end of module");
|
|
|
|
return std::make_unique<Module>(std::move(functions));
|
|
}
|
|
|
|
private:
|
|
Lexer& lexer;
|
|
|
|
/// Parse a return statement.
|
|
/// return :== return ; | return expr ;
|
|
std::unique_ptr<ReturnExpression> parseReturn()
|
|
{
|
|
auto loc = lexer.getLastLocation();
|
|
lexer.consume(tok_return);
|
|
|
|
// return takes an optional argument
|
|
if (lexer.getCurToken() != ';')
|
|
{
|
|
std::unique_ptr<ExpressionNodeBase> expr = parseExpression();
|
|
if (!expr)
|
|
{
|
|
return nullptr;
|
|
}
|
|
|
|
return std::make_unique<ReturnExpression>(std::move(loc), std::move(expr));
|
|
}
|
|
return std::make_unique<ReturnExpression>(std::move(loc));
|
|
}
|
|
|
|
/// Parse a literal number.
|
|
/// numberexpr ::= number
|
|
std::unique_ptr<ExpressionNodeBase> parseNumberExpr()
|
|
{
|
|
auto loc = lexer.getLastLocation();
|
|
auto result =
|
|
std::make_unique<NumberExpression>(std::move(loc), lexer.getValue());
|
|
lexer.consume(tok_number);
|
|
return std::move(result);
|
|
}
|
|
|
|
/// Parse a literal array expression.
|
|
/// tensorLiteral ::= [ literalList ] | number
|
|
/// literalList ::= tensorLiteral | tensorLiteral, literalList
|
|
std::unique_ptr<ExpressionNodeBase> parseTensorLiteralExpr()
|
|
{
|
|
auto loc = lexer.getLastLocation();
|
|
lexer.consume(Token('['));
|
|
|
|
// Hold the list of values at this nesting level.
|
|
std::vector<std::unique_ptr<ExpressionNodeBase>> values;
|
|
// Hold the dimensions for all the nesting inside this level.
|
|
std::vector<int64_t> dims;
|
|
do
|
|
{
|
|
// We can have either another nested array or a number literal.
|
|
if (lexer.getCurToken() == '[')
|
|
{
|
|
values.push_back(parseTensorLiteralExpr());
|
|
if (!values.back())
|
|
return nullptr; // parse error in the nested array.
|
|
}
|
|
else
|
|
{
|
|
if (lexer.getCurToken() != tok_number)
|
|
return parseError<ExpressionNodeBase>("<num> or [", "in literal expression");
|
|
values.push_back(parseNumberExpr());
|
|
}
|
|
|
|
// End of this list on ']'
|
|
if (lexer.getCurToken() == ']')
|
|
break;
|
|
|
|
// Elements are separated by a comma.
|
|
if (lexer.getCurToken() != ',')
|
|
return parseError<ExpressionNodeBase>("] or ,", "in literal expression");
|
|
|
|
lexer.getNextToken(); // eat ,
|
|
}
|
|
while (true);
|
|
if (values.empty())
|
|
return parseError<ExpressionNodeBase>("<something>", "to fill literal expression");
|
|
lexer.getNextToken(); // eat ]
|
|
|
|
/// Fill in the dimensions now. First the current nesting level:
|
|
dims.push_back(values.size());
|
|
|
|
/// If there is any nested array, process all of them and ensure that
|
|
/// dimensions are uniform.
|
|
if (llvm::any_of(values, [](std::unique_ptr<ExpressionNodeBase>& expr)
|
|
{
|
|
return llvm::isa<LiteralExpression>(expr.get());
|
|
}))
|
|
{
|
|
auto* firstLiteral = llvm::dyn_cast<LiteralExpression>(values.front().get());
|
|
if (!firstLiteral)
|
|
return parseError<ExpressionNodeBase>("uniform well-nested dimensions",
|
|
"inside literal expression");
|
|
|
|
// Append the nested dimensions to the current level
|
|
auto firstDims = firstLiteral->getDimensions();
|
|
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
|
|
|
|
// Sanity check that shape is uniform across all elements of the list.
|
|
for (auto& expr : values)
|
|
{
|
|
auto* exprLiteral = llvm::cast<LiteralExpression>(expr.get());
|
|
if (!exprLiteral)
|
|
return parseError<ExpressionNodeBase>("uniform well-nested dimensions",
|
|
"inside literal expression");
|
|
if (exprLiteral->getDimensions() != firstDims)
|
|
return parseError<ExpressionNodeBase>("uniform well-nested dimensions",
|
|
"inside literal expression");
|
|
}
|
|
}
|
|
return std::make_unique<LiteralExpression>(std::move(loc), std::move(values),
|
|
std::move(dims));
|
|
}
|
|
|
|
/// parenexpr ::= '(' expression ')'
|
|
std::unique_ptr<ExpressionNodeBase> parseParenExpr()
|
|
{
|
|
lexer.getNextToken(); // eat (.
|
|
auto v = parseExpression();
|
|
if (!v)
|
|
return nullptr;
|
|
|
|
if (lexer.getCurToken() != ')')
|
|
return parseError<ExpressionNodeBase>(")", "to close expression with parentheses");
|
|
lexer.consume(Token(')'));
|
|
return v;
|
|
}
|
|
|
|
/// identifierexpr
|
|
/// ::= identifier
|
|
/// ::= identifier '(' expression ')'
|
|
std::unique_ptr<ExpressionNodeBase> parseIdentifierExpr()
|
|
{
|
|
std::string name(lexer.getId());
|
|
|
|
auto loc = lexer.getLastLocation();
|
|
lexer.getNextToken(); // eat identifier.
|
|
|
|
if (lexer.getCurToken() != '(') // Simple variable ref.
|
|
return std::make_unique<VariableExpression>(std::move(loc), name);
|
|
|
|
// This is a function call.
|
|
lexer.consume(Token('('));
|
|
std::vector<std::unique_ptr<ExpressionNodeBase>> args;
|
|
if (lexer.getCurToken() != ')')
|
|
{
|
|
while (true)
|
|
{
|
|
if (auto arg = parseExpression())
|
|
args.push_back(std::move(arg));
|
|
else
|
|
return nullptr;
|
|
|
|
if (lexer.getCurToken() == ')')
|
|
break;
|
|
|
|
if (lexer.getCurToken() != ',')
|
|
return parseError<ExpressionNodeBase>(", or )", "in argument list");
|
|
lexer.getNextToken();
|
|
}
|
|
}
|
|
lexer.consume(Token(')'));
|
|
|
|
// It can be a builtin call to print
|
|
if (name == "print")
|
|
{
|
|
if (args.size() != 1)
|
|
return parseError<ExpressionNodeBase>("<single arg>", "as argument to print()");
|
|
|
|
return std::make_unique<PrintExpression>(std::move(loc), std::move(args[0]));
|
|
}
|
|
|
|
// Call to a user-defined function
|
|
return std::make_unique<CallExpression>(std::move(loc), name, std::move(args));
|
|
}
|
|
|
|
/// primary
|
|
/// ::= identifierexpr
|
|
/// ::= numberexpr
|
|
/// ::= parenexpr
|
|
/// ::= tensorliteral
|
|
std::unique_ptr<ExpressionNodeBase> parsePrimary()
|
|
{
|
|
switch (lexer.getCurToken())
|
|
{
|
|
default:
|
|
llvm::errs() << "unknown token '" << lexer.getCurToken()
|
|
<< "' when expecting an expression\n";
|
|
return nullptr;
|
|
case tok_identifier:
|
|
return parseIdentifierExpr();
|
|
case tok_number:
|
|
return parseNumberExpr();
|
|
case '(':
|
|
return parseParenExpr();
|
|
case '[':
|
|
return parseTensorLiteralExpr();
|
|
case ';':
|
|
return nullptr;
|
|
case '}':
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
/// Recursively parse the right hand side of a binary expression, the ExprPrec
|
|
/// argument indicates the precedence of the current binary operator.
|
|
///
|
|
/// binoprhs ::= ('+' primary)*
|
|
std::unique_ptr<ExpressionNodeBase> parseBinOpRHS(int exprPrec,
|
|
std::unique_ptr<ExpressionNodeBase> lhs)
|
|
{
|
|
// If this is a binop, find its precedence.
|
|
while (true)
|
|
{
|
|
int tokPrec = getTokPrecedence();
|
|
|
|
// If this is a binop that binds at least as tightly as the current binop,
|
|
// consume it, otherwise we are done.
|
|
if (tokPrec < exprPrec)
|
|
return lhs;
|
|
|
|
// Okay, we know this is a binop.
|
|
int binOp = lexer.getCurToken();
|
|
lexer.consume(Token(binOp));
|
|
auto loc = lexer.getLastLocation();
|
|
|
|
// Parse the primary expression after the binary operator.
|
|
auto rhs = parsePrimary();
|
|
if (!rhs)
|
|
return parseError<ExpressionNodeBase>("expression", "to complete binary operator");
|
|
|
|
// If BinOp binds less tightly with rhs than the operator after rhs, let
|
|
// the pending operator take rhs as its lhs.
|
|
int nextPrec = getTokPrecedence();
|
|
if (tokPrec < nextPrec)
|
|
{
|
|
rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
|
|
if (!rhs)
|
|
return nullptr;
|
|
}
|
|
|
|
// Merge lhs/RHS.
|
|
lhs = std::make_unique<BinaryExpression>(std::move(loc), binOp,
|
|
std::move(lhs), std::move(rhs));
|
|
}
|
|
}
|
|
|
|
/// expression::= primary binop rhs
|
|
std::unique_ptr<ExpressionNodeBase> parseExpression()
|
|
{
|
|
auto lhs = parsePrimary();
|
|
if (!lhs)
|
|
return nullptr;
|
|
|
|
return parseBinOpRHS(0, std::move(lhs));
|
|
}
|
|
|
|
/// type ::= < shape_list >
|
|
/// shape_list ::= num | num , shape_list
|
|
std::unique_ptr<ValueType> parseType()
|
|
{
|
|
if (lexer.getCurToken() != '<')
|
|
return parseError<ValueType>("<", "to begin type");
|
|
lexer.getNextToken(); // eat <
|
|
|
|
auto type = std::make_unique<ValueType>();
|
|
|
|
while (lexer.getCurToken() == tok_number)
|
|
{
|
|
type->shape.push_back(lexer.getValue());
|
|
lexer.getNextToken();
|
|
if (lexer.getCurToken() == ',')
|
|
lexer.getNextToken();
|
|
}
|
|
|
|
if (lexer.getCurToken() != '>')
|
|
return parseError<ValueType>(">", "to end type");
|
|
lexer.getNextToken(); // eat >
|
|
return type;
|
|
}
|
|
|
|
/// Parse a variable declaration, it starts with a `var` keyword followed by
|
|
/// and identifier and an optional type (shape specification) before the
|
|
/// initializer.
|
|
/// decl ::= var identifier [ type ] = expr
|
|
std::unique_ptr<VariableDeclarationExpression> parseDeclaration()
|
|
{
|
|
if (lexer.getCurToken() != tok_var)
|
|
return parseError<VariableDeclarationExpression>("var", "to begin declaration");
|
|
auto loc = lexer.getLastLocation();
|
|
lexer.getNextToken(); // eat var
|
|
|
|
if (lexer.getCurToken() != tok_identifier)
|
|
return parseError<VariableDeclarationExpression>("identified",
|
|
"after 'var' declaration");
|
|
std::string id(lexer.getId());
|
|
lexer.getNextToken(); // eat id
|
|
|
|
std::unique_ptr<ValueType> type; // Type is optional, it can be inferred
|
|
if (lexer.getCurToken() == '<')
|
|
{
|
|
type = parseType();
|
|
if (!type)
|
|
return nullptr;
|
|
}
|
|
|
|
if (!type)
|
|
type = std::make_unique<ValueType>();
|
|
lexer.consume(Token('='));
|
|
auto expr = parseExpression();
|
|
return std::make_unique<VariableDeclarationExpression>(std::move(loc), std::move(id),
|
|
std::move(*type), std::move(expr));
|
|
}
|
|
|
|
/// Parse a block: a list of expression separated by semicolons and wrapped in
|
|
/// curly braces.
|
|
///
|
|
/// block ::= { expression_list }
|
|
/// expression_list ::= block_expr ; expression_list
|
|
/// block_expr ::= decl | "return" | expr
|
|
std::unique_ptr<ExpressionList> parseBlock()
|
|
{
|
|
if (lexer.getCurToken() != '{')
|
|
return parseError<ExpressionList>("{", "to begin block");
|
|
lexer.consume(Token('{'));
|
|
|
|
auto exprList = std::make_unique<ExpressionList>();
|
|
|
|
// Ignore empty expressions: swallow sequences of semicolons.
|
|
while (lexer.getCurToken() == ';')
|
|
lexer.consume(Token(';'));
|
|
|
|
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof)
|
|
{
|
|
if (lexer.getCurToken() == tok_var)
|
|
{
|
|
// Variable declaration
|
|
auto varDecl = parseDeclaration();
|
|
if (!varDecl)
|
|
return nullptr;
|
|
exprList->push_back(std::move(varDecl));
|
|
}
|
|
else if (lexer.getCurToken() == tok_return)
|
|
{
|
|
// Return statement
|
|
auto ret = parseReturn();
|
|
if (!ret)
|
|
return nullptr;
|
|
exprList->push_back(std::move(ret));
|
|
}
|
|
else
|
|
{
|
|
// General expression
|
|
auto expr = parseExpression();
|
|
if (!expr)
|
|
return nullptr;
|
|
exprList->push_back(std::move(expr));
|
|
}
|
|
// Ensure that elements are separated by a semicolon.
|
|
if (lexer.getCurToken() != ';')
|
|
return parseError<ExpressionList>(";", "after expression");
|
|
|
|
// Ignore empty expressions: swallow sequences of semicolons.
|
|
while (lexer.getCurToken() == ';')
|
|
lexer.consume(Token(';'));
|
|
}
|
|
|
|
if (lexer.getCurToken() != '}')
|
|
return parseError<ExpressionList>("}", "to close block");
|
|
|
|
lexer.consume(Token('}'));
|
|
return exprList;
|
|
}
|
|
|
|
/// prototype ::= def id '(' decl_list ')'
|
|
/// decl_list ::= identifier | identifier, decl_list
|
|
std::unique_ptr<FunctionPrototype> parsePrototype()
|
|
{
|
|
auto loc = lexer.getLastLocation();
|
|
|
|
if (lexer.getCurToken() != tok_def)
|
|
return parseError<FunctionPrototype>("def", "in prototype");
|
|
lexer.consume(tok_def);
|
|
|
|
if (lexer.getCurToken() != tok_identifier)
|
|
return parseError<FunctionPrototype>("function name", "in prototype");
|
|
|
|
std::string fnName(lexer.getId());
|
|
lexer.consume(tok_identifier);
|
|
|
|
if (lexer.getCurToken() != '(')
|
|
return parseError<FunctionPrototype>("(", "in prototype");
|
|
lexer.consume(Token('('));
|
|
|
|
std::vector<std::unique_ptr<VariableExpression>> args;
|
|
if (lexer.getCurToken() != ')')
|
|
{
|
|
do
|
|
{
|
|
std::string name(lexer.getId());
|
|
auto loc = lexer.getLastLocation();
|
|
lexer.consume(tok_identifier);
|
|
auto decl = std::make_unique<VariableExpression>(std::move(loc), name);
|
|
args.push_back(std::move(decl));
|
|
if (lexer.getCurToken() != ',')
|
|
break;
|
|
lexer.consume(Token(','));
|
|
if (lexer.getCurToken() != tok_identifier)
|
|
return parseError<FunctionPrototype>(
|
|
"identifier", "after ',' in function parameter list");
|
|
}
|
|
while (true);
|
|
}
|
|
if (lexer.getCurToken() != ')')
|
|
return parseError<FunctionPrototype>(")", "to end function prototype");
|
|
|
|
// success.
|
|
lexer.consume(Token(')'));
|
|
return std::make_unique<FunctionPrototype>(std::move(loc), fnName,
|
|
std::move(args));
|
|
}
|
|
|
|
/// Parse a function definition, we expect a prototype initiated with the
|
|
/// `def` keyword, followed by a block containing a list of expressions.
|
|
///
|
|
/// definition ::= prototype block
|
|
std::unique_ptr<Function> parseDefinition()
|
|
{
|
|
auto proto = parsePrototype();
|
|
if (!proto)
|
|
return nullptr;
|
|
|
|
if (auto block = parseBlock())
|
|
return std::make_unique<Function>(std::move(proto), std::move(block));
|
|
return nullptr;
|
|
}
|
|
|
|
/// Get the precedence of the pending binary operator token.
|
|
int getTokPrecedence()
|
|
{
|
|
if (!isascii(lexer.getCurToken()))
|
|
return -1;
|
|
|
|
// 1 is lowest precedence.
|
|
switch (static_cast<char>(lexer.getCurToken()))
|
|
{
|
|
case '-':
|
|
return 20;
|
|
case '+':
|
|
return 20;
|
|
case '*':
|
|
return 40;
|
|
default:
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
/// Helper function to signal errors while parsing, it takes an argument
|
|
/// indicating the expected token and another argument giving more context.
|
|
/// Location is retrieved from the lexer to enrich the error message.
|
|
template <typename R, typename T, typename U = const char*>
|
|
std::unique_ptr<R> parseError(T&& expected, U&& context = "")
|
|
{
|
|
auto curToken = lexer.getCurToken();
|
|
llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
|
|
<< lexer.getLastLocation().col << "): expected '" << expected
|
|
<< "' " << context << " but has Token " << curToken;
|
|
if (isprint(curToken))
|
|
llvm::errs() << " '" << (char)curToken << "'";
|
|
llvm::errs() << "\n";
|
|
return nullptr;
|
|
}
|
|
};
|
|
}
|
|
|
|
#endif // TOY_PARSER_H
|