hello-mlir/include/Parser.h

540 lines
20 KiB
C++

//===- Parser.h - Toy Language Parser -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the parser for the Toy language. It processes the Token
// provided by the Lexer and returns an AST.
//
//===----------------------------------------------------------------------===//
#ifndef TOY_PARSER_H
#define TOY_PARSER_H
#include <complex>
#include "SyntaxNode.h"
#include "Lexer.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <utility>
#include <vector>
#include <optional>
namespace hello
{
/// This is a simple recursive parser for the Toy language. It produces a well
/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
/// or symbol resolution is performed. For example, variables are referenced by
/// string and the code could reference an undeclared variable and the parsing
/// succeeds.
class Parser
{
public:
/// Create a Parser for the supplied lexer.
explicit Parser(Lexer& lexer) : lexer(lexer)
{
}
/// Parse a full Module. A module is a list of function definitions.
std::unique_ptr<Module> parseModule()
{
lexer.getNextToken(); // prime the lexer
// Parse functions one at a time and accumulate in this vector.
std::vector<Function> functions;
while (auto f = parseDefinition())
{
functions.push_back(std::move(*f));
if (lexer.getCurToken() == tok_eof)
break;
}
// If we didn't reach EOF, there was an error during parsing
if (lexer.getCurToken() != tok_eof)
return parseError<Module>("nothing", "at end of module");
return std::make_unique<Module>(std::move(functions));
}
private:
Lexer& lexer;
/// Parse a return statement.
/// return :== return ; | return expr ;
std::unique_ptr<ReturnExpression> parseReturn()
{
auto loc = lexer.getLastLocation();
lexer.consume(tok_return);
// return takes an optional argument
if (lexer.getCurToken() != ';')
{
std::unique_ptr<ExpressionNodeBase> expr = parseExpression();
if (!expr)
{
return nullptr;
}
return std::make_unique<ReturnExpression>(std::move(loc), std::move(expr));
}
return std::make_unique<ReturnExpression>(std::move(loc));
}
/// Parse a literal number.
/// numberexpr ::= number
std::unique_ptr<ExpressionNodeBase> parseNumberExpr()
{
auto loc = lexer.getLastLocation();
auto result =
std::make_unique<NumberExpression>(std::move(loc), lexer.getValue());
lexer.consume(tok_number);
return std::move(result);
}
/// Parse a literal array expression.
/// tensorLiteral ::= [ literalList ] | number
/// literalList ::= tensorLiteral | tensorLiteral, literalList
std::unique_ptr<ExpressionNodeBase> parseTensorLiteralExpr()
{
auto loc = lexer.getLastLocation();
lexer.consume(Token('['));
// Hold the list of values at this nesting level.
std::vector<std::unique_ptr<ExpressionNodeBase>> values;
// Hold the dimensions for all the nesting inside this level.
std::vector<int64_t> dims;
do
{
// We can have either another nested array or a number literal.
if (lexer.getCurToken() == '[')
{
values.push_back(parseTensorLiteralExpr());
if (!values.back())
return nullptr; // parse error in the nested array.
}
else
{
if (lexer.getCurToken() != tok_number)
return parseError<ExpressionNodeBase>("<num> or [", "in literal expression");
values.push_back(parseNumberExpr());
}
// End of this list on ']'
if (lexer.getCurToken() == ']')
break;
// Elements are separated by a comma.
if (lexer.getCurToken() != ',')
return parseError<ExpressionNodeBase>("] or ,", "in literal expression");
lexer.getNextToken(); // eat ,
}
while (true);
if (values.empty())
return parseError<ExpressionNodeBase>("<something>", "to fill literal expression");
lexer.getNextToken(); // eat ]
/// Fill in the dimensions now. First the current nesting level:
dims.push_back(values.size());
/// If there is any nested array, process all of them and ensure that
/// dimensions are uniform.
if (llvm::any_of(values, [](std::unique_ptr<ExpressionNodeBase>& expr)
{
return llvm::isa<LiteralExpression>(expr.get());
}))
{
auto* firstLiteral = llvm::dyn_cast<LiteralExpression>(values.front().get());
if (!firstLiteral)
return parseError<ExpressionNodeBase>("uniform well-nested dimensions",
"inside literal expression");
// Append the nested dimensions to the current level
auto firstDims = firstLiteral->getDimensions();
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
// Sanity check that shape is uniform across all elements of the list.
for (auto& expr : values)
{
auto* exprLiteral = llvm::cast<LiteralExpression>(expr.get());
if (!exprLiteral)
return parseError<ExpressionNodeBase>("uniform well-nested dimensions",
"inside literal expression");
if (exprLiteral->getDimensions() != firstDims)
return parseError<ExpressionNodeBase>("uniform well-nested dimensions",
"inside literal expression");
}
}
return std::make_unique<LiteralExpression>(std::move(loc), std::move(values),
std::move(dims));
}
/// parenexpr ::= '(' expression ')'
std::unique_ptr<ExpressionNodeBase> parseParenExpr()
{
lexer.getNextToken(); // eat (.
auto v = parseExpression();
if (!v)
return nullptr;
if (lexer.getCurToken() != ')')
return parseError<ExpressionNodeBase>(")", "to close expression with parentheses");
lexer.consume(Token(')'));
return v;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression ')'
std::unique_ptr<ExpressionNodeBase> parseIdentifierExpr()
{
std::string name(lexer.getId());
auto loc = lexer.getLastLocation();
lexer.getNextToken(); // eat identifier.
if (lexer.getCurToken() != '(') // Simple variable ref.
return std::make_unique<VariableExpression>(std::move(loc), name);
// This is a function call.
lexer.consume(Token('('));
std::vector<std::unique_ptr<ExpressionNodeBase>> args;
if (lexer.getCurToken() != ')')
{
while (true)
{
if (auto arg = parseExpression())
args.push_back(std::move(arg));
else
return nullptr;
if (lexer.getCurToken() == ')')
break;
if (lexer.getCurToken() != ',')
return parseError<ExpressionNodeBase>(", or )", "in argument list");
lexer.getNextToken();
}
}
lexer.consume(Token(')'));
// It can be a builtin call to print
if (name == "print")
{
if (args.size() != 1)
return parseError<ExpressionNodeBase>("<single arg>", "as argument to print()");
return std::make_unique<PrintExpression>(std::move(loc), std::move(args[0]));
}
// Call to a user-defined function
return std::make_unique<CallExpression>(std::move(loc), name, std::move(args));
}
/// primary
/// ::= identifierexpr
/// ::= numberexpr
/// ::= parenexpr
/// ::= tensorliteral
std::unique_ptr<ExpressionNodeBase> parsePrimary()
{
switch (lexer.getCurToken())
{
default:
llvm::errs() << "unknown token '" << lexer.getCurToken()
<< "' when expecting an expression\n";
return nullptr;
case tok_identifier:
return parseIdentifierExpr();
case tok_number:
return parseNumberExpr();
case '(':
return parseParenExpr();
case '[':
return parseTensorLiteralExpr();
case ';':
return nullptr;
case '}':
return nullptr;
}
}
/// Recursively parse the right hand side of a binary expression, the ExprPrec
/// argument indicates the precedence of the current binary operator.
///
/// binoprhs ::= ('+' primary)*
std::unique_ptr<ExpressionNodeBase> parseBinOpRHS(int exprPrec,
std::unique_ptr<ExpressionNodeBase> lhs)
{
// If this is a binop, find its precedence.
while (true)
{
int tokPrec = getTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
if (tokPrec < exprPrec)
return lhs;
// Okay, we know this is a binop.
int binOp = lexer.getCurToken();
lexer.consume(Token(binOp));
auto loc = lexer.getLastLocation();
// Parse the primary expression after the binary operator.
auto rhs = parsePrimary();
if (!rhs)
return parseError<ExpressionNodeBase>("expression", "to complete binary operator");
// If BinOp binds less tightly with rhs than the operator after rhs, let
// the pending operator take rhs as its lhs.
int nextPrec = getTokPrecedence();
if (tokPrec < nextPrec)
{
rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
if (!rhs)
return nullptr;
}
// Merge lhs/RHS.
lhs = std::make_unique<BinaryExpression>(std::move(loc), binOp,
std::move(lhs), std::move(rhs));
}
}
/// expression::= primary binop rhs
std::unique_ptr<ExpressionNodeBase> parseExpression()
{
auto lhs = parsePrimary();
if (!lhs)
return nullptr;
return parseBinOpRHS(0, std::move(lhs));
}
/// type ::= < shape_list >
/// shape_list ::= num | num , shape_list
std::unique_ptr<ValueType> parseType()
{
if (lexer.getCurToken() != '<')
return parseError<ValueType>("<", "to begin type");
lexer.getNextToken(); // eat <
auto type = std::make_unique<ValueType>();
while (lexer.getCurToken() == tok_number)
{
type->shape.push_back(lexer.getValue());
lexer.getNextToken();
if (lexer.getCurToken() == ',')
lexer.getNextToken();
}
if (lexer.getCurToken() != '>')
return parseError<ValueType>(">", "to end type");
lexer.getNextToken(); // eat >
return type;
}
/// Parse a variable declaration, it starts with a `var` keyword followed by
/// and identifier and an optional type (shape specification) before the
/// initializer.
/// decl ::= var identifier [ type ] = expr
std::unique_ptr<VariableDeclarationExpression> parseDeclaration()
{
if (lexer.getCurToken() != tok_var)
return parseError<VariableDeclarationExpression>("var", "to begin declaration");
auto loc = lexer.getLastLocation();
lexer.getNextToken(); // eat var
if (lexer.getCurToken() != tok_identifier)
return parseError<VariableDeclarationExpression>("identified",
"after 'var' declaration");
std::string id(lexer.getId());
lexer.getNextToken(); // eat id
std::unique_ptr<ValueType> type; // Type is optional, it can be inferred
if (lexer.getCurToken() == '<')
{
type = parseType();
if (!type)
return nullptr;
}
if (!type)
type = std::make_unique<ValueType>();
lexer.consume(Token('='));
auto expr = parseExpression();
return std::make_unique<VariableDeclarationExpression>(std::move(loc), std::move(id),
std::move(*type), std::move(expr));
}
/// Parse a block: a list of expression separated by semicolons and wrapped in
/// curly braces.
///
/// block ::= { expression_list }
/// expression_list ::= block_expr ; expression_list
/// block_expr ::= decl | "return" | expr
std::unique_ptr<ExpressionList> parseBlock()
{
if (lexer.getCurToken() != '{')
return parseError<ExpressionList>("{", "to begin block");
lexer.consume(Token('{'));
auto exprList = std::make_unique<ExpressionList>();
// Ignore empty expressions: swallow sequences of semicolons.
while (lexer.getCurToken() == ';')
lexer.consume(Token(';'));
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof)
{
if (lexer.getCurToken() == tok_var)
{
// Variable declaration
auto varDecl = parseDeclaration();
if (!varDecl)
return nullptr;
exprList->push_back(std::move(varDecl));
}
else if (lexer.getCurToken() == tok_return)
{
// Return statement
auto ret = parseReturn();
if (!ret)
return nullptr;
exprList->push_back(std::move(ret));
}
else
{
// General expression
auto expr = parseExpression();
if (!expr)
return nullptr;
exprList->push_back(std::move(expr));
}
// Ensure that elements are separated by a semicolon.
if (lexer.getCurToken() != ';')
return parseError<ExpressionList>(";", "after expression");
// Ignore empty expressions: swallow sequences of semicolons.
while (lexer.getCurToken() == ';')
lexer.consume(Token(';'));
}
if (lexer.getCurToken() != '}')
return parseError<ExpressionList>("}", "to close block");
lexer.consume(Token('}'));
return exprList;
}
/// prototype ::= def id '(' decl_list ')'
/// decl_list ::= identifier | identifier, decl_list
std::unique_ptr<FunctionPrototype> parsePrototype()
{
auto loc = lexer.getLastLocation();
if (lexer.getCurToken() != tok_def)
return parseError<FunctionPrototype>("def", "in prototype");
lexer.consume(tok_def);
if (lexer.getCurToken() != tok_identifier)
return parseError<FunctionPrototype>("function name", "in prototype");
std::string fnName(lexer.getId());
lexer.consume(tok_identifier);
if (lexer.getCurToken() != '(')
return parseError<FunctionPrototype>("(", "in prototype");
lexer.consume(Token('('));
std::vector<std::unique_ptr<VariableExpression>> args;
if (lexer.getCurToken() != ')')
{
do
{
std::string name(lexer.getId());
auto loc = lexer.getLastLocation();
lexer.consume(tok_identifier);
auto decl = std::make_unique<VariableExpression>(std::move(loc), name);
args.push_back(std::move(decl));
if (lexer.getCurToken() != ',')
break;
lexer.consume(Token(','));
if (lexer.getCurToken() != tok_identifier)
return parseError<FunctionPrototype>(
"identifier", "after ',' in function parameter list");
}
while (true);
}
if (lexer.getCurToken() != ')')
return parseError<FunctionPrototype>(")", "to end function prototype");
// success.
lexer.consume(Token(')'));
return std::make_unique<FunctionPrototype>(std::move(loc), fnName,
std::move(args));
}
/// Parse a function definition, we expect a prototype initiated with the
/// `def` keyword, followed by a block containing a list of expressions.
///
/// definition ::= prototype block
std::unique_ptr<Function> parseDefinition()
{
auto proto = parsePrototype();
if (!proto)
return nullptr;
if (auto block = parseBlock())
return std::make_unique<Function>(std::move(proto), std::move(block));
return nullptr;
}
/// Get the precedence of the pending binary operator token.
int getTokPrecedence()
{
if (!isascii(lexer.getCurToken()))
return -1;
// 1 is lowest precedence.
switch (static_cast<char>(lexer.getCurToken()))
{
case '-':
return 20;
case '+':
return 20;
case '*':
return 40;
default:
return -1;
}
}
/// Helper function to signal errors while parsing, it takes an argument
/// indicating the expected token and another argument giving more context.
/// Location is retrieved from the lexer to enrich the error message.
template <typename R, typename T, typename U = const char*>
std::unique_ptr<R> parseError(T&& expected, U&& context = "")
{
auto curToken = lexer.getCurToken();
llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
<< lexer.getLastLocation().col << "): expected '" << expected
<< "' " << context << " but has Token " << curToken;
if (isprint(curToken))
llvm::errs() << " '" << (char)curToken << "'";
llvm::errs() << "\n";
return nullptr;
}
};
}
#endif // TOY_PARSER_H