//===- Parser.h - Toy Language Parser -------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the parser for the Toy language. It processes the Token // provided by the Lexer and returns an AST. // //===----------------------------------------------------------------------===// #ifndef TOY_PARSER_H #define TOY_PARSER_H #include #include "SyntaxNode.h" #include "Lexer.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" #include #include #include namespace hello { /// This is a simple recursive parser for the Toy language. It produces a well /// formed AST from a stream of Token supplied by the Lexer. No semantic checks /// or symbol resolution is performed. For example, variables are referenced by /// string and the code could reference an undeclared variable and the parsing /// succeeds. class Parser { public: /// Create a Parser for the supplied lexer. explicit Parser(Lexer& lexer) : lexer(lexer) { } /// Parse a full Module. A module is a list of function definitions. std::unique_ptr parseModule() { lexer.getNextToken(); // prime the lexer // Parse functions one at a time and accumulate in this vector. std::vector functions; while (auto f = parseDefinition()) { functions.push_back(std::move(*f)); if (lexer.getCurToken() == tok_eof) break; } // If we didn't reach EOF, there was an error during parsing if (lexer.getCurToken() != tok_eof) return parseError("nothing", "at end of module"); return std::make_unique(std::move(functions)); } private: Lexer& lexer; /// Parse a return statement. /// return :== return ; | return expr ; std::unique_ptr parseReturn() { auto loc = lexer.getLastLocation(); lexer.consume(tok_return); // return takes an optional argument if (lexer.getCurToken() != ';') { std::unique_ptr expr = parseExpression(); if (!expr) { return nullptr; } return std::make_unique(std::move(loc), std::move(expr)); } return std::make_unique(std::move(loc)); } /// Parse a literal number. /// numberexpr ::= number std::unique_ptr parseNumberExpr() { auto loc = lexer.getLastLocation(); auto result = std::make_unique(std::move(loc), lexer.getValue()); lexer.consume(tok_number); return std::move(result); } /// Parse a literal array expression. /// tensorLiteral ::= [ literalList ] | number /// literalList ::= tensorLiteral | tensorLiteral, literalList std::unique_ptr parseTensorLiteralExpr() { auto loc = lexer.getLastLocation(); lexer.consume(Token('[')); // Hold the list of values at this nesting level. std::vector> values; // Hold the dimensions for all the nesting inside this level. std::vector dims; do { // We can have either another nested array or a number literal. if (lexer.getCurToken() == '[') { values.push_back(parseTensorLiteralExpr()); if (!values.back()) return nullptr; // parse error in the nested array. } else { if (lexer.getCurToken() != tok_number) return parseError(" or [", "in literal expression"); values.push_back(parseNumberExpr()); } // End of this list on ']' if (lexer.getCurToken() == ']') break; // Elements are separated by a comma. if (lexer.getCurToken() != ',') return parseError("] or ,", "in literal expression"); lexer.getNextToken(); // eat , } while (true); if (values.empty()) return parseError("", "to fill literal expression"); lexer.getNextToken(); // eat ] /// Fill in the dimensions now. First the current nesting level: dims.push_back(values.size()); /// If there is any nested array, process all of them and ensure that /// dimensions are uniform. if (llvm::any_of(values, [](std::unique_ptr& expr) { return llvm::isa(expr.get()); })) { auto* firstLiteral = llvm::dyn_cast(values.front().get()); if (!firstLiteral) return parseError("uniform well-nested dimensions", "inside literal expression"); // Append the nested dimensions to the current level auto firstDims = firstLiteral->getDimensions(); dims.insert(dims.end(), firstDims.begin(), firstDims.end()); // Sanity check that shape is uniform across all elements of the list. for (auto& expr : values) { auto* exprLiteral = llvm::cast(expr.get()); if (!exprLiteral) return parseError("uniform well-nested dimensions", "inside literal expression"); if (exprLiteral->getDimensions() != firstDims) return parseError("uniform well-nested dimensions", "inside literal expression"); } } return std::make_unique(std::move(loc), std::move(values), std::move(dims)); } /// parenexpr ::= '(' expression ')' std::unique_ptr parseParenExpr() { lexer.getNextToken(); // eat (. auto v = parseExpression(); if (!v) return nullptr; if (lexer.getCurToken() != ')') return parseError(")", "to close expression with parentheses"); lexer.consume(Token(')')); return v; } /// identifierexpr /// ::= identifier /// ::= identifier '(' expression ')' std::unique_ptr parseIdentifierExpr() { std::string name(lexer.getId()); auto loc = lexer.getLastLocation(); lexer.getNextToken(); // eat identifier. if (lexer.getCurToken() != '(') // Simple variable ref. return std::make_unique(std::move(loc), name); // This is a function call. lexer.consume(Token('(')); std::vector> args; if (lexer.getCurToken() != ')') { while (true) { if (auto arg = parseExpression()) args.push_back(std::move(arg)); else return nullptr; if (lexer.getCurToken() == ')') break; if (lexer.getCurToken() != ',') return parseError(", or )", "in argument list"); lexer.getNextToken(); } } lexer.consume(Token(')')); // It can be a builtin call to print if (name == "print") { if (args.size() != 1) return parseError("", "as argument to print()"); return std::make_unique(std::move(loc), std::move(args[0])); } // Call to a user-defined function return std::make_unique(std::move(loc), name, std::move(args)); } /// primary /// ::= identifierexpr /// ::= numberexpr /// ::= parenexpr /// ::= tensorliteral std::unique_ptr parsePrimary() { switch (lexer.getCurToken()) { default: llvm::errs() << "unknown token '" << lexer.getCurToken() << "' when expecting an expression\n"; return nullptr; case tok_identifier: return parseIdentifierExpr(); case tok_number: return parseNumberExpr(); case '(': return parseParenExpr(); case '[': return parseTensorLiteralExpr(); case ';': return nullptr; case '}': return nullptr; } } /// Recursively parse the right hand side of a binary expression, the ExprPrec /// argument indicates the precedence of the current binary operator. /// /// binoprhs ::= ('+' primary)* std::unique_ptr parseBinOpRHS(int exprPrec, std::unique_ptr lhs) { // If this is a binop, find its precedence. while (true) { int tokPrec = getTokPrecedence(); // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. if (tokPrec < exprPrec) return lhs; // Okay, we know this is a binop. int binOp = lexer.getCurToken(); lexer.consume(Token(binOp)); auto loc = lexer.getLastLocation(); // Parse the primary expression after the binary operator. auto rhs = parsePrimary(); if (!rhs) return parseError("expression", "to complete binary operator"); // If BinOp binds less tightly with rhs than the operator after rhs, let // the pending operator take rhs as its lhs. int nextPrec = getTokPrecedence(); if (tokPrec < nextPrec) { rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); if (!rhs) return nullptr; } // Merge lhs/RHS. lhs = std::make_unique(std::move(loc), binOp, std::move(lhs), std::move(rhs)); } } /// expression::= primary binop rhs std::unique_ptr parseExpression() { auto lhs = parsePrimary(); if (!lhs) return nullptr; return parseBinOpRHS(0, std::move(lhs)); } /// type ::= < shape_list > /// shape_list ::= num | num , shape_list std::unique_ptr parseType() { if (lexer.getCurToken() != '<') return parseError("<", "to begin type"); lexer.getNextToken(); // eat < auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); lexer.getNextToken(); if (lexer.getCurToken() == ',') lexer.getNextToken(); } if (lexer.getCurToken() != '>') return parseError(">", "to end type"); lexer.getNextToken(); // eat > return type; } /// Parse a variable declaration, it starts with a `var` keyword followed by /// and identifier and an optional type (shape specification) before the /// initializer. /// decl ::= var identifier [ type ] = expr std::unique_ptr parseDeclaration() { if (lexer.getCurToken() != tok_var) return parseError("var", "to begin declaration"); auto loc = lexer.getLastLocation(); lexer.getNextToken(); // eat var if (lexer.getCurToken() != tok_identifier) return parseError("identified", "after 'var' declaration"); std::string id(lexer.getId()); lexer.getNextToken(); // eat id std::unique_ptr type; // Type is optional, it can be inferred if (lexer.getCurToken() == '<') { type = parseType(); if (!type) return nullptr; } if (!type) type = std::make_unique(); lexer.consume(Token('=')); auto expr = parseExpression(); return std::make_unique(std::move(loc), std::move(id), std::move(*type), std::move(expr)); } /// Parse a block: a list of expression separated by semicolons and wrapped in /// curly braces. /// /// block ::= { expression_list } /// expression_list ::= block_expr ; expression_list /// block_expr ::= decl | "return" | expr std::unique_ptr parseBlock() { if (lexer.getCurToken() != '{') return parseError("{", "to begin block"); lexer.consume(Token('{')); auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. while (lexer.getCurToken() == ';') lexer.consume(Token(';')); while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { if (lexer.getCurToken() == tok_var) { // Variable declaration auto varDecl = parseDeclaration(); if (!varDecl) return nullptr; exprList->push_back(std::move(varDecl)); } else if (lexer.getCurToken() == tok_return) { // Return statement auto ret = parseReturn(); if (!ret) return nullptr; exprList->push_back(std::move(ret)); } else { // General expression auto expr = parseExpression(); if (!expr) return nullptr; exprList->push_back(std::move(expr)); } // Ensure that elements are separated by a semicolon. if (lexer.getCurToken() != ';') return parseError(";", "after expression"); // Ignore empty expressions: swallow sequences of semicolons. while (lexer.getCurToken() == ';') lexer.consume(Token(';')); } if (lexer.getCurToken() != '}') return parseError("}", "to close block"); lexer.consume(Token('}')); return exprList; } /// prototype ::= def id '(' decl_list ')' /// decl_list ::= identifier | identifier, decl_list std::unique_ptr parsePrototype() { auto loc = lexer.getLastLocation(); if (lexer.getCurToken() != tok_def) return parseError("def", "in prototype"); lexer.consume(tok_def); if (lexer.getCurToken() != tok_identifier) return parseError("function name", "in prototype"); std::string fnName(lexer.getId()); lexer.consume(tok_identifier); if (lexer.getCurToken() != '(') return parseError("(", "in prototype"); lexer.consume(Token('(')); std::vector> args; if (lexer.getCurToken() != ')') { do { std::string name(lexer.getId()); auto loc = lexer.getLastLocation(); lexer.consume(tok_identifier); auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); if (lexer.getCurToken() != ',') break; lexer.consume(Token(',')); if (lexer.getCurToken() != tok_identifier) return parseError( "identifier", "after ',' in function parameter list"); } while (true); } if (lexer.getCurToken() != ')') return parseError(")", "to end function prototype"); // success. lexer.consume(Token(')')); return std::make_unique(std::move(loc), fnName, std::move(args)); } /// Parse a function definition, we expect a prototype initiated with the /// `def` keyword, followed by a block containing a list of expressions. /// /// definition ::= prototype block std::unique_ptr parseDefinition() { auto proto = parsePrototype(); if (!proto) return nullptr; if (auto block = parseBlock()) return std::make_unique(std::move(proto), std::move(block)); return nullptr; } /// Get the precedence of the pending binary operator token. int getTokPrecedence() { if (!isascii(lexer.getCurToken())) return -1; // 1 is lowest precedence. switch (static_cast(lexer.getCurToken())) { case '-': return 20; case '+': return 20; case '*': return 40; default: return -1; } } /// Helper function to signal errors while parsing, it takes an argument /// indicating the expected token and another argument giving more context. /// Location is retrieved from the lexer to enrich the error message. template std::unique_ptr parseError(T&& expected, U&& context = "") { auto curToken = lexer.getCurToken(); llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " << lexer.getLastLocation().col << "): expected '" << expected << "' " << context << " but has Token " << curToken; if (isprint(curToken)) llvm::errs() << " '" << (char)curToken << "'"; llvm::errs() << "\n"; return nullptr; } }; } #endif // TOY_PARSER_H