commit 1a64b78ef80f8b1096050c832311b1708b0cc686 Author: jackfiled Date: Thu May 29 15:53:58 2025 +0800 feat: toy tutorial chapter 1. diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..518ed79 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +cmake-*/ +build/ +.idea/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..529cabd --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,43 @@ +cmake_minimum_required(VERSION 3.20) +project(hello_mlir) + +set(CMAKE_CXX_STANDARD 17) + +find_package(MLIR REQUIRED CONFIG) +find_package(LLVM REQUIRED CONFIG) +message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + +set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) +set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") + +include(TableGen) +include(AddLLVM) +include(AddMLIR) +include(HandleLLVMOptions) + +message(${MLIR_INCLUDE_DIRS}) +include_directories(${LLVM_INCLUDE_DIRS}) +include_directories(${MLIR_INCLUDE_DIRS}) +link_directories(${LLVM_BUILD_LIBRARY_DIR}) +add_definitions(${LLVM_DEFINITIONS}) + +include_directories(include) + +add_library(SyntaxNode SHARED lib/SyntaxNode.cpp include/SyntaxNode.h include/Parser.h include/Lexer.h) + +target_link_libraries(SyntaxNode + PRIVATE + MLIRSupport) + +add_executable(hello-mlir main.cpp) + +target_link_libraries(hello-mlir + PRIVATE + SyntaxNode + LLVMSupport + LLVMCore) \ No newline at end of file diff --git a/examples/transpose.hello b/examples/transpose.hello new file mode 100644 index 0000000..755ca53 --- /dev/null +++ b/examples/transpose.hello @@ -0,0 +1,13 @@ +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + + # b is identical to a, the literal tensor is implicitly reshaped: defining new + # variables is the way to reshape tensors (element count must match). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # transpose() and print() are the only builtin, the following will transpose + # a and b and perform an element-wise multiplication before printing the result. + print(transpose(a) * transpose(b)); +} \ No newline at end of file diff --git a/include/Lexer.h b/include/Lexer.h new file mode 100644 index 0000000..bce6a16 --- /dev/null +++ b/include/Lexer.h @@ -0,0 +1,250 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_LEXER_H +#define TOY_LEXER_H + +#include "llvm/ADT/StringRef.h" + +#include +#include + +#include "SyntaxNode.h" + +namespace hello +{ + // List of Token returned by the lexer. + enum Token : int + { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, + }; + + /// The Lexer is an abstract base class providing all the facilities that the + /// Parser expects. It goes through the stream one token at a time and keeps + /// track of the location in the file for debugging purposes. + /// It relies on a subclass to provide a `readNextLine()` method. The subclass + /// can proceed by reading the next line from the standard input or from a + /// memory mapped file. + class Lexer + { + public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purposes (attaching a location to a Token). + explicit Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) + { + } + + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() const { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) + { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() + { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() + { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() const { return curLineNum; } + + // Return the current column in the file. + int getCol() const { return curCol; } + + private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() + { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') + { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() + { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) + { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') + { + std::string numStr; + do + { + numStr += lastChar; + lastChar = Token(getNextChar()); + } + while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') + { + // Comment until end of line. + do + { + lastChar = Token(getNextChar()); + } + while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; + }; + + /// A lexer implementation operating on a buffer in memory. + class LexerBuffer final : public Lexer + { + public: + LexerBuffer(const char* begin, const char* end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) + { + } + + private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override + { + auto* begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + + return llvm::StringRef{begin, static_cast(current - begin)}; + } + + const char *current, *end; + }; +} // namespace toy + +#endif // TOY_LEXER_H diff --git a/include/Parser.h b/include/Parser.h new file mode 100644 index 0000000..4e3cde8 --- /dev/null +++ b/include/Parser.h @@ -0,0 +1,539 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PARSER_H +#define TOY_PARSER_H + +#include + +#include "SyntaxNode.h" +#include "Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace hello +{ + /// This is a simple recursive parser for the Toy language. It produces a well + /// formed AST from a stream of Token supplied by the Lexer. No semantic checks + /// or symbol resolution is performed. For example, variables are referenced by + /// string and the code could reference an undeclared variable and the parsing + /// succeeds. + class Parser + { + public: + /// Create a Parser for the supplied lexer. + explicit Parser(Lexer& lexer) : lexer(lexer) + { + } + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() + { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) + { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + + private: + Lexer& lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() + { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + if (lexer.getCurToken() != ';') + { + std::unique_ptr expr = parseExpression(); + if (!expr) + { + return nullptr; + } + + return std::make_unique(std::move(loc), std::move(expr)); + } + return std::make_unique(std::move(loc)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() + { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() + { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do + { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') + { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } + else + { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } + while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr& expr) + { + return llvm::isa(expr.get()); + })) + { + auto* firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDimensions(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto& expr : values) + { + auto* exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDimensions() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() + { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() + { + std::string name(lexer.getId()); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') + { + while (true) + { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") + { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() + { + switch (lexer.getCurToken()) + { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) + { + // If this is a binop, find its precedence. + while (true) + { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) + { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() + { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() + { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) + { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() + { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') + { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() + { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) + { + if (lexer.getCurToken() == tok_var) + { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } + else if (lexer.getCurToken() == tok_return) + { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } + else + { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() + { + auto loc = lexer.getLastLocation(); + + if (lexer.getCurToken() != tok_def) + return parseError("def", "in prototype"); + lexer.consume(tok_def); + + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName(lexer.getId()); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') + { + do + { + std::string name(lexer.getId()); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } + while (true); + } + if (lexer.getCurToken() != ')') + return parseError(")", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() + { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() + { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) + { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T&& expected, U&& context = "") + { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } + }; +} + +#endif // TOY_PARSER_H diff --git a/include/SyntaxNode.h b/include/SyntaxNode.h new file mode 100644 index 0000000..794a526 --- /dev/null +++ b/include/SyntaxNode.h @@ -0,0 +1,358 @@ +// +// Created by ricardo on 28/05/25. +// + +#ifndef SYNTEXNODE_H +#define SYNTEXNODE_H +#include +#include +#include + +namespace hello +{ + struct ValueType + { + std::vector shape; + }; + + struct Location + { + std::shared_ptr file; + int line; + int col; + }; + + class ExpressionNodeBase + { + public: + enum ExpressionNodeKind + { + VariableDeclaration, + Return, + Number, + Literal, + Variable, + BinaryOperation, + Call, + Print + }; + + ExpressionNodeBase(ExpressionNodeKind kind, Location location) : kind(kind), location(location) + { + } + + virtual ~ExpressionNodeBase() = default; + + const Location& getLocation() const + { + return location; + } + + ExpressionNodeKind getKind() const + { + return kind; + } + + private: + const ExpressionNodeKind kind; + Location location; + }; + + using ExpressionList = std::vector>; + + class NumberExpression : public ExpressionNodeBase + { + double value; + + public: + NumberExpression(Location location, const double value) : ExpressionNodeBase(Number, + std::move(location)), value(value) + { + } + + double getValue() const + { + return value; + } + + static bool classof(const ExpressionNodeBase* c) + { + return c->getKind() == Number; + } + }; + + class LiteralExpression : public ExpressionNodeBase + { + std::vector> values; + std::vector dimensions; + + public: + LiteralExpression(Location location, std::vector> values, + std::vector dimensions) + : ExpressionNodeBase(Literal, std::move(location)), values(std::move(values)), + dimensions(std::move(dimensions)) + { + } + + llvm::ArrayRef> getValues() + { + return values; + } + + llvm::ArrayRef getDimensions() + { + return dimensions; + } + + static bool classof(const ExpressionNodeBase* c) + { + return c->getKind() == Literal; + } + }; + + class VariableExpression : public ExpressionNodeBase + { + std::string name; + + public: + VariableExpression(Location location, const llvm::StringRef name) + : ExpressionNodeBase(Variable, std::move(location)), name(name) + { + } + + llvm::StringRef getName() + { + return name; + } + + static bool classof(const ExpressionNodeBase* c) + { + return c->getKind() == Variable; + } + }; + + class VariableDeclarationExpression : public ExpressionNodeBase + { + std::string name; + ValueType variableType; + std::unique_ptr initialValue; + + public: + VariableDeclarationExpression(Location location, const llvm::StringRef name, ValueType type, + std::unique_ptr initialValue): + ExpressionNodeBase(VariableDeclaration, std::move(location)), name(name), variableType(std::move(type)), + initialValue(std::move(initialValue)) + { + } + + llvm::StringRef getName() + { + return name; + } + + ExpressionNodeBase* getInitialValue() const + { + return initialValue.get(); + } + + const ValueType& getType() + { + return variableType; + } + + static bool classof(const ExpressionNodeBase* c) + { + return c->getKind() == VariableDeclaration; + } + }; + + class ReturnExpression : public ExpressionNodeBase + { + std::optional> returnExpression; + + public: + explicit ReturnExpression(Location location) : ExpressionNodeBase(Return, std::move(location)) + { + } + + ReturnExpression(Location location, std::unique_ptr expression) : ExpressionNodeBase( + Return, + std::move(location)), returnExpression(std::make_optional(std::move(expression))) + { + } + + std::optional getReturnExpression() const + { + if (returnExpression.has_value()) + { + return returnExpression->get(); + } + + return std::nullopt; + } + + static bool classof(const ExpressionNodeBase* c) + { + return c->getKind() == Return; + } + }; + + class BinaryExpression : public ExpressionNodeBase + { + char binaryOperator; + std::unique_ptr left, right; + + public: + BinaryExpression(Location location, char op, std::unique_ptr left, + std::unique_ptr right) + : ExpressionNodeBase(BinaryOperation, std::move(location)), binaryOperator(op), left(std::move(left)), + right(std::move(right)) + { + } + + char getOperator() const + { + return binaryOperator; + } + + ExpressionNodeBase* getLeft() const + { + return left.get(); + } + + ExpressionNodeBase* getRight() const + { + return right.get(); + } + + static bool classof(const ExpressionNodeBase* c) + { + return c->getKind() == BinaryOperation; + } + }; + + class CallExpression : public ExpressionNodeBase + { + std::string name; + std::vector> arguments; + + public: + CallExpression(Location location, const std::string& callee, + std::vector> arguments) + : ExpressionNodeBase(Call, std::move(location)), name(std::move(callee)), arguments(std::move(arguments)) + { + } + + llvm::StringRef getName() const + { + return name; + } + + llvm::ArrayRef> getArguments() const + { + return arguments; + } + + static bool classof(const ExpressionNodeBase* c) + { + return c->getKind() == Call; + } + }; + + class PrintExpression : public ExpressionNodeBase + { + std::unique_ptr argument; + + public: + PrintExpression(Location location, std::unique_ptr argument) + : ExpressionNodeBase(Print, std::move(location)), argument(std::move(argument)) + { + } + + ExpressionNodeBase* getArgument() const + { + return argument.get(); + } + + static bool classof(const ExpressionNodeBase* c) + { + return c->getKind() == Print; + } + }; + + class FunctionPrototype + { + Location location; + std::string name; + std::vector> parameters; + + public: + FunctionPrototype(Location location, const std::string& name, + std::vector> parameters): location(std::move(location)), + name(name), parameters(std::move(parameters)) + { + } + + const Location& getLocation() const + { + return location; + } + + llvm::StringRef getName() const + { + return name; + } + + llvm::ArrayRef> getParameters() const + { + return parameters; + } + }; + + class Function + { + std::unique_ptr prototype; + std::unique_ptr body; + + public: + Function(std::unique_ptr prototype, std::unique_ptr body) : + prototype(std::move(prototype)), + body(std::move(body)) + { + } + + FunctionPrototype* getPrototype() const + { + return prototype.get(); + } + + ExpressionList* getBody() const + { + return body.get(); + } + }; + + class Module + { + std::vector functions; + + public: + explicit Module(std::vector functions) : functions(std::move(functions)) + { + } + + auto begin() + { + return functions.begin(); + } + + auto end() + { + return functions.end(); + } + + void dump(); + }; +} + +#endif //SYNTEXNODE_H diff --git a/lib/SyntaxNode.cpp b/lib/SyntaxNode.cpp new file mode 100644 index 0000000..eae3462 --- /dev/null +++ b/lib/SyntaxNode.cpp @@ -0,0 +1,245 @@ +// +// Created by ricardo on 28/05/25. +// +#include +#include +#include "SyntaxNode.h" + +using namespace hello; + +namespace +{ + struct Indent + { + int& level; + + explicit Indent(int& level) : level(level) + { + ++level; + } + + ~Indent() + { + --level; + } + }; + + class SyntaxNodeDumper + { + public: + void dump(Module* module); + + private: + static void dump(const ValueType& type); + void dump(VariableDeclarationExpression* varDecl); + void dump(ExpressionNodeBase* expr); + void dump(const ExpressionList* exprList); + void dump(NumberExpression* num); + void dump(LiteralExpression* node); + void dump(VariableExpression* node); + void dump(const ReturnExpression* node); + void dump(BinaryExpression* node); + void dump(CallExpression* node); + void dump(PrintExpression* node); + void dump(FunctionPrototype* node); + void dump(const Function* node); + + void indent() const + { + for (int i = 0; i < currentIndent; i++) + { + llvm::outs() << " "; + } + } + + int currentIndent = 0; + }; +} + +template +static std::string formatLocation(T* node) +{ + const auto& location = node->getLocation(); + return (llvm::Twine("@") + *location.file + ":" + llvm::Twine(location.line) + ":" + llvm::Twine(location.col)). + str(); +} + +#define INDENT() Indent level(currentIndent); indent(); + + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void SyntaxNodeDumper::dump(ExpressionNodeBase* expr) +{ + llvm::TypeSwitch(expr) + .Case( + [&](auto* node) { this->dump(node); }) + .Default([&](ExpressionNodeBase*) + { + // No match, fallback to a generic message + INDENT(); + llvm::outs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void SyntaxNodeDumper::dump(VariableDeclarationExpression* varDecl) +{ + INDENT(); + llvm::outs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::outs() << " " << formatLocation(varDecl) << "\n"; + dump(varDecl->getInitialValue()); +} + +/// A "block", or a list of expression +void SyntaxNodeDumper::dump(const ExpressionList* exprList) +{ + INDENT(); + llvm::outs() << "Block {\n"; + for (auto& expr : *exprList) + dump(expr.get()); + indent(); + llvm::outs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void SyntaxNodeDumper::dump(NumberExpression* num) +{ + INDENT(); + llvm::outs() << num->getValue() << " " << formatLocation(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExpressionNodeBase* litOrNum) +{ + // Inside a literal expression we can have either a number or another literal + if (const auto* num = llvm::dyn_cast(litOrNum)) + { + llvm::outs() << num->getValue(); + return; + } + auto* literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::outs() << "<"; + interleaveComma(literal->getDimensions(), llvm::outs()); + llvm::outs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::outs() << "[ "; + interleaveComma(literal->getValues(), llvm::outs(), + [&](auto& elt) { printLitHelper(elt.get()); }); + llvm::outs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void SyntaxNodeDumper::dump(LiteralExpression* node) +{ + INDENT(); + llvm::outs() << "Literal: "; + printLitHelper(node); + llvm::outs() << " " << formatLocation(node) << "\n"; +} + +/// Print a variable reference (just a name). +void SyntaxNodeDumper::dump(VariableExpression* node) +{ + INDENT(); + llvm::outs() << "var: " << node->getName() << " " << formatLocation(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void SyntaxNodeDumper::dump(const ReturnExpression* node) +{ + INDENT(); + llvm::outs() << "Return\n"; + if (node->getReturnExpression().has_value()) + return dump(*node->getReturnExpression()); + { + INDENT(); + llvm::outs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void SyntaxNodeDumper::dump(BinaryExpression* node) +{ + INDENT(); + llvm::outs() << "BinOp: " << node->getOperator() << " " << formatLocation(node) << "\n"; + dump(node->getLeft()); + dump(node->getRight()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void SyntaxNodeDumper::dump(CallExpression* node) +{ + INDENT(); + llvm::outs() << "Call '" << node->getName() << "' [ " << formatLocation(node) << "\n"; + for (auto& arg : node->getArguments()) + dump(arg.get()); + indent(); + llvm::outs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void SyntaxNodeDumper::dump(PrintExpression* node) +{ + INDENT(); + llvm::outs() << "Print [ " << formatLocation(node) << "\n"; + dump(node->getArgument()); + indent(); + llvm::outs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void SyntaxNodeDumper::dump(const ValueType& type) +{ + llvm::outs() << "<"; + interleaveComma(type.shape, llvm::outs()); + llvm::outs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void SyntaxNodeDumper::dump(FunctionPrototype* node) +{ + INDENT(); + llvm::outs() << "Proto '" << node->getName() << "' " << formatLocation(node) << "\n"; + indent(); + llvm::outs() << "Params: ["; + llvm::interleaveComma(node->getParameters(), llvm::outs(), + [](auto& arg) { llvm::outs() << arg->getName(); }); + llvm::outs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void SyntaxNodeDumper::dump(const Function* node) +{ + INDENT(); + llvm::outs() << "Function \n"; + dump(node->getPrototype()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void SyntaxNodeDumper::dump(Module* module) +{ + INDENT(); + llvm::outs() << "Module:\n"; + for (auto& f : *module) + dump(&f); +} + +namespace hello +{ + void Module::dump() + { + SyntaxNodeDumper().dump(this); + } +} diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..0d99340 --- /dev/null +++ b/main.cpp @@ -0,0 +1,56 @@ +#include "Lexer.h" +#include "Parser.h" + +#include +#include +#include + +static llvm::cl::opt inputFilename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-"), + llvm::cl::value_desc("filename")); + +namespace +{ + enum Action { None, DumpSyntaxNode }; +} + +static llvm::cl::opt emitAction("emit", llvm::cl::desc("Select the kind of output desired"), + llvm::cl::values(clEnumValN(DumpSyntaxNode, "ast", "Dump syntax node"))); + +std::unique_ptr parseInputFile(llvm::StringRef filename) +{ + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) + { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + hello::LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); + hello::Parser parser(lexer); + return parser.parseModule(); +} + +int main(int argc, char** argv) +{ + llvm::cl::ParseCommandLineOptions(argc, argv, "Hello MLIR Compiler\n"); + + auto module = parseInputFile(inputFilename); + + if (!module) + { + return 1; + } + + switch (emitAction) + { + case DumpSyntaxNode: + module->dump(); + return 0; + default: + llvm::errs() << "Unrecognized action\n"; + return 1; + } +}