From cbf31fa513c3a5da76c36d09d4671b2d35265bb7 Mon Sep 17 00:00:00 2001 From: Kai-Philipp Nosper Date: Fri, 4 Feb 2022 17:06:38 +0100 Subject: [PATCH] Implement simple AST optimizer - Precalculate operations only containing literals --- src/astoptimizer.rs | 104 ++++++++++++++++++++++++++++++++++++++++++++ src/interpreter.rs | 34 ++++++++++----- src/lib.rs | 21 ++++----- src/main.rs | 24 ++++++++-- 4 files changed, 159 insertions(+), 24 deletions(-) create mode 100644 src/astoptimizer.rs diff --git a/src/astoptimizer.rs b/src/astoptimizer.rs new file mode 100644 index 0000000..39775a5 --- /dev/null +++ b/src/astoptimizer.rs @@ -0,0 +1,104 @@ +use crate::ast::{Ast, BlockScope, Expression, If, Loop, Statement, BinOpType, UnOpType}; + +pub trait AstOptimizer { + fn optimize(ast: Ast) -> Ast; +} + +pub struct SimpleAstOptimizer; + +impl AstOptimizer for SimpleAstOptimizer { + fn optimize(mut ast: Ast) -> Ast { + Self::optimize_block(&mut ast.main); + ast + } +} + +impl SimpleAstOptimizer { + fn optimize_block(block: &mut BlockScope) { + for stmt in block { + match stmt { + Statement::Expr(expr) => Self::optimize_expr(expr), + Statement::Loop(Loop { + condition, + advancement, + body, + }) => { + Self::optimize_expr(condition); + if let Some(advancement) = advancement { + Self::optimize_expr(advancement) + } + Self::optimize_block(body); + } + Statement::If(If { + condition, + body_true, + body_false, + }) => { + Self::optimize_expr(condition); + Self::optimize_block(body_true); + Self::optimize_block(body_false); + } + Statement::Print(expr) => Self::optimize_expr(expr), + } + } + } + + fn optimize_expr(expr: &mut Expression) { + match expr { + Expression::I64(_) | Expression::String(_) | Expression::Var(_, _) => (), + Expression::BinOp(bo, lhs, rhs) => { + Self::optimize_expr(lhs); + Self::optimize_expr(rhs); + + // Precalculate binary operations that consist of 2 literals. No need to do this at + // runtime, as all parts of the calculation are known at *compiletime* / parsetime. + match (lhs.as_mut(), rhs.as_mut()) { + (Expression::I64(lhs), Expression::I64(rhs)) => { + let new_expr = match bo { + BinOpType::Add => Expression::I64(*lhs + *rhs), + BinOpType::Mul => Expression::I64(*lhs * *rhs), + BinOpType::Sub => Expression::I64(*lhs - *rhs), + BinOpType::Div => Expression::I64(*lhs / *rhs), + BinOpType::Mod => Expression::I64(*lhs % *rhs), + BinOpType::BOr => Expression::I64(*lhs | *rhs), + BinOpType::BAnd => Expression::I64(*lhs & *rhs), + BinOpType::BXor => Expression::I64(*lhs ^ *rhs), + BinOpType::LAnd => Expression::I64(if (*lhs != 0) && (*rhs != 0) { 1 } else { 0 }), + BinOpType::LOr => Expression::I64(if (*lhs != 0) || (*rhs != 0) { 1 } else { 0 }), + BinOpType::Shr => Expression::I64(*lhs >> *rhs), + BinOpType::Shl => Expression::I64(*lhs << *rhs), + BinOpType::EquEqu => Expression::I64(if lhs == rhs { 1 } else { 0 }), + BinOpType::NotEqu => Expression::I64(if lhs != rhs { 1 } else { 0 }), + BinOpType::Less => Expression::I64(if lhs < rhs { 1 } else { 0 }), + BinOpType::LessEqu => Expression::I64(if lhs <= rhs { 1 } else { 0 }), + BinOpType::Greater => Expression::I64(if lhs > rhs { 1 } else { 0 }), + BinOpType::GreaterEqu => Expression::I64(if lhs >= rhs { 1 } else { 0 }), + + BinOpType::Declare | BinOpType::Assign => unreachable!(), + }; + *expr = new_expr; + }, + _ => () + } + + } + Expression::UnOp(uo, operand) => { + Self::optimize_expr(operand); + + // Precalculate unary operations just like binary ones + match operand.as_mut() { + Expression::I64(val) => { + let new_expr = match uo { + UnOpType::Negate => Expression::I64(-*val), + UnOpType::BNot => Expression::I64(!*val), + UnOpType::LNot => Expression::I64(if *val == 0 { 1 } else { 0 }), + }; + *expr = new_expr; + } + _ => (), + } + } + } + + } +} diff --git a/src/interpreter.rs b/src/interpreter.rs index d4b6886..0096a14 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -1,7 +1,7 @@ use crate::{ - ast::{BlockScope, BinOpType, Expression, If, Statement, UnOpType}, + ast::{BlockScope, BinOpType, Expression, If, Statement, UnOpType, Ast}, lexer::lex, - parser::parse, stringstore::{Sid, StringStore}, + parser::parse, stringstore::{Sid, StringStore}, astoptimizer::{SimpleAstOptimizer, AstOptimizer}, }; #[derive(Debug, PartialEq, Eq, Clone)] @@ -12,8 +12,15 @@ pub enum Value { #[derive(Default)] pub struct Interpreter { - capture_output: bool, + pub optimize_ast: bool, + + pub print_tokens: bool, + pub print_ast: bool, + + pub capture_output: bool, output: Vec, + + // Variable table stores the runtime values of variables vartable: Vec, @@ -22,11 +29,7 @@ pub struct Interpreter { impl Interpreter { pub fn new() -> Self { - Self::default() - } - - pub fn set_capture_output(&mut self, enabled: bool) { - self.capture_output = enabled; + Self { optimize_ast: true, ..Self::default() } } pub fn output(&self) -> &[Value] { @@ -41,14 +44,23 @@ impl Interpreter { self.vartable.get_mut(idx) } - pub fn run_str(&mut self, code: &str, print_tokens: bool, print_ast: bool) { + pub fn run_str(&mut self, code: &str) { let tokens = lex(code).unwrap(); - if print_tokens { + if self.print_tokens { println!("Tokens: {:?}", tokens); } let ast = parse(tokens); - if print_ast { + + self.run_ast(ast); + } + + pub fn run_ast(&mut self, mut ast: Ast) { + if self.optimize_ast { + ast = SimpleAstOptimizer::optimize(ast); + } + + if self.print_ast { println!("{:#?}", ast.main); } diff --git a/src/lib.rs b/src/lib.rs index 6344b98..2afaf75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ pub mod lexer; pub mod parser; pub mod token; pub mod stringstore; +pub mod astoptimizer; #[cfg(test)] mod tests { @@ -16,11 +17,11 @@ mod tests { let correct_result = 233168; let mut interpreter = Interpreter::new(); - interpreter.set_capture_output(true); + interpreter.capture_output = true; let code = read_to_string(format!("examples/{filename}")).unwrap(); - interpreter.run_str(&code, false, false); + interpreter.run_str(&code); let expected_output = [Value::I64(correct_result)]; @@ -33,11 +34,11 @@ mod tests { let correct_result = 4613732; let mut interpreter = Interpreter::new(); - interpreter.set_capture_output(true); + interpreter.capture_output = true; let code = read_to_string(format!("examples/{filename}")).unwrap(); - interpreter.run_str(&code, false, false); + interpreter.run_str(&code); let expected_output = [Value::I64(correct_result)]; @@ -50,11 +51,11 @@ mod tests { let correct_result = 6857; let mut interpreter = Interpreter::new(); - interpreter.set_capture_output(true); + interpreter.capture_output = true; let code = read_to_string(format!("examples/{filename}")).unwrap(); - interpreter.run_str(&code, false, false); + interpreter.run_str(&code); let expected_output = [Value::I64(correct_result)]; @@ -67,11 +68,11 @@ mod tests { let correct_result = 906609; let mut interpreter = Interpreter::new(); - interpreter.set_capture_output(true); + interpreter.capture_output = true; let code = read_to_string(format!("examples/{filename}")).unwrap(); - interpreter.run_str(&code, false, false); + interpreter.run_str(&code); let expected_output = [Value::I64(correct_result)]; @@ -84,11 +85,11 @@ mod tests { let correct_result = 232792560; let mut interpreter = Interpreter::new(); - interpreter.set_capture_output(true); + interpreter.capture_output = true; let code = read_to_string(format!("examples/{filename}")).unwrap(); - interpreter.run_str(&code, false, false); + interpreter.run_str(&code); let expected_output = [Value::I64(correct_result)]; diff --git a/src/main.rs b/src/main.rs index 0efe66a..3210ace 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use std::{ env::args, fs, - io::{stdin, stdout, Write}, + io::{stdin, stdout, Write}, process::exit, }; use nek_lang::interpreter::Interpreter; @@ -10,6 +10,7 @@ use nek_lang::interpreter::Interpreter; struct CliConfig { print_tokens: bool, print_ast: bool, + no_optimizations: bool, interactive: bool, file: Option, } @@ -22,7 +23,9 @@ fn main() { match arg.as_str() { "--token" | "-t" => conf.print_tokens = true, "--ast" | "-a" => conf.print_ast = true, + "--no-opt" | "-n" => conf.no_optimizations = true, "--interactive" | "-i" => conf.interactive = true, + "--help" | "-h" => print_help(), file if conf.file.is_none() => conf.file = Some(file.to_string()), _ => panic!("Invalid argument: '{}'", arg), } @@ -30,9 +33,13 @@ fn main() { let mut interpreter = Interpreter::new(); + interpreter.print_tokens = conf.print_tokens; + interpreter.print_ast = conf.print_ast; + interpreter.optimize_ast = !conf.no_optimizations; + if let Some(file) = &conf.file { let code = fs::read_to_string(file).expect(&format!("File not found: '{}'", file)); - interpreter.run_str(&code, conf.print_tokens, conf.print_ast); + interpreter.run_str(&code); } if conf.interactive || conf.file.is_none() { @@ -49,7 +56,18 @@ fn main() { break; } - interpreter.run_str(&code, conf.print_tokens, conf.print_ast); + interpreter.run_str(&code); } } } + +fn print_help() { + println!("Usage nek-lang [FLAGS] [FILE]"); + println!("FLAGS: "); + println!("-t, --token Print the lexed tokens"); + println!("-a, --ast Print the abstract syntax tree"); + println!("-n, --no-opt Disable the AST optimizations"); + println!("-i, --interactive Interactive mode"); + println!("-h, --help Show this help screen"); + exit(0); +}