diff --git a/evilang_lib/src/lib/interpreter/environment/mod.rs b/evilang_lib/src/lib/interpreter/environment/mod.rs index 63dcf20..2264113 100644 --- a/evilang_lib/src/lib/interpreter/environment/mod.rs +++ b/evilang_lib/src/lib/interpreter/environment/mod.rs @@ -142,7 +142,7 @@ impl Environment { self.setup_and_eval_statements(&parse(input)?) } - pub fn setup_scope_for_statement( + pub fn hoist_scope_for_statement( &mut self, statement: &Statement ) -> ResultWithError { @@ -153,23 +153,54 @@ impl Environment { } } Statement::FunctionDeclarationStatement(fdecl) => { - self.declare((&fdecl.name).into(), Function::new_closure(self, fdecl.clone()).into())?; + self.hoist_identifier((&fdecl.name).into())?; + } + Statement::ClassDeclarationStatement(cdecl) => { + self.hoist_identifier((&cdecl.name).into())?; + } + _ => {} + } + return Ok(StatementMetaGeneration::NormalGeneration); + } + + pub fn init_scope_for_statement( + &mut self, + statement: &Statement + ) -> ResultWithError { + match statement { + Statement::FunctionDeclarationStatement(fdecl) => { + self.assign_locally( + (&fdecl.name).into(), + Function::new_closure(self, fdecl.clone()).into() + ); } Statement::ClassDeclarationStatement(cdecl) => { let class = RuntimeObject::new_class_decl(self, cdecl)?; - self.declare((&cdecl.name).into(), class.into())?; + self.assign_locally((&cdecl.name).into(), class.into()); } _ => {} } return Ok(StatementMetaGeneration::NormalGeneration); } + pub fn setup_scope_for_statement( + &mut self, + statement: &Statement + ) -> ResultWithError { + self.hoist_scope_for_statement(statement)?; + self.init_scope_for_statement(statement)?; + return Ok(StatementMetaGeneration::NormalGeneration); + } + pub fn setup_scope( &mut self, statements: &StatementList ) -> ResultWithError { for statement in statements.iter() { - self.setup_scope_for_statement(statement)?; + self.hoist_scope_for_statement(statement)?; + } + for statement in statements.iter() { + self.init_scope_for_statement(statement)?; } return Ok(StatementMetaGeneration::NormalGeneration); } diff --git a/evilang_lib/src/lib/interpreter/runtime_values/functions/mod.rs b/evilang_lib/src/lib/interpreter/runtime_values/functions/mod.rs index da76203..d9a93ff 100644 --- a/evilang_lib/src/lib/interpreter/runtime_values/functions/mod.rs +++ b/evilang_lib/src/lib/interpreter/runtime_values/functions/mod.rs @@ -1,4 +1,5 @@ use std::fmt::{ Display, Formatter }; +use std::ops::Deref; use gc::{ Finalize, Trace }; @@ -12,6 +13,10 @@ use crate::interpreter::runtime_values::functions::types::{ FunctionParameters, FunctionReturnValue, }; +use crate::interpreter::variables_containers::{ VariableScope, VariablesMap }; +use crate::interpreter::variables_containers::map::IVariablesMapConstMembers; +use crate::interpreter::variables_containers::scope::IGenericVariablesScope; +use crate::semantic::captured_variables::analyze_captured_variables; use crate::types::cell_ref::{ gc_clone, GcPtr }; pub mod closure; @@ -48,7 +53,23 @@ impl Display for Function { impl Function { pub fn new_closure(env: &Environment, decl: FunctionDeclaration) -> GcPtrToFunction { - let closure = Closure::new(decl, gc_clone(&env.scope)); + let captured_vars = analyze_captured_variables(&decl); + let global_scope_ptr = gc_clone(&env.global_scope.borrow().scope); + let mut captured_map = VariablesMap::new(); + + for var in captured_vars { + if let Some(val) = env.scope.get_actual(var.as_str().into()) { + // Check if this variable comes from the global scope + let var_scope = env.scope.resolve_variable_scope(var.as_str().into()); + if !GcPtr::ptr_eq(&var_scope, &global_scope_ptr.variables) { + // It's not global, so we capture it + captured_map.variables.insert(var.into(), val.into_owned()); + } + } + } + + let closure_scope = VariableScope::new_gc_from_map(captured_map, Some(global_scope_ptr)); + let closure = Closure::new(decl, closure_scope); let function_closure = Function::Closure(closure); return GcPtr::new(function_closure); } diff --git a/evilang_lib/src/lib/lib.rs b/evilang_lib/src/lib/lib.rs index 9b16593..ef5ed77 100644 --- a/evilang_lib/src/lib/lib.rs +++ b/evilang_lib/src/lib/lib.rs @@ -7,3 +7,4 @@ pub mod interpreter; pub mod parser; pub mod tokenizer; pub mod types; +pub mod semantic; diff --git a/evilang_lib/src/lib/semantic/captured_variables.rs b/evilang_lib/src/lib/semantic/captured_variables.rs new file mode 100644 index 0000000..5bb5d4d --- /dev/null +++ b/evilang_lib/src/lib/semantic/captured_variables.rs @@ -0,0 +1,239 @@ +use std::collections::HashSet; + +use crate::ast::expression::{ + Expression, + IdentifierT, + MemberIndexer, +}; +use crate::ast::statement::{ Statement, StatementList }; +use crate::ast::structs::{ + FunctionDeclaration, +}; + +pub fn analyze_captured_variables( + function_decl: &FunctionDeclaration +) -> HashSet { + let mut analyzer = CapturedVariablesAnalyzer::new(function_decl); + analyzer.analyze(); + return analyzer.captured_variables; +} + +struct CapturedVariablesAnalyzer<'a> { + root_function: &'a FunctionDeclaration, + captured_variables: HashSet, + scope_stack: Vec>, +} + +impl<'a> CapturedVariablesAnalyzer<'a> { + fn new(root_function: &'a FunctionDeclaration) -> Self { + let mut root_scope = HashSet::new(); + for param in &root_function.parameters { + root_scope.insert(param.identifier.clone()); + } + Self { + root_function, + captured_variables: HashSet::new(), + scope_stack: vec![root_scope], + } + } + + fn analyze(&mut self) { + self.collect_hoisted_declarations(&self.root_function.body); + self.visit_statement(&self.root_function.body); + } + + fn push_scope(&mut self) { + self.scope_stack.push(HashSet::new()); + } + + fn pop_scope(&mut self) { + self.scope_stack.pop(); + } + + fn declare_variable(&mut self, name: IdentifierT) { + if let Some(top) = self.scope_stack.last_mut() { + top.insert(name); + } + } + + fn is_variable_declared(&self, name: &IdentifierT) -> bool { + for scope in self.scope_stack.iter().rev() { + if scope.contains(name) { + return true; + } + } + return false; + } + + fn collect_hoisted_declarations(&mut self, stmt: &Statement) { + match stmt { + Statement::VariableDeclarations(decls) => { + for decl in decls { + self.declare_variable(decl.identifier.clone()); + } + } + Statement::FunctionDeclarationStatement(fdecl) => { + self.declare_variable(fdecl.name.clone()); + } + Statement::ClassDeclarationStatement(cdecl) => { + self.declare_variable(cdecl.name.clone()); + } + Statement::BlockStatement(_stmts) => { + } + _ => {} + } + } + + fn collect_hoisted_declarations_list(&mut self, list: &StatementList) { + for stmt in list { + self.collect_hoisted_declarations(stmt); + } + } + + fn visit_statement_list(&mut self, list: &StatementList) { + for stmt in list { + self.visit_statement(stmt); + } + } + + fn visit_statement(&mut self, stmt: &Statement) { + match stmt { + Statement::BlockStatement(list) => { + self.push_scope(); + self.collect_hoisted_declarations_list(list); + self.visit_statement_list(list); + self.pop_scope(); + } + Statement::ExpressionStatement(expr) => self.visit_expression(expr), + Statement::ReturnStatement(opt_expr) => { + if let Some(expr) = opt_expr { + self.visit_expression(expr); + } + } + Statement::VariableDeclarations(decls) => { + for decl in decls { + if let Some(init) = &decl.initializer { + self.visit_expression(init); + } + } + } + Statement::IfStatement { condition, if_branch, else_branch } => { + self.visit_expression(condition); + self.visit_statement(if_branch); + if let Some(else_b) = else_branch { + self.visit_statement(else_b); + } + } + Statement::WhileLoop { condition, body } => { + self.visit_expression(condition); + self.visit_statement(body); + } + Statement::DoWhileLoop { condition, body } => { + self.visit_expression(condition); + self.visit_statement(body); + } + Statement::ForLoop { initialization, condition, increment, body } => { + self.push_scope(); + self.collect_hoisted_declarations(initialization); + + self.visit_statement(initialization); + self.visit_expression(condition); + self.visit_statement(increment); + self.visit_statement(body); + + self.pop_scope(); + } + Statement::FunctionDeclarationStatement(func_decl) => { + self.push_scope(); // Inner function scope + for param in &func_decl.parameters { + self.declare_variable(param.identifier.clone()); + } + self.collect_hoisted_declarations(&func_decl.body); + self.visit_statement(&func_decl.body); + self.pop_scope(); + } + Statement::ClassDeclarationStatement(class_decl) => { + for method in &class_decl.methods { + self.push_scope(); + for param in &method.parameters { + self.declare_variable(param.identifier.clone()); + } + self.collect_hoisted_declarations(&method.body); + self.visit_statement(&method.body); + self.pop_scope(); + } + } + Statement::NamespaceStatement { body, .. } => { + self.push_scope(); + self.collect_hoisted_declarations_list(body); + self.visit_statement_list(body); + self.pop_scope(); + } + Statement::ImportStatement { file_name, .. } => { + self.visit_expression(file_name); + } + Statement::EmptyStatement => {} + Statement::BreakStatement(_) => {} + Statement::ContinueStatement(_) => {} + } + } + + fn visit_expression(&mut self, expr: &Expression) { + match expr { + Expression::Identifier(name) => { + if !self.is_variable_declared(name) { + self.captured_variables.insert(name.clone()); + } + } + Expression::BinaryExpression { left, right, .. } | + Expression::AssignmentExpression { left, right, .. } => { + self.visit_expression(left); + self.visit_expression(right); + } + Expression::UnaryExpression { argument, .. } => { + self.visit_expression(argument); + } + Expression::ParenthesizedExpression(expr) => self.visit_expression(expr), + Expression::FunctionCall(call) | Expression::NewObjectExpression(call) => { + self.visit_expression(&call.callee); + for arg in &call.arguments { + self.visit_expression(arg); + } + } + Expression::MemberAccess { object, member } => { + self.visit_expression(object); + if let MemberIndexer::SubscriptExpression(expr) = member { + self.visit_expression(expr); + } + } + Expression::FunctionExpression(func_decl) => { + self.push_scope(); + for param in &func_decl.parameters { + self.declare_variable(param.identifier.clone()); + } + self.collect_hoisted_declarations(&func_decl.body); + self.visit_statement(&func_decl.body); + self.pop_scope(); + } + Expression::ClassDeclarationExpression(class_decl) => { + for method in &class_decl.methods { + self.push_scope(); + for param in &method.parameters { + self.declare_variable(param.identifier.clone()); + } + self.collect_hoisted_declarations(&method.body); + self.visit_statement(&method.body); + self.pop_scope(); + } + } + Expression::DottedIdentifiers(dotted) => { + if let Some(first) = dotted.identifiers.first() { + if !self.is_variable_declared(first) { + self.captured_variables.insert(first.clone()); + } + } + } + _ => {} + } + } +} diff --git a/evilang_lib/src/lib/semantic/mod.rs b/evilang_lib/src/lib/semantic/mod.rs new file mode 100644 index 0000000..863d3d4 --- /dev/null +++ b/evilang_lib/src/lib/semantic/mod.rs @@ -0,0 +1 @@ +pub mod captured_variables; \ No newline at end of file diff --git a/tests/closure_capture_tests.rs b/tests/closure_capture_tests.rs new file mode 100644 index 0000000..fe82331 --- /dev/null +++ b/tests/closure_capture_tests.rs @@ -0,0 +1,52 @@ + +use std::ops::Deref; +use evilang_lib::interpreter::environment::Environment; +use evilang_lib::interpreter::runtime_values::PrimitiveValue; +use evilang_lib::interpreter::runtime_values::functions::Function; +use evilang_lib::interpreter::variables_containers::map::IVariablesMapConstMembers; + +#[test] +fn test_closure_capture_optimization() { + let code = r#" + fn outer() { + let unused = 100; + let used = 200; + fn inner() { + return used; + } + return inner; + } + let closure = outer(); + "#; + + let mut env = Environment::new().unwrap(); + let _ = env.eval_program_string(code.to_string()).unwrap(); + + // Get 'closure' variable + let closure_var = env.scope.get_actual("closure".into()).expect("closure variable should exist"); + let closure_val = closure_var.borrow(); + + match closure_val.deref() { + PrimitiveValue::Function(func_ptr) => { + match func_ptr.deref() { + Function::Closure(closure) => { + // Check captured variables + let scope = &closure.parent_scope; + let vars = scope.variables.borrow(); + + // "used" should be captured + assert!(vars.contains_key("used".into()), "Captured scope should contain 'used'"); + + // "unused" should NOT be captured (memory leak check) + assert!(!vars.contains_key("unused".into()), "Captured scope should NOT contain 'unused'"); + + // Check value + let used_val = vars.get_actual("used".into()).unwrap(); + assert_eq!(*used_val.borrow().deref(), PrimitiveValue::Number(200.into())); + }, + _ => panic!("Expected a closure"), + } + }, + _ => panic!("Expected a function"), + } +}