diff --git a/baml_language/.gitignore b/baml_language/.gitignore index de2768d625..451131ec5f 100644 --- a/baml_language/.gitignore +++ b/baml_language/.gitignore @@ -11,3 +11,6 @@ rpi/ # Riptide artifacts (cloud-synced) .humanlayer/tasks/ + +# insta pending snapshot scratch files (resolved by `cargo insta review`) +*.pending-snap diff --git a/baml_language/crates/baml_compiler2_hir/src/builder.rs b/baml_language/crates/baml_compiler2_hir/src/builder.rs index abe6684322..2cf2523e1c 100644 --- a/baml_language/crates/baml_compiler2_hir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_hir/src/builder.rs @@ -12,7 +12,7 @@ use baml_base::{Name, SourceFile}; use baml_compiler_diagnostics::diagnostic::DiagnosticId; use baml_compiler2_ast::{self as ast, LoweringDiagnostic}; use rustc_hash::FxHashMap; -use text_size::TextRange; +use text_size::{TextRange, TextSize}; use crate::{ contributions::{Contribution, Definition, DefinitionKind, FileSymbolContributions}, @@ -26,10 +26,19 @@ use crate::{ }, scope::{FileScopeId, Scope, ScopeId, ScopeKind}, semantic_index::{ - DefinitionSite, FileSemanticIndex, PathResolution, ScopeBindings, SemanticIndexExtra, + BindingId, DefinitionSite, FileSemanticIndex, LocalBinding, PathResolution, ScopeBindings, + SemanticIndexExtra, visible_binding_at_in_scopes, }, }; +#[derive(Debug, Clone)] +struct PathRootReference { + name: Name, + use_scope: FileScopeId, + use_offset: TextSize, + owner_lambda: Option, +} + pub struct SemanticIndexBuilder<'db> { db: &'db dyn crate::Db, file: SourceFile, @@ -48,6 +57,11 @@ pub struct SemanticIndexBuilder<'db> { /// Path root resolutions for multi-segment `Path` expressions. /// Collected during `walk_expr_body`, sorted by `ExprId` at the end. path_resolutions: Vec<(ast::ExprId, PathResolution)>, + /// Path-root references collected while walking source order. Unlike + /// `expr_scopes`, this carries the scope and innermost lambda context at + /// collection time so capture analysis does not rely on arena-local `ExprId`s. + path_root_references: Vec, + lambda_stack: Vec, item_tree: ItemTree, item_tree_source_map: crate::item_tree::ItemTreeSourceMap, @@ -68,6 +82,8 @@ impl<'db> SemanticIndexBuilder<'db> { class_depth: 0, expr_scopes: Vec::new(), path_resolutions: Vec::new(), + path_root_references: Vec::new(), + lambda_stack: Vec::new(), item_tree: ItemTree::new(), item_tree_source_map: crate::item_tree::ItemTreeSourceMap::default(), type_contributions: Vec::new(), @@ -234,312 +250,507 @@ impl<'db> SemanticIndexBuilder<'db> { } } - /// Walk an `ExprBody` arena, recording each expression in the current scope. - /// Block expressions with let-bindings push a Block scope. + /// Walk an `ExprBody` arena in source order, recording expression ownership + /// and local bindings in the lexical scope that owns each expression. fn walk_expr_body(&mut self, body: &ast::ExprBody, source_map: &ast::AstSourceMap) { - for (expr_id, expr) in body.exprs.iter() { - self.record_expr_scope(expr_id); - let _ = expr; + if let Some(root_expr) = body.root_expr { + self.walk_expr(root_expr, body, source_map, false); } - // Collect let-bindings and for-loop bindings, detecting duplicates within the scope. - let mut seen: FxHashMap> = FxHashMap::default(); - for (stmt_id, stmt) in body.stmts.iter() { - let binding_pattern = match stmt { - ast::Stmt::Let { pattern, .. } => Some(*pattern), - ast::Stmt::For { binding, .. } => Some(*binding), - _ => None, - }; - if let Some(pattern) = binding_pattern { - let scope_id = self.current_scope_id(); - if let Some(name) = body.patterns[pattern].binding_name() { - let name_range = source_map.pattern_span(pattern); - - seen.entry(name.clone()).or_default().push(MemberSite { - range: name_range, - kind: DefinitionKind::Binding, - }); + } - self.scope_bindings[scope_id.index() as usize] - .bindings - .push((name.clone(), DefinitionSite::Statement(stmt_id), name_range)); + /// Walk an expression, recording its `FileScopeId` and (for `Block`s) + /// optionally pushing a `ScopeKind::Block` scope around the contents. + /// + /// `push_block_scope`: pass `true` for nested expressions; pass `false` + /// when walking the root body of a function/lambda (the function/lambda + /// scope is already on the stack — pushing another `Block` scope would + /// double-wrap the body). + fn walk_expr( + &mut self, + expr_id: ast::ExprId, + body: &ast::ExprBody, + source_map: &ast::AstSourceMap, + push_block_scope: bool, + ) { + match &body.exprs[expr_id] { + ast::Expr::Block { stmts, tail_expr } => { + if push_block_scope { + self.push_scope(ScopeKind::Block, None, source_map.expr_span(expr_id)); } + self.record_expr_scope(expr_id); + self.walk_block_contents(stmts, *tail_expr, body, source_map); + if push_block_scope { + self.pop_scope(); + } + } + ast::Expr::Lambda(func_def) => { + self.record_expr_scope(expr_id); + self.walk_lambda_expr(expr_id, func_def, source_map); + } + _ => { + self.record_expr_scope(expr_id); + self.walk_expr_children(expr_id, body, source_map); } } + } - self.emit_duplicate_diagnostics(seen); + fn walk_block_contents( + &mut self, + stmts: &[ast::StmtId], + tail_expr: Option, + body: &ast::ExprBody, + source_map: &ast::AstSourceMap, + ) { + for &stmt_id in stmts { + self.walk_stmt(stmt_id, body, source_map); + } + if let Some(tail_expr) = tail_expr { + self.walk_expr(tail_expr, body, source_map, true); + } + } - // Register match-arm pattern bindings in child scopes. - // The MatchArm scope's TextRange covers the arm span, so - // scope_at_offset will find it for names used inside the arm body. - for (arm_id, arm) in body.match_arms.iter() { - let arm_span = source_map.match_arm_span(arm_id); - self.push_scope(ScopeKind::MatchArm, None, arm_span); - - if let Some(name) = Self::pattern_binding_name(&body.patterns, arm.pattern) { - let name_range = source_map.pattern_span(arm.pattern); - let scope_id = self.current_scope_id(); - self.scope_bindings[scope_id.index() as usize] - .bindings - .push(( - name.clone(), - DefinitionSite::PatternBinding(arm.pattern), - name_range, - )); + fn walk_stmt( + &mut self, + stmt_id: ast::StmtId, + body: &ast::ExprBody, + source_map: &ast::AstSourceMap, + ) { + match &body.stmts[stmt_id] { + ast::Stmt::Expr(expr) => self.walk_expr(*expr, body, source_map, true), + ast::Stmt::Let { + pattern, + initializer, + .. + } => { + if let Some(initializer) = initializer { + self.walk_expr(*initializer, body, source_map, true); + } + self.register_local_pattern( + *pattern, + DefinitionSite::Statement(stmt_id), + body, + source_map, + source_map.stmt_span(stmt_id).end(), + ); } - - self.pop_scope(); + ast::Stmt::For { + binding, + collection, + body: loop_body, + } => { + self.walk_expr(*collection, body, source_map, true); + self.push_scope(ScopeKind::Block, None, source_map.stmt_span(stmt_id)); + self.register_local_pattern( + *binding, + DefinitionSite::Statement(stmt_id), + body, + source_map, + source_map.pattern_span(*binding).start(), + ); + self.walk_expr(*loop_body, body, source_map, true); + self.pop_scope(); + } + ast::Stmt::While { + condition, + body: loop_body, + after, + .. + } => { + self.walk_expr(*condition, body, source_map, true); + // Push a Block scope around the body and the C-style for + // `after` step, mirroring `Stmt::For`. While the body is + // itself an `Expr::Block` (which pushes its own scope), the + // wrapping scope here gives the while-statement its own + // identity in the scope tree, so downstream consumers (LSP + // find-references, capture analysis, MIR `binding_locals` + // lookup) can anchor on the while-statement boundary + // symmetrically with for-statements. + // + // The `after` step (set by C-style `for (init; cond; after)` + // desugaring) runs at the same level as the body, not inside + // it — it must be able to see the surrounding-scope locals + // declared by the for-init, so it stays within this wrapping + // scope but outside the body's own block scope. + self.push_scope(ScopeKind::Block, None, source_map.stmt_span(stmt_id)); + self.walk_expr(*loop_body, body, source_map, true); + if let Some(after) = after { + self.walk_stmt(*after, body, source_map); + } + self.pop_scope(); + } + ast::Stmt::Return(expr) => { + if let Some(expr) = expr { + self.walk_expr(*expr, body, source_map, true); + } + } + ast::Stmt::Throw { value } => { + self.walk_expr(*value, body, source_map, true); + } + ast::Stmt::Assign { target, value } => { + self.walk_expr(*target, body, source_map, true); + self.walk_expr(*value, body, source_map, true); + } + ast::Stmt::AssignOp { target, value, .. } => { + self.walk_expr(*target, body, source_map, true); + self.walk_expr(*value, body, source_map, true); + } + ast::Stmt::Break + | ast::Stmt::Continue + | ast::Stmt::Missing + | ast::Stmt::HeaderComment { .. } => {} } + } - // Register catch clause and catch arm pattern bindings in child scopes. - // Two-level scoping: CatchClause (holds clause binding) → CatchArm (holds arm pattern). - for (expr_id, expr) in body.exprs.iter() { - let ast::Expr::Catch { clauses, .. } = expr else { - continue; - }; - let catch_span = source_map.expr_span(expr_id); - - for clause in clauses { - // Push CatchClause scope — clause binding visible to all arms. - self.push_scope(ScopeKind::CatchClause, None, catch_span); - - if let Some(name) = Self::pattern_binding_name(&body.patterns, clause.binding) { - let name_range = source_map.pattern_span(clause.binding); - let scope_id = self.current_scope_id(); - self.scope_bindings[scope_id.index() as usize] - .bindings - .push(( - name.clone(), - DefinitionSite::PatternBinding(clause.binding), - name_range, - )); + fn walk_expr_children( + &mut self, + expr_id: ast::ExprId, + body: &ast::ExprBody, + source_map: &ast::AstSourceMap, + ) { + match &body.exprs[expr_id] { + ast::Expr::If { + condition, + then_branch, + else_branch, + } => { + self.walk_expr(*condition, body, source_map, true); + self.walk_expr(*then_branch, body, source_map, true); + if let Some(else_branch) = else_branch { + self.walk_expr(*else_branch, body, source_map, true); } - - // Register optional stack trace binding in the same CatchClause scope. - if let Some(st_pat) = clause.stack_trace_binding { - if let Some(name) = Self::pattern_binding_name(&body.patterns, st_pat) { - let name_range = source_map.pattern_span(st_pat); - let scope_id = self.current_scope_id(); - self.scope_bindings[scope_id.index() as usize] - .bindings - .push(( - name.clone(), - DefinitionSite::PatternBinding(st_pat), - name_range, - )); - } + } + ast::Expr::Match { + scrutinee, arms, .. + } => { + self.walk_expr(*scrutinee, body, source_map, true); + for &arm_id in arms { + self.walk_match_arm(arm_id, body, source_map); } - - // Push CatchArm child scopes — arm pattern visible only in arm body. - for &arm_id in &clause.arms { - let arm = &body.catch_arms[arm_id]; - let arm_span = source_map.catch_arm_span(arm_id); - self.push_scope(ScopeKind::CatchArm, None, arm_span); - - if let Some(name) = Self::pattern_binding_name(&body.patterns, arm.pattern) { - let name_range = source_map.pattern_span(arm.pattern); - let scope_id = self.current_scope_id(); - self.scope_bindings[scope_id.index() as usize] - .bindings - .push(( - name.clone(), - DefinitionSite::PatternBinding(arm.pattern), - name_range, - )); + } + ast::Expr::Catch { base, clauses } => { + self.walk_expr(*base, body, source_map, true); + for clause in clauses { + self.walk_catch_clause( + clause, + body, + source_map, + Self::catch_clause_scope_span(clause, source_map), + ); + } + } + ast::Expr::Throw { value } => { + self.walk_expr(*value, body, source_map, true); + } + ast::Expr::Binary { lhs, rhs, .. } => { + self.walk_expr(*lhs, body, source_map, true); + self.walk_expr(*rhs, body, source_map, true); + } + ast::Expr::Unary { expr, .. } | ast::Expr::OptionalChain { expr } => { + self.walk_expr(*expr, body, source_map, true); + } + ast::Expr::Call { callee, args } | ast::Expr::OptionalCall { callee, args } => { + self.walk_expr(*callee, body, source_map, true); + for &arg in args { + self.walk_expr(arg, body, source_map, true); + } + } + ast::Expr::Object { + fields, spreads, .. + } => { + for (_, field_expr) in fields { + self.walk_expr(*field_expr, body, source_map, true); + } + for spread in spreads { + self.walk_expr(spread.expr, body, source_map, true); + } + } + ast::Expr::Array { elements } => { + for &element in elements { + self.walk_expr(element, body, source_map, true); + } + } + ast::Expr::Map { entries } => { + for &(key, value) in entries { + self.walk_expr(key, body, source_map, true); + self.walk_expr(value, body, source_map, true); + } + } + ast::Expr::MemberAccess { base, .. } | ast::Expr::OptionalMemberAccess { base, .. } => { + self.walk_expr(*base, body, source_map, true); + } + ast::Expr::Index { base, index } | ast::Expr::OptionalIndex { base, index } => { + self.walk_expr(*base, body, source_map, true); + self.walk_expr(*index, body, source_map, true); + } + ast::Expr::Path(segments) => { + if let Some(root) = segments.first() { + let use_scope = self.current_scope_id(); + let use_offset = source_map.expr_span(expr_id).start(); + self.record_path_root_reference(root, use_scope, use_offset); + if segments.len() >= 2 { + self.classify_path_expr(expr_id, segments, use_scope, use_offset); } - - self.pop_scope(); // CatchArm } - - self.pop_scope(); // CatchClause } + ast::Expr::Literal(_) + | ast::Expr::ByteStringLiteral(_) + | ast::Expr::Null + | ast::Expr::Block { .. } + | ast::Expr::Lambda(_) + | ast::Expr::Missing => {} } + } - // Pass 5 — Lambda scopes: register lambda params in child scopes. - for (expr_id, expr) in body.exprs.iter() { - let ast::Expr::Lambda(ref func_def) = *expr else { - continue; - }; - let lambda_span = source_map.expr_span(expr_id); - - self.push_scope(ScopeKind::Lambda, None, lambda_span); + fn register_local_pattern( + &mut self, + pat_id: ast::PatId, + site: DefinitionSite, + body: &ast::ExprBody, + source_map: &ast::AstSourceMap, + visible_from: TextSize, + ) { + if let Some(name) = Self::local_binding_name(&body.patterns, pat_id) { + let name_range = source_map.pattern_span(pat_id); let scope_id = self.current_scope_id(); + self.scope_bindings[scope_id.index() as usize] + .bindings + .push(LocalBinding { + name: name.clone(), + site, + pattern: pat_id, + name_range, + visible_from, + }); + } + } - // Seed params into the lambda scope's bindings - for (idx, param) in func_def.params.iter().enumerate() { - self.scope_bindings[scope_id.index() as usize] - .params - .push((param.name.clone(), idx)); - } + fn walk_match_arm( + &mut self, + arm_id: ast::MatchArmId, + body: &ast::ExprBody, + source_map: &ast::AstSourceMap, + ) { + let arm = &body.match_arms[arm_id]; + self.push_scope(ScopeKind::MatchArm, None, source_map.match_arm_span(arm_id)); + let visible_from = source_map.pattern_span(arm.pattern).start(); + self.register_local_pattern( + arm.pattern, + DefinitionSite::PatternBinding(arm.pattern), + body, + source_map, + visible_from, + ); + if let Some(guard) = arm.guard { + self.walk_expr(guard, body, source_map, true); + } + self.walk_expr(arm.body, body, source_map, true); + self.pop_scope(); + } - // Recursively walk the lambda's own ExprBody - if let Some(ast::FunctionBodyDef::Expr(ref lambda_body, ref lambda_source_map)) = - func_def.body - { - self.walk_expr_body(lambda_body, lambda_source_map); + fn walk_catch_clause( + &mut self, + clause: &ast::CatchClause, + body: &ast::ExprBody, + source_map: &ast::AstSourceMap, + catch_span: TextRange, + ) { + self.push_scope(ScopeKind::CatchClause, None, catch_span); + let binding_visible_from = source_map.pattern_span(clause.binding).start(); + self.register_local_pattern( + clause.binding, + DefinitionSite::PatternBinding(clause.binding), + body, + source_map, + binding_visible_from, + ); + if let Some(st_pat) = clause.stack_trace_binding { + let st_visible_from = source_map.pattern_span(st_pat).start(); + self.register_local_pattern( + st_pat, + DefinitionSite::PatternBinding(st_pat), + body, + source_map, + st_visible_from, + ); + } + for &arm_id in &clause.arms { + self.walk_catch_arm(arm_id, body, source_map); + } + self.pop_scope(); + } - // ── Capture analysis ────────────────────────────────────────── - // Identify names referenced in the lambda body that are - // defined in ancestor scopes (up to Function boundary). - let referenced_names = Self::collect_name_references(lambda_body); - let lambda_idx = scope_id.index() as usize; + fn walk_catch_arm( + &mut self, + arm_id: ast::CatchArmId, + body: &ast::ExprBody, + source_map: &ast::AstSourceMap, + ) { + let arm = &body.catch_arms[arm_id]; + self.push_scope(ScopeKind::CatchArm, None, source_map.catch_arm_span(arm_id)); + let visible_from = source_map.pattern_span(arm.pattern).start(); + self.register_local_pattern( + arm.pattern, + DefinitionSite::PatternBinding(arm.pattern), + body, + source_map, + visible_from, + ); + self.walk_expr(arm.body, body, source_map, true); + self.pop_scope(); + } - let mut captures: Vec<(Name, DefinitionSite)> = Vec::new(); - let mut seen: std::collections::HashSet = std::collections::HashSet::new(); + fn catch_clause_scope_span( + clause: &ast::CatchClause, + source_map: &ast::AstSourceMap, + ) -> TextRange { + let binding_span = source_map.pattern_span(clause.binding); + let mut start = binding_span.start(); + let mut end = binding_span.end(); + + if let Some(stack_trace_binding) = clause.stack_trace_binding { + let span = source_map.pattern_span(stack_trace_binding); + start = start.min(span.start()); + end = end.max(span.end()); + } - for name in &referenced_names { - // Skip if already recorded as a capture - if seen.contains(name) { - continue; - } - // Skip if defined locally in the lambda scope (param or let-binding) - if Self::scope_defines_name(&self.scope_bindings[lambda_idx], name) { - continue; - } + for &arm_id in &clause.arms { + let span = source_map.catch_arm_span(arm_id); + start = start.min(span.start()); + end = end.max(span.end()); + } - // Walk ancestor scopes to find the defining scope - let mut current = self.scopes[lambda_idx].parent; - while let Some(ancestor_id) = current { - let ancestor_idx = ancestor_id.index() as usize; - - // Read scope metadata before any mutation to avoid - // simultaneous borrow conflicts on self.scopes and - // self.scope_bindings. - let ancestor_kind = self.scopes[ancestor_idx].kind.clone(); - let ancestor_parent = self.scopes[ancestor_idx].parent; - - // Check if this ancestor defines the name — record the - // DefinitionSite so captures are tied to the specific - // declaration, not just the name (future-proofs for shadowing). - if let Some(def_site) = - Self::scope_definition_site(&self.scope_bindings[ancestor_idx], name) - { - captures.push((name.clone(), def_site)); - seen.insert(name.clone()); - // Mark the name as captured in the defining scope - self.scope_bindings[ancestor_idx] - .captured_names - .insert(name.clone()); - break; - } + TextRange::new(start, end) + } - // Also check if it's already a capture of an intermediate - // lambda (for nested capture chains: inner lambda captures - // from an intermediate lambda that itself captures from the - // parent). - if let Some((_, def_site)) = self.scope_bindings[ancestor_idx] - .captures - .iter() - .find(|(n, _)| n == name) - { - captures.push((name.clone(), *def_site)); - seen.insert(name.clone()); - break; - } + fn walk_lambda_expr( + &mut self, + expr_id: ast::ExprId, + func_def: &ast::FunctionDef, + source_map: &ast::AstSourceMap, + ) { + self.push_scope(ScopeKind::Lambda, None, source_map.expr_span(expr_id)); + let scope_id = self.current_scope_id(); + for (idx, param) in func_def.params.iter().enumerate() { + self.scope_bindings[scope_id.index() as usize] + .params + .push((param.name.clone(), idx)); + } + if let Some(ast::FunctionBodyDef::Expr(lambda_body, lambda_source_map)) = &func_def.body { + self.lambda_stack.push(scope_id); + self.walk_expr_body(lambda_body, lambda_source_map); + self.analyze_lambda_captures(scope_id, lambda_body, lambda_source_map); + self.lambda_stack.pop(); + } + self.pop_scope(); + } - // Stop at Function boundary — don't capture across function defs - if matches!(ancestor_kind, ScopeKind::Function) { - break; - } + fn analyze_lambda_captures( + &mut self, + lambda_scope: FileScopeId, + _lambda_body: &ast::ExprBody, + _lambda_source_map: &ast::AstSourceMap, + ) { + let lambda_idx = lambda_scope.index() as usize; + let mut captures: Vec<(Name, BindingId)> = Vec::new(); + let mut seen = std::collections::HashSet::new(); - current = ancestor_parent; - } + for reference in self + .path_root_references + .iter() + .filter(|reference| reference.owner_lambda == Some(lambda_scope)) + { + if let Some(binding_id) = + self.visible_binding_at(reference.use_scope, reference.use_offset, &reference.name) + { + if !self.scope_is_descendant_or_self(binding_id.scope, lambda_scope) + && seen.insert(binding_id) + { + captures.push((reference.name.clone(), binding_id)); + self.scope_bindings[binding_id.scope.index() as usize] + .captured_bindings + .insert(binding_id); } - - self.scope_bindings[lambda_idx].captures = captures; } - - self.pop_scope(); } - // Pass 6 — Path resolution: classify multi-segment Path root segments. - // After all binding collection passes, check if the root of each - // multi-segment Path is a locally-declared variable or parameter. - let visible_names = self.collect_visible_names(); - for (expr_id, expr) in body.exprs.iter() { - if let ast::Expr::Path(segments) = expr { - if segments.len() >= 2 { - let root = &segments[0]; - let resolution = if visible_names.contains(root) { - PathResolution::Local { name: root.clone() } - } else { - PathResolution::Unknown - }; - self.path_resolutions.push((expr_id, resolution)); - } + self.scope_bindings[lambda_idx].captures = captures; + } + + fn visible_binding_at( + &self, + scope_id: FileScopeId, + at_offset: TextSize, + name: &Name, + ) -> Option { + visible_binding_at_in_scopes( + &self.scopes, + &self.scope_bindings, + scope_id, + at_offset, + name, + ) + } + + fn scope_is_descendant_or_self(&self, scope_id: FileScopeId, ancestor_id: FileScopeId) -> bool { + let mut current = Some(scope_id); + while let Some(id) = current { + if id == ancestor_id { + return true; } + current = self.scopes[id.index() as usize].parent; } + false } - /// Collect all names visible in the current scope chain (params and - /// let-bindings), stopping at function/lambda boundaries. - /// - /// This is a conservative best-effort check: names found here are - /// definitely locals; names not found may be package names (resolved by TIR). - fn collect_visible_names(&self) -> std::collections::HashSet { - let mut names = std::collections::HashSet::new(); - for &scope_id in self.scope_stack.iter().rev() { - let idx = scope_id.index() as usize; - let bindings = &self.scope_bindings[idx]; - for (name, _) in &bindings.params { - names.insert(name.clone()); - } - for (name, _, _) in &bindings.bindings { - names.insert(name.clone()); - } - for (name, _) in &bindings.captures { - names.insert(name.clone()); - } - // Stop at function/lambda boundary — don't look through function scopes. - let scope_kind = &self.scopes[idx].kind; - if matches!(scope_kind, ScopeKind::Function | ScopeKind::Lambda) { - break; - } + fn classify_path_expr( + &mut self, + expr_id: ast::ExprId, + segments: &[Name], + use_scope: FileScopeId, + use_offset: TextSize, + ) { + if segments.len() < 2 { + return; } - names + let root = &segments[0]; + let resolution = if self + .visible_binding_at(use_scope, use_offset, root) + .is_some() + { + PathResolution::Local { name: root.clone() } + } else { + PathResolution::Unknown + }; + self.path_resolutions.push((expr_id, resolution)); + } + + fn record_path_root_reference( + &mut self, + root: &Name, + use_scope: FileScopeId, + use_offset: TextSize, + ) { + self.path_root_references.push(PathRootReference { + name: root.clone(), + use_scope, + use_offset, + owner_lambda: self.lambda_stack.last().copied(), + }); } /// Extract the binding name from a pattern, if it has one. - /// Wildcards (`_`) are not bindings and return `None`. - fn pattern_binding_name( + /// + /// The AST canonicalizes `_` to `Wildcard` at construction time + /// (`Pattern::binding`), so `_` never reaches us as a `Bind` regardless of + /// the surface form. `let`/`for` patterns and `match`/`catch` arm + /// patterns therefore use the same extraction. + fn local_binding_name( patterns: &la_arena::Arena, pat_id: ast::PatId, ) -> Option<&Name> { patterns[pat_id].binding_name() } - /// Collect all single-segment `Expr::Path` names from an `ExprBody`. - /// These represent potential variable references — both bare identifiers - /// (`x`) and the root segment of multi-segment paths (`obj` in `obj.field`). - fn collect_name_references(body: &ast::ExprBody) -> Vec { - let mut names = Vec::new(); - for (_expr_id, expr) in body.exprs.iter() { - if let ast::Expr::Path(segments) = expr { - if !segments.is_empty() { - names.push(segments[0].clone()); - } - } - } - names - } - - /// Check if a name is defined in a scope's bindings (params or let-bindings). - fn scope_defines_name(bindings: &ScopeBindings, name: &Name) -> bool { - bindings.params.iter().any(|(n, _)| n == name) - || bindings.bindings.iter().any(|(n, _, _)| n == name) - } - - /// Find the `DefinitionSite` for a name in a scope's bindings. - /// Returns the first matching definition (params checked first, then bindings). - fn scope_definition_site(bindings: &ScopeBindings, name: &Name) -> Option { - if let Some((_, idx)) = bindings.params.iter().find(|(n, _)| n == name) { - return Some(DefinitionSite::Parameter(*idx)); - } - if let Some((_, def, _)) = bindings.bindings.iter().find(|(n, _, _)| n == name) { - return Some(*def); - } - None - } - // ── Item lowering ──────────────────────────────────────────────────────── fn lower_item(&mut self, item: &ast::Item) { diff --git a/baml_language/crates/baml_compiler2_hir/src/semantic_index.rs b/baml_language/crates/baml_compiler2_hir/src/semantic_index.rs index 87f543bb38..1f2956c71a 100644 --- a/baml_language/crates/baml_compiler2_hir/src/semantic_index.rs +++ b/baml_language/crates/baml_compiler2_hir/src/semantic_index.rs @@ -38,7 +38,7 @@ use crate::{ contributions::FileSymbolContributions, diagnostic::Hir2Diagnostic, item_tree::{ItemTree, ItemTreeSourceMap}, - scope::{FileScopeId, Scope, ScopeId}, + scope::{FileScopeId, Scope, ScopeId, ScopeKind}, }; // ── DefinitionSite ─────────────────────────────────────────────────────────── @@ -54,8 +54,25 @@ pub enum DefinitionSite { PatternBinding(PatId), } +// ── BindingId ──────────────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct BindingId { + pub scope: FileScopeId, + pub site: DefinitionSite, +} + // ── ScopeBindings ──────────────────────────────────────────────────────────── +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LocalBinding { + pub name: Name, + pub site: DefinitionSite, + pub pattern: PatId, + pub name_range: TextRange, + pub visible_from: TextSize, +} + /// Per-scope local bindings — what names are introduced in this scope. /// /// Lightweight version of Ty's `PlaceTable` + `UseDefMap`. BAML's simpler @@ -64,17 +81,17 @@ pub enum DefinitionSite { #[derive(Debug, Clone, PartialEq, Eq)] pub struct ScopeBindings { /// Let-bindings in this scope, in source order. - pub bindings: Vec<(Name, DefinitionSite, TextRange)>, + pub bindings: Vec, /// Parameters (for Function/Lambda scopes). pub params: Vec<(Name, usize)>, // (name, param_index) /// Variables captured from ancestor scopes (for Lambda scopes only). /// Each entry is `(name, definition_site)` to uniquely identify the /// captured declaration, even in the presence of shadowing. /// Populated by capture analysis in `SemanticIndexBuilder::walk_expr_body`. - pub captures: Vec<(Name, DefinitionSite)>, - /// Names in this scope that are captured by a descendant lambda. + pub captures: Vec<(Name, BindingId)>, + /// Bindings in this scope that are captured by a descendant lambda. /// Used by MIR lowering to decide which locals need cell wrapping. - pub captured_names: HashSet, + pub captured_bindings: HashSet, } impl ScopeBindings { @@ -83,7 +100,7 @@ impl ScopeBindings { bindings: Vec::new(), params: Vec::new(), captures: Vec::new(), - captured_names: HashSet::new(), + captured_bindings: HashSet::new(), } } } @@ -94,6 +111,48 @@ impl Default for ScopeBindings { } } +/// Shared local-binding lookup used while building and after indexing. +/// +/// Keep this as the single source for parent-scope visibility semantics: +/// skip ancestor class scopes, scan local bindings in reverse source order, +/// check `visible_from`, then fall back to parameters. +pub(crate) fn visible_binding_at_in_scopes( + scopes: &[Scope], + scope_bindings: &[ScopeBindings], + scope_id: FileScopeId, + at_offset: TextSize, + name: &Name, +) -> Option { + let mut current = Some(scope_id); + while let Some(ancestor_id) = current { + let scope = &scopes[ancestor_id.index() as usize]; + if matches!(scope.kind, ScopeKind::Class) && ancestor_id != scope_id { + current = scope.parent; + continue; + } + + let bindings = &scope_bindings[ancestor_id.index() as usize]; + for binding in bindings.bindings.iter().rev() { + if &binding.name == name && binding.visible_from <= at_offset { + return Some(BindingId { + scope: ancestor_id, + site: binding.site, + }); + } + } + for (param_name, param_idx) in &bindings.params { + if param_name == name { + return Some(BindingId { + scope: ancestor_id, + site: DefinitionSite::Parameter(*param_idx), + }); + } + } + current = scope.parent; + } + None +} + // ── SemanticIndexExtra ─────────────────────────────────────────────────────── /// Rare/optional data for `FileSemanticIndex`. Heap-allocated only when @@ -185,6 +244,21 @@ unsafe impl salsa::Update for FileSemanticIndex<'_> { } impl FileSemanticIndex<'_> { + /// Find the `Lambda` scope whose range exactly matches `span`. + /// + /// Linear walk over the scope list (which is small in practice — + /// bounded by the number of lambda nestings in a file). + pub fn lambda_scope_for(&self, span: text_size::TextRange) -> Option { + self.scopes + .iter() + .enumerate() + .find(|(_, scope)| matches!(scope.kind, ScopeKind::Lambda) && scope.range == span) + .map(|(i, _)| { + #[allow(clippy::cast_possible_truncation)] + FileScopeId::new(i as u32) + }) + } + /// Find the innermost scope containing `offset`. /// /// Scopes are in DFS pre-order. We walk in reverse (deepest first) @@ -257,6 +331,25 @@ impl FileSemanticIndex<'_> { ancestors } + pub fn binding_visible_at(&self, binding: &LocalBinding, at_offset: TextSize) -> bool { + binding.visible_from <= at_offset + } + + pub fn visible_binding_at( + &self, + scope_id: FileScopeId, + at_offset: TextSize, + name: &Name, + ) -> Option { + visible_binding_at_in_scopes( + &self.scopes, + &self.scope_bindings, + scope_id, + at_offset, + name, + ) + } + pub fn diagnostics(&self) -> &[Hir2Diagnostic] { self.extra .as_ref() diff --git a/baml_language/crates/baml_compiler2_mir/src/lower.rs b/baml_language/crates/baml_compiler2_mir/src/lower.rs index 53f8a2b5f7..63d2fe1d3e 100644 --- a/baml_language/crates/baml_compiler2_mir/src/lower.rs +++ b/baml_language/crates/baml_compiler2_mir/src/lower.rs @@ -335,6 +335,7 @@ use baml_compiler2_hir::{ loc::{FunctionLoc, LetLoc}, package::{PackageId, package_dependencies, package_items}, scope::FileScopeId, + semantic_index::{BindingId, DefinitionSite}, }; use baml_compiler2_ppir::file_semantic_index; use baml_compiler2_tir::{ @@ -347,6 +348,7 @@ struct LoweringContext<'db> { db: &'db dyn crate::Db, builder: MirBuilder, locals: HashMap, + binding_locals: HashMap, loop_context: Option, catch_context: Option, exit_block: BlockId, @@ -417,18 +419,16 @@ struct LoweringContext<'db> { // Capture map for the current lambda body. // `Some(map)` when lowering inside a lambda body; `None` for top-level functions. - // Maps captured variable name -> index into the closure's captures array. + // Maps captured binding identity -> index into the closure's captures array. // Used by `lower_path_expr` to resolve references to captured variables as // `Place::Capture(idx)` instead of `Place::Local(_)`. - capture_indices: Option>, + capture_indices: Option>, - // Names that were added to the current lambda's capture list transitively - // (i.e. because an inner lambda needed them but they weren't in the HIR - // capture list for this lambda). Populated by `lower_lambda` when building - // an inner closure's capture operands. Collected by the *parent* - // `lower_lambda` call after the body is lowered so it can extend the outer - // MakeClosure with extra captures. - transitive_captures_needed: Vec, + // Bindings that were added to the current lambda's capture list transitively + // because an inner lambda needed them but they were not in the HIR capture + // list for this lambda. Collected by the parent `lower_lambda` call after + // the body is lowered so it can extend the outer MakeClosure with extra captures. + transitive_captures_needed: Vec, /// Stack of null-exit blocks for active `OptionalChain` scopes. /// When an `OptionalFieldAccess`/`OptionalIndex`/`OptionalCall` encounters null, @@ -710,6 +710,7 @@ impl<'db> LoweringContext<'db> { db, builder: MirBuilder::new(func_name, arity), locals: HashMap::new(), + binding_locals: HashMap::new(), loop_context: None, catch_context: None, exit_block: BlockId(0), // placeholder; overwritten in lower_function_body @@ -879,6 +880,7 @@ impl<'db> LoweringContext<'db> { db, builder: MirBuilder::new(let_name.clone(), 0), locals: HashMap::new(), + binding_locals: HashMap::new(), loop_context: None, catch_context: None, exit_block: BlockId(0), // placeholder; overwritten in lower_let_body_inner @@ -923,6 +925,169 @@ impl<'db> LoweringContext<'db> { Name::new(&name) } + fn scope_is_descendant_or_self( + index: &baml_compiler2_hir::semantic_index::FileSemanticIndex<'_>, + scope_id: FileScopeId, + ancestor_id: FileScopeId, + ) -> bool { + let mut current = Some(scope_id); + while let Some(id) = current { + if id == ancestor_id { + return true; + } + current = index.scopes[id.index() as usize].parent; + } + false + } + + fn binding_id_for_pattern_site( + &self, + pattern: AstPatId, + site: DefinitionSite, + ) -> Option { + let index = file_semantic_index(self.db, self.file); + let pattern_span = self + .source_map + .as_ref() + .map(|source_map| source_map.pattern_span(pattern)); + + for (scope_idx, bindings) in index.scope_bindings.iter().enumerate() { + let scope_id = FileScopeId::new(u32::try_from(scope_idx).expect("scope id overflow")); + if !Self::scope_is_descendant_or_self(index, scope_id, self.current_scope) { + continue; + } + for binding in &bindings.bindings { + let pattern_matches_name_range = pattern_span.is_none_or(|span| { + span == binding.name_range + || (span.start() <= binding.name_range.start() + && binding.name_range.end() <= span.end()) + }); + if binding.site == site && binding.pattern == pattern && pattern_matches_name_range + { + return Some(BindingId { + scope: scope_id, + site, + }); + } + } + } + None + } + + fn binding_id_for_statement(&self, stmt_id: AstStmtId, pattern: AstPatId) -> Option { + self.binding_id_for_pattern_site(pattern, DefinitionSite::Statement(stmt_id)) + } + + fn record_pattern_binding_local(&mut self, pattern: AstPatId, local: Local) { + if let Some(binding_id) = + self.binding_id_for_pattern_site(pattern, DefinitionSite::PatternBinding(pattern)) + { + self.binding_locals.insert(binding_id, local); + } + } + + fn pattern_binding_is_captured(&self, pattern: AstPatId) -> bool { + let Some(binding_id) = + self.binding_id_for_pattern_site(pattern, DefinitionSite::PatternBinding(pattern)) + else { + return false; + }; + let index = file_semantic_index(self.db, self.file); + index + .scope_bindings + .get(binding_id.scope.index() as usize) + .is_some_and(|bindings| bindings.captured_bindings.contains(&binding_id)) + } + + fn binding_id_for_name_at(&self, expr_id: AstExprId, name: &Name) -> Option { + let index = file_semantic_index(self.db, self.file); + let (scope_id, offset) = if let Some(source_map) = self.source_map.as_ref() { + let offset = source_map.expr_span(expr_id).start(); + ( + index.scope_at_offset(offset, self.scope_func_name.as_ref()), + offset, + ) + } else { + // The source-map-less branch is only valid for **synthesized** + // expressions emitted by the lowering itself (e.g. for-loop index + // increments, capture forwarding, init function bodies). The + // fallback uses `current_scope` and the scope's end offset, which + // is correct for synthesized refs at the end of the current scope + // but would silently pick the post-shadow binding for a + // user-written name lowered without a source map. + // + // If you find yourself adding a user-visible expression that + // hits this path: the right fix is to thread a `BindingId` + // through to the call site, not to widen this fallback. + let scope_id = self.current_scope; + let offset = index.scopes[scope_id.index() as usize].range.end(); + (scope_id, offset) + }; + index.visible_binding_at(scope_id, offset, name) + } + + fn capture_index_for_name_at(&self, expr_id: AstExprId, name: &Name) -> Option { + let binding_id = self.binding_id_for_name_at(expr_id, name)?; + self.capture_indices + .as_ref() + .and_then(|captures| captures.get(&binding_id).copied()) + } + + /// Emit `unwatch` ops for every watched local at index `[watched_depth..]` + /// of `watched_locals_stack`, in reverse declaration order. + /// + /// This is the single emitter for unwatch sequences. All scope-exit + /// paths go through it: + /// - normal block fallthrough: `lower_scoped_block` (depth = entry stack len) + /// - normal `for`-body fallthrough (depth = entry stack len) + /// - normal match/catch arm-body fallthrough (depth = arm-entry stack len) + /// - `break` / `continue` (depth = `loop_context.watched_locals_depth`) + /// - `return` / `throw` (depth = 0 — the stack is swapped at lambda + /// boundaries, so 0 means "everything in the enclosing function") + /// + /// Does NOT truncate the stack — callers that own the scope are + /// responsible for truncating via `restore_locals_after_scope`. Divergent + /// callers (break/continue/return/throw) leave the stack alone because a + /// dead block follows the divergent terminator. + fn emit_unwatch_to_depth(&mut self, watched_depth: usize) { + let watched = self.watched_locals_stack[watched_depth..].to_vec(); + for local in watched.into_iter().rev() { + self.builder.unwatch(local); + } + } + + fn restore_locals_after_scope( + &mut self, + saved_locals: HashMap, + watched_depth: usize, + ) { + self.watched_locals_stack.truncate(watched_depth); + self.locals = saved_locals; + } + + fn restore_active_locals(&mut self, saved_locals: HashMap) { + self.locals = saved_locals; + } + + fn mark_captured_locals_in_scope_tree(&mut self, root_scope: FileScopeId) { + let index = file_semantic_index(self.db, self.file); + let root = &index.scopes[root_scope.index() as usize]; + let start = root_scope.index(); + let end = root.descendants.end.index(); + + for raw_idx in start..end { + let scope_id = FileScopeId::new(raw_idx); + let Some(scope_bindings) = index.scope_bindings.get(scope_id.index() as usize) else { + continue; + }; + for binding_id in &scope_bindings.captured_bindings { + if let Some(&local) = self.binding_locals.get(binding_id) { + self.builder.local_decl_mut(local).is_captured = true; + } + } + } + } + /// Get the `baml_type::Ty` for an expression by looking up in the aggregated map /// and converting from TIR Ty. Uses `current_scope` as the `FileScopeId` key. fn expr_ty(&self, expr_id: AstExprId) -> Ty { @@ -1050,7 +1215,7 @@ impl LoweringContext<'_> { // Parameter locals _1..=_n // For `self` with no annotation, look up the TIR-inferred parameter type // which correctly resolves to the enclosing class type. - for (param_name, param_te) in &sig.params { + for (param_idx, (param_name, param_te)) in sig.params.iter().enumerate() { let param_ty = if param_name.as_str() == "self" && matches!(param_te, baml_compiler2_ast::TypeExpr::Unknown { .. }) { @@ -1091,6 +1256,13 @@ impl LoweringContext<'_> { .builder .declare_local(Some(param_name.clone()), param_ty, None, false); self.locals.insert(param_name.clone(), local); + self.binding_locals.insert( + BindingId { + scope: self.current_scope, + site: DefinitionSite::Parameter(param_idx), + }, + local, + ); } // Entry and exit blocks @@ -1117,20 +1289,9 @@ impl LoweringContext<'_> { self.builder.set_current_block(self.exit_block); self.builder.return_(); - // Mark locals that are captured by nested lambdas with `is_captured = true`. - // The HIR `ScopeBindings.captured_names` for the function scope records which - // names are captured by any descendant lambda. These locals need cell wrapping. - { - let func_scope_id = self.current_scope; - let index = file_semantic_index(self.db, self.file); - if let Some(sb) = index.scope_bindings.get(func_scope_id.index() as usize) { - for captured_name in &sb.captured_names { - if let Some(&local) = self.locals.get(captured_name) { - self.builder.local_decl_mut(local).is_captured = true; - } - } - } - } + // Mark locals captured by nested lambdas. HIR stores this by binding + // identity, including block-owned bindings. + self.mark_captured_locals_in_scope_tree(self.current_scope); // Take the builder out of self to call `build()` which consumes it let dummy = MirBuilder::new(Name::new("_dummy"), 0); @@ -1256,23 +1417,21 @@ impl LoweringContext<'_> { }; // Read HIR captures for this lambda scope. - // `captures` lists (name, DefinitionSite) pairs that the lambda reads from - // enclosing scopes. The DefinitionSite uniquely identifies the declaration - // even with shadowing. - // We build `capture_indices` (name → index in closure.captures[]) so that - // `lower_path_expr` and `lower_lvalue` can emit Place::Capture(idx). - let hir_captures: Vec = { + // `captures` lists the exact binding identities that the lambda reads + // from enclosing scopes. We build `capture_indices` so path/lvalue + // lowering can emit `Place::Capture(idx)` without collapsing shadows by name. + let hir_captures: Vec<(Name, BindingId)> = { let index = file_semantic_index(self.db, self.file); index .scope_bindings .get(lambda_scope_id.index() as usize) - .map(|sb| sb.captures.iter().map(|(name, _)| name.clone()).collect()) + .map(|sb| sb.captures.clone()) .unwrap_or_default() }; - let lambda_capture_indices: HashMap = hir_captures + let lambda_capture_indices: HashMap = hir_captures .iter() .enumerate() - .map(|(i, name)| (name.clone(), i)) + .map(|(i, (_, binding_id))| (*binding_id, i)) .collect(); // Save parent state. @@ -1283,6 +1442,7 @@ impl LoweringContext<'_> { let saved_body = std::mem::replace(&mut self.body, lambda_body); let saved_source_map = std::mem::replace(&mut self.source_map, lambda_source_map); let saved_locals = std::mem::take(&mut self.locals); + let saved_binding_locals = std::mem::take(&mut self.binding_locals); let saved_exit_block = self.exit_block; let saved_loop_context = self.loop_context.take(); let saved_catch_context = self.catch_context.take(); @@ -1325,7 +1485,7 @@ impl LoweringContext<'_> { ); // Declare parameter locals _1..=_n. - for param in &func_def.params { + for (param_idx, param) in func_def.params.iter().enumerate() { let param_ty = match ¶m.type_expr { Some(spanned_te) => { let mut diags = Vec::new(); @@ -1347,6 +1507,13 @@ impl LoweringContext<'_> { .builder .declare_local(Some(param.name.clone()), param_ty, None, false); self.locals.insert(param.name.clone(), local); + self.binding_locals.insert( + BindingId { + scope: self.current_scope, + site: DefinitionSite::Parameter(param_idx), + }, + local, + ); } // Create entry and exit blocks. @@ -1372,19 +1539,9 @@ impl LoweringContext<'_> { self.builder.set_current_block(self.exit_block); self.builder.return_(); - // Mark locals that are captured by nested lambdas with `is_captured = true`. - // This mirrors the same step in `lower_function_body` but for lambdas. - // Uses the lambda's own scope id (lambda_scope_id) to look up HIR captured_names. - { - let index = file_semantic_index(self.db, self.file); - if let Some(sb) = index.scope_bindings.get(lambda_scope_id.index() as usize) { - for captured_name in &sb.captured_names { - if let Some(&local) = self.locals.get(captured_name) { - self.builder.local_decl_mut(local).is_captured = true; - } - } - } - } + // Mark locals captured by nested lambdas. HIR stores this by binding + // identity, including block-owned bindings. + self.mark_captured_locals_in_scope_tree(lambda_scope_id); // Build the lambda MirFunction. // First, collect any nested lambdas that were encountered while lowering @@ -1416,6 +1573,7 @@ impl LoweringContext<'_> { self.body = saved_body; self.source_map = saved_source_map; self.locals = saved_locals; + self.binding_locals = saved_binding_locals; self.exit_block = saved_exit_block; self.loop_context = saved_loop_context; self.catch_context = saved_catch_context; @@ -1433,9 +1591,12 @@ impl LoweringContext<'_> { // handle propagation by pushing to `transitive_captures_needed` when a // name is not found in the current scope's locals or captures. let mut extended_hir_captures = hir_captures; - for name in &newly_needed_transitive { - if !extended_hir_captures.contains(name) { - extended_hir_captures.push(name.clone()); + for binding_id in newly_needed_transitive { + if !extended_hir_captures + .iter() + .any(|(_, existing)| *existing == binding_id) + { + extended_hir_captures.push((Name::new("_capture"), binding_id)); } } @@ -1449,17 +1610,17 @@ impl LoweringContext<'_> { // lambda — i.e. the current lambda (f) will need to capture it from ITS // parent, and g will receive it via f's capture slot. let mut capture_operands: Vec = Vec::with_capacity(extended_hir_captures.len()); - for name in &extended_hir_captures { - if let Some(&local) = self.locals.get(name) { + for (_, binding_id) in &extended_hir_captures { + if let Some(&local) = self.binding_locals.get(binding_id) { // Mark the local as captured at the capture site — this is the // definitive place where we know the exact Local being captured, - // even in the presence of shadowing (future-proofing). + // even in the presence of shadowing. self.builder.local_decl_mut(local).is_captured = true; capture_operands.push(Operand::Copy(Place::Local(local))); } else if let Some(cap_idx) = self .capture_indices .as_ref() - .and_then(|m| m.get(name)) + .and_then(|m| m.get(binding_id)) .copied() { // The variable is itself a capture in the current scope. @@ -1472,11 +1633,11 @@ impl LoweringContext<'_> { let new_idx = { let ci = self.capture_indices.get_or_insert_with(HashMap::new); let idx = ci.len(); - ci.insert(name.clone(), idx); + ci.insert(*binding_id, idx); idx }; // Signal to our parent lambda that it needs to capture this name. - self.transitive_captures_needed.push(name.clone()); + self.transitive_captures_needed.push(*binding_id); capture_operands.push(Operand::Copy(Place::Capture(new_idx))); } } @@ -1498,6 +1659,38 @@ impl LoweringContext<'_> { // ─── 3.2: Core lower_expr dispatch ─────────────────────────────────────────── impl LoweringContext<'_> { + fn lower_scoped_block( + &mut self, + stmts: &[AstStmtId], + tail_expr: Option, + dest: Place, + ) { + let saved_locals = self.locals.clone(); + let watched_depth = self.watched_locals_stack.len(); + + for &stmt_id in stmts { + self.lower_stmt(stmt_id); + if self.builder.is_current_terminated() { + break; + } + } + + if !self.builder.is_current_terminated() { + match tail_expr { + Some(tail) => self.lower_expr(tail, dest), + None => { + self.builder + .assign(dest, Rvalue::Use(Operand::Constant(Constant::Null))); + } + } + } + + if !self.builder.is_current_terminated() { + self.emit_unwatch_to_depth(watched_depth); + } + self.restore_locals_after_scope(saved_locals, watched_depth); + } + fn lower_expr(&mut self, expr_id: AstExprId, dest: Place) { let prev_span = self.builder.current_source_span; if let Some(span) = self.span_for_expr(expr_id) { @@ -1589,21 +1782,7 @@ impl LoweringContext<'_> { } AstExpr::Block { stmts, tail_expr } => { - for &stmt_id in &stmts { - self.lower_stmt(stmt_id); - if self.builder.is_current_terminated() { - break; // Remaining stmts are dead code (after return/throw/break/continue) - } - } - if !self.builder.is_current_terminated() { - match tail_expr { - Some(tail) => self.lower_expr(tail, dest), - None => { - self.builder - .assign(dest, Rvalue::Use(Operand::Constant(Constant::Null))); - } - } - } + self.lower_scoped_block(&stmts, tail_expr, dest); } AstExpr::Match { @@ -1702,11 +1881,8 @@ impl<'db> LoweringContext<'db> { let receiver_op = if receiver_segments.len() == 1 { if let Some(&recv_local) = self.locals.get(&receiver_segments[0]) { Operand::Copy(Place::Local(recv_local)) - } else if let Some(cap_idx) = self - .capture_indices - .as_ref() - .and_then(|m| m.get(&receiver_segments[0])) - .copied() + } else if let Some(cap_idx) = + self.capture_index_for_name_at(expr_id, &receiver_segments[0]) { // Receiver is a captured variable — use capture slot. Operand::Copy(Place::Capture(cap_idx)) @@ -1775,11 +1951,8 @@ impl<'db> LoweringContext<'db> { let receiver_op = if receiver_segments.len() == 1 { if let Some(&recv_local) = self.locals.get(&receiver_segments[0]) { Operand::Copy(Place::Local(recv_local)) - } else if let Some(cap_idx) = self - .capture_indices - .as_ref() - .and_then(|m| m.get(&receiver_segments[0])) - .copied() + } else if let Some(cap_idx) = + self.capture_index_for_name_at(expr_id, &receiver_segments[0]) { // Receiver is a captured variable — use capture slot. Operand::Copy(Place::Capture(cap_idx)) @@ -1875,12 +2048,7 @@ impl<'db> LoweringContext<'db> { if let Some(&local) = self.locals.get(&local_name) { self.builder .assign(dest, Rvalue::Use(Operand::Copy(Place::Local(local)))); - } else if let Some(cap_idx) = self - .capture_indices - .as_ref() - .and_then(|m| m.get(&local_name)) - .copied() - { + } else if let Some(cap_idx) = self.capture_index_for_name_at(expr_id, &local_name) { // This variable is captured from an enclosing scope. // Emit a LoadCapture via Place::Capture. self.builder @@ -1942,12 +2110,7 @@ impl<'db> LoweringContext<'db> { self.builder.local_ty(root_local) }; (place, ty) - } else if let Some(cap_idx) = self - .capture_indices - .as_ref() - .and_then(|m| m.get(&segments[0])) - .copied() - { + } else if let Some(cap_idx) = self.capture_index_for_name_at(expr_id, &segments[0]) { let place = Place::Capture(cap_idx); let ty = self .path_root_ty(expr_id) @@ -2684,6 +2847,10 @@ impl LoweringContext<'_> { // Simple local variable receiver (e.g. `self`). if let Some(&recv_local) = self.locals.get(&receiver_segments[0]) { Operand::Copy(Place::Local(recv_local)) + } else if let Some(cap_idx) = + self.capture_index_for_name_at(callee, &receiver_segments[0]) + { + Operand::Copy(Place::Capture(cap_idx)) } else { Operand::Constant(Constant::Null) } @@ -2714,9 +2881,13 @@ impl LoweringContext<'_> { None => self.lower_to_operand(callee), }; let first_seg = &segments[0]; - let receiver_local = self.locals.get(first_seg).copied(); - if let Some(receiver_local) = receiver_local { - let receiver_op = Operand::Copy(Place::Local(receiver_local)); + let receiver_op = if let Some(&receiver_local) = self.locals.get(first_seg) { + Some(Operand::Copy(Place::Local(receiver_local))) + } else { + self.capture_index_for_name_at(callee, first_seg) + .map(|cap_idx| Operand::Copy(Place::Capture(cap_idx))) + }; + if let Some(receiver_op) = receiver_op { let mut all_args = vec![receiver_op]; all_args.extend(args.iter().map(|&a| self.lower_to_operand(a))); (callee_op, all_args) @@ -3481,11 +3652,6 @@ impl LoweringContext<'_> { let local = self.builder .declare_local(Some(name.clone()), local_ty, None, is_watched); - self.locals.insert(name, local); - - if is_watched { - self.watched_locals_stack.push(local); - } if let Some(init) = initializer { self.lower_expr(init, Place::local(local)); @@ -3495,6 +3661,15 @@ impl LoweringContext<'_> { Rvalue::Use(Operand::Constant(Constant::Null)), ); } + + self.locals.insert(name, local); + if let Some(binding_id) = self.binding_id_for_statement(stmt_id, pattern) { + self.binding_locals.insert(binding_id, local); + } + + if is_watched { + self.watched_locals_stack.push(local); + } } AstStmt::While { @@ -3572,6 +3747,9 @@ impl LoweringContext<'_> { collection, body, } => { + let saved_locals = self.locals.clone(); + let watched_depth = self.watched_locals_stack.len(); + // 1. Evaluate collection into a temp local let coll_ty = self.expr_ty(collection); let coll_local = self.builder.temp(coll_ty.clone()); @@ -3606,7 +3784,6 @@ impl LoweringContext<'_> { // Register loop context so break/continue work inside for-loops let prev_loop = self.loop_context.take(); - let watched_depth = self.watched_locals_stack.len(); self.loop_context = Some(LoopContext { break_target: bb_exit, continue_target: bb_after, @@ -3674,6 +3851,9 @@ impl LoweringContext<'_> { Rvalue::Use(Operand::Copy(Place::Local(elem_local))), ); self.locals.insert(name.clone(), local); + if let Some(binding_id) = self.binding_id_for_statement(stmt_id, binding) { + self.binding_locals.insert(binding_id, local); + } } } @@ -3684,8 +3864,10 @@ impl LoweringContext<'_> { self.lower_expr(body, Place::local(body_temp)); if !self.builder.is_current_terminated() { + self.emit_unwatch_to_depth(watched_depth); self.builder.goto(bb_after); } + self.restore_locals_after_scope(saved_locals, watched_depth); // 7. After: __idx += 1 self.builder.set_current_block(bb_after); @@ -3709,11 +3891,10 @@ impl LoweringContext<'_> { if let Some(e) = expr { self.lower_expr(e, Place::local(ret)); } - // Unwatch all watched locals before returning - let watched = self.watched_locals_stack.clone(); - for &local in watched.iter().rev() { - self.builder.unwatch(local); - } + // Unwatch all watched locals in this function (the stack is + // swapped at lambda boundaries, so depth=0 covers exactly the + // current function's watches). + self.emit_unwatch_to_depth(0); self.builder.goto(self.exit_block); // Create a dead successor block for the builder cursor // (subsequent statements in the same block-list are dead code) @@ -3726,6 +3907,10 @@ impl LoweringContext<'_> { AstStmt::Throw { value } => { let val_op = self.lower_to_operand(value); + // Unwatch all watched locals in this function before throwing, + // matching the Return path. Without this, a + // `watch let conn = …` followed by a `throw` leaks the watcher. + self.emit_unwatch_to_depth(0); self.builder.throw(val_op); let dead = self.builder.create_block(); self.builder.set_current_block(dead); @@ -3735,10 +3920,7 @@ impl LoweringContext<'_> { if let Some(ref loop_ctx) = self.loop_context { let target = loop_ctx.break_target; let depth = loop_ctx.watched_locals_depth; - let watched: Vec = self.watched_locals_stack[depth..].to_vec(); - for &local in watched.iter().rev() { - self.builder.unwatch(local); - } + self.emit_unwatch_to_depth(depth); self.builder.goto(target); } let dead = self.builder.create_block(); @@ -3749,10 +3931,7 @@ impl LoweringContext<'_> { if let Some(ref loop_ctx) = self.loop_context { let target = loop_ctx.continue_target; let depth = loop_ctx.watched_locals_depth; - let watched: Vec = self.watched_locals_stack[depth..].to_vec(); - for &local in watched.iter().rev() { - self.builder.unwatch(local); - } + self.emit_unwatch_to_depth(depth); self.builder.goto(target); } let dead = self.builder.create_block(); @@ -3845,11 +4024,7 @@ impl LoweringContext<'_> { AstExpr::Path(segments) if segments.len() == 1 => { if let Some(&local) = self.locals.get(&segments[0]) { Place::Local(local) - } else if let Some(cap_idx) = self - .capture_indices - .as_ref() - .and_then(|m| m.get(&segments[0])) - .copied() + } else if let Some(cap_idx) = self.capture_index_for_name_at(expr_id, &segments[0]) { // Assignment to a captured variable in a closure body. Place::Capture(cap_idx) @@ -3863,35 +4038,32 @@ impl LoweringContext<'_> { AstExpr::Path(segments) if segments.len() >= 2 => { // Multi-segment path lvalue: `a.b` or `a.b.c`. // Chain field projections from the root local or capture. - let (mut current_place, mut current_ty) = - if let Some(&l) = self.locals.get(&segments[0]) { - let ty = self - .path_root_ty(expr_id) - .unwrap_or_else(|| self.builder.local_ty(l)); - (Place::Local(l), ty) - } else if let Some(cap_idx) = self - .capture_indices - .as_ref() - .and_then(|m| m.get(&segments[0])) - .copied() - { - let ty = self - .path_root_ty(expr_id) - .unwrap_or_else(|| Ty::BuiltinUnknown { - attr: TyAttr::default(), - }); - (Place::Capture(cap_idx), ty) - } else { - let tmp = self.builder.temp(Ty::Null { + let (mut current_place, mut current_ty) = if let Some(&l) = + self.locals.get(&segments[0]) + { + let ty = self + .path_root_ty(expr_id) + .unwrap_or_else(|| self.builder.local_ty(l)); + (Place::Local(l), ty) + } else if let Some(cap_idx) = self.capture_index_for_name_at(expr_id, &segments[0]) + { + let ty = self + .path_root_ty(expr_id) + .unwrap_or_else(|| Ty::BuiltinUnknown { attr: TyAttr::default(), }); - ( - Place::Local(tmp), - Ty::Null { - attr: TyAttr::default(), - }, - ) - }; + (Place::Capture(cap_idx), ty) + } else { + let tmp = self.builder.temp(Ty::Null { + attr: TyAttr::default(), + }); + ( + Place::Local(tmp), + Ty::Null { + attr: TyAttr::default(), + }, + ) + }; for seg in &segments[1..] { if let Ty::Class(ref tn, _) = current_ty.clone() { @@ -4454,11 +4626,20 @@ impl LoweringContext<'_> { self.builder.set_current_block(bb_body); let (pattern, body, _) = arms[arm_idx]; + let saved_locals = self.locals.clone(); + let watched_depth = self.watched_locals_stack.len(); self.bind_pattern(scrutinee, pattern); self.lower_expr(body, dest.clone()); if !self.builder.is_current_terminated() { + // A `watch let` declared inside an arm body must be torn + // down on fallthrough. Without this the watcher leaks past + // the arm. Mirrors `lower_match_chain`. + self.emit_unwatch_to_depth(watched_depth); self.builder.goto(join); } + // Restore both the name→local map AND truncate the watched + // stack back to the arm-entry depth (mirrors `lower_scoped_block`). + self.restore_locals_after_scope(saved_locals, watched_depth); } } @@ -4524,11 +4705,19 @@ impl LoweringContext<'_> { self.builder.set_current_block(bb_wildcard_body); } let (pattern, body, _) = arms[idx]; + let saved_locals = self.locals.clone(); + let watched_depth = self.watched_locals_stack.len(); self.bind_pattern(scrutinee, pattern); self.lower_expr(body, dest); if !self.builder.is_current_terminated() { + // A `watch let` declared inside the wildcard body must be + // torn down on fallthrough; mirrors the int-arm path above. + self.emit_unwatch_to_depth(watched_depth); self.builder.goto(join); } + // Restore name→local map AND truncate the watched stack back to + // the arm-entry depth (mirrors `lower_scoped_block`). + self.restore_locals_after_scope(saved_locals, watched_depth); } else { // No wildcard — decide what the otherwise block does. // Use `is_switch_exhaustive` (which may be inferred for TypeTag) @@ -4612,11 +4801,20 @@ impl LoweringContext<'_> { // Exhaustive last arm: skip the pattern test — it must match. if exhaustive && rest.is_empty() && arm.guard.is_none() { + let saved_locals = self.locals.clone(); + let watched_depth = self.watched_locals_stack.len(); self.bind_pattern(scrutinee, arm.pattern); self.lower_expr(arm.body, dest); if !self.builder.is_current_terminated() { + // A `watch let` declared inside an arm body must be torn + // down on fallthrough. Without this the watcher leaks past + // the arm. + self.emit_unwatch_to_depth(watched_depth); self.builder.goto(join); } + // Restore both the name→local map AND truncate the watched stack + // back to the arm-entry depth (mirrors `lower_scoped_block`). + self.restore_locals_after_scope(saved_locals, watched_depth); return; } @@ -4626,6 +4824,8 @@ impl LoweringContext<'_> { self.lower_pattern_test(scrutinee, arm.pattern, bb_body, bb_next); self.builder.set_current_block(bb_body); + let saved_locals = self.locals.clone(); + let watched_depth = self.watched_locals_stack.len(); self.bind_pattern(scrutinee, arm.pattern); if let Some(guard) = arm.guard { let guard_op = self.lower_to_operand(guard); @@ -4635,8 +4835,11 @@ impl LoweringContext<'_> { } self.lower_expr(arm.body, dest.clone()); if !self.builder.is_current_terminated() { + // See exhaustive arm comment above. + self.emit_unwatch_to_depth(watched_depth); self.builder.goto(join); } + self.restore_locals_after_scope(saved_locals, watched_depth); self.builder.set_current_block(bb_next); self.lower_match_chain(scrutinee, rest, dest, join, exhaustive); @@ -4840,6 +5043,7 @@ impl LoweringContext<'_> { Rvalue::Use(Operand::Copy(Place::Local(scrutinee))), ); self.locals.insert(name, local); + self.record_pattern_binding_local(pat_id, local); } } } @@ -4934,58 +5138,158 @@ impl LoweringContext<'_> { ) { use baml_compiler2_ast::CatchClauseKind; + #[derive(Clone)] + struct ClauseLocals { + binding_name: Option, + binding_local: Option, + binding_copy_local: Option, + stack_trace_name: Option, + stack_trace_payload: Option, + stack_trace_copy_local: Option, + } + + fn install_clause_locals( + ctx: &mut LoweringContext<'_>, + error_local: Local, + clause: &ClauseLocals, + ) { + if let (Some(name), Some(local)) = (&clause.binding_name, clause.binding_local) { + ctx.locals.insert(name.clone(), local); + } + if let Some(binding_copy_local) = clause.binding_copy_local { + ctx.builder.assign( + Place::local(binding_copy_local), + Rvalue::Use(Operand::Copy(Place::Local(error_local))), + ); + } + if let (Some(name), Some(local)) = + (&clause.stack_trace_name, clause.stack_trace_copy_local) + { + ctx.locals.insert(name.clone(), local); + } + if let (Some(payload), Some(copy_local)) = + (clause.stack_trace_payload, clause.stack_trace_copy_local) + && payload != copy_local + { + ctx.builder.assign( + Place::local(copy_local), + Rvalue::Use(Operand::Copy(Place::Local(payload))), + ); + } + } + + let saved_catch_outer_locals = self.locals.clone(); let bb_join = self.builder.create_block(); let bb_handler = self.builder.create_block(); // Use the user-provided binding name (e.g. `e` from `catch (e)`) so it - // shows up in bytecode instead of an anonymous `_N` temp. - let binding_name = clauses - .first() - .and_then(|c| self.body.patterns[c.binding].binding_name().cloned()); + // shows up in bytecode instead of an anonymous `_N` temp. Only do this + // for single-clause catches with a non-captured binding. + let single_clause_binding_name = clauses.first().and_then(|c| { + if clauses.len() == 1 && !self.pattern_binding_is_captured(c.binding) { + self.body.patterns[c.binding].binding_name().cloned() + } else { + None + } + }); let error_local = self.builder.declare_local( - binding_name.clone(), + single_clause_binding_name, Ty::BuiltinUnknown { attr: TyAttr::default(), }, None, false, ); - if let Some(name) = binding_name { - self.locals.insert(name, error_local); - } - // Declare stack trace local if the catch clause has a second binding. - let stack_trace_local = clauses.first().and_then(|c| { - c.stack_trace_binding.map(|st_pat| { - let st_name = self.body.patterns[st_pat].binding_name().cloned(); - let local = self.builder.declare_local( - st_name.clone(), + let stack_trace_local = clauses + .iter() + .any(|c| c.stack_trace_binding.is_some()) + .then(|| { + self.builder.declare_local( + None, Ty::BuiltinUnknown { attr: TyAttr::default(), }, None, false, - ); - if let Some(name) = st_name { - self.locals.insert(name, local); + ) + }); + + let mut clause_locals = Vec::with_capacity(clauses.len()); + for clause in clauses { + let binding_name = self.body.patterns[clause.binding].binding_name().cloned(); + let binding_is_captured = self.pattern_binding_is_captured(clause.binding); + let (binding_local, binding_copy_local) = match binding_name.clone() { + Some(name) if binding_is_captured => { + let local = self.builder.declare_local( + Some(name), + Ty::BuiltinUnknown { + attr: TyAttr::default(), + }, + None, + false, + ); + self.record_pattern_binding_local(clause.binding, local); + (Some(local), Some(local)) } - local - }) - }); + Some(_) => { + self.record_pattern_binding_local(clause.binding, error_local); + (Some(error_local), None) + } + None => (None, None), + }; + + let (stack_trace_name, stack_trace_copy_local) = if let (Some(st_pat), Some(payload)) = + (clause.stack_trace_binding, stack_trace_local) + { + let name = self.body.patterns[st_pat].binding_name().cloned(); + let is_captured = self.pattern_binding_is_captured(st_pat); + match name.clone() { + Some(name) if is_captured => { + let local = self.builder.declare_local( + Some(name.clone()), + Ty::BuiltinUnknown { + attr: TyAttr::default(), + }, + None, + false, + ); + self.record_pattern_binding_local(st_pat, local); + (Some(name), Some(local)) + } + Some(name) => { + self.record_pattern_binding_local(st_pat, payload); + (Some(name), Some(payload)) + } + None => (None, None), + } + } else { + (None, None) + }; + + clause_locals.push(ClauseLocals { + binding_name, + binding_local, + binding_copy_local, + stack_trace_name, + stack_trace_payload: stack_trace_local, + stack_trace_copy_local, + }); + } // Flatten all arms from all clauses (blocks created lazily below). - let mut arms: Vec<(baml_compiler2_ast::CatchArm, bool)> = Vec::new(); - for clause in clauses { + let mut arms: Vec<(baml_compiler2_ast::CatchArm, bool, usize)> = Vec::new(); + for (clause_idx, clause) in clauses.iter().enumerate() { for &arm_id in &clause.arms { let arm = self.body.catch_arms[arm_id].clone(); let pat = &self.body.patterns[arm.pattern]; let is_wildcard = matches!(pat.kind, AstPatternKind::Wildcard) && pat.narrow.is_none(); - arms.push((arm, is_wildcard)); + arms.push((arm, is_wildcard, clause_idx)); } } - let has_wildcard = arms.iter().any(|(_, is_wc)| *is_wc); + let has_wildcard = arms.iter().any(|(_, is_wc, _)| *is_wc); let is_catch_all_panics = clauses .iter() .any(|clause| matches!(clause.kind, CatchClauseKind::CatchAllPanics)); @@ -5023,21 +5327,27 @@ impl LoweringContext<'_> { // Switch on Rvalue::TypeTag instead of a sequential is_type chain. let switch_arms: Vec<(AstPatId, AstExprId, Option)> = arms .iter() - .map(|(arm, _)| (arm.pattern, arm.body, None)) + .map(|(arm, _, _)| (arm.pattern, arm.body, None)) .collect(); self.builder.set_current_block(bb_handler); - if self.try_lower_as_switch( - error_local, - &switch_arms, - dest.clone(), - bb_join, - SwitchOtherwise::Catch { + if clauses.len() == 1 { + install_clause_locals(self, error_local, &clause_locals[0]); + } + if clauses.len() == 1 + && self.try_lower_as_switch( error_local, - needs_throw_if_panic, - }, - None, - ) { + &switch_arms, + dest.clone(), + bb_join, + SwitchOtherwise::Catch { + error_local, + needs_throw_if_panic, + }, + None, + ) + { self.builder.set_current_block(bb_join); + self.restore_active_locals(saved_catch_outer_locals); return; } @@ -5046,10 +5356,17 @@ impl LoweringContext<'_> { // doesn't leave orphaned unterminated blocks). let arms_with_blocks: Vec<_> = arms .iter() - .map(|(arm, is_wc)| (arm.clone(), self.builder.create_block(), *is_wc)) + .map(|(arm, is_wc, clause_idx)| { + ( + arm.clone(), + self.builder.create_block(), + *is_wc, + *clause_idx, + ) + }) .collect(); - for &(ref arm, body_block, is_wildcard) in &arms_with_blocks { + for &(ref arm, body_block, is_wildcard, _) in &arms_with_blocks { if is_wildcard && needs_throw_if_panic { let bb_wildcard = self.builder.create_block(); self.builder @@ -5068,16 +5385,27 @@ impl LoweringContext<'_> { } // Lower each arm body. - for &(ref arm, body_block, _) in &arms_with_blocks { + for &(ref arm, body_block, _, clause_idx) in &arms_with_blocks { self.builder.set_current_block(body_block); + let saved_locals = self.locals.clone(); + let watched_depth = self.watched_locals_stack.len(); + let clause = clause_locals[clause_idx].clone(); + install_clause_locals(self, error_local, &clause); self.bind_pattern(error_local, arm.pattern); self.lower_expr(arm.body, dest.clone()); if !self.builder.is_current_terminated() { + // A `watch let` declared inside a catch-arm body must be + // torn down on fallthrough. + self.emit_unwatch_to_depth(watched_depth); self.builder.goto(bb_join); } + // Restore name→local map AND truncate the watched stack back to + // the arm-entry depth (mirrors `lower_scoped_block`). + self.restore_locals_after_scope(saved_locals, watched_depth); } self.builder.set_current_block(bb_join); + self.restore_active_locals(saved_catch_outer_locals); } } diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index 17d6fdf98d..6b64d3b426 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -78,6 +78,42 @@ struct CallbackThrowProvenance { callback_concrete_throws: Option, } +struct ScopedLocalsSnapshot { + locals: FxHashMap, + declared_types: FxHashMap, + let_binding_patterns: FxHashMap, + scoped_local_declarations_len: usize, + scoped_local_assignments_len: usize, +} + +struct ScopedLocalDeclaration { + name: Name, + /// The pattern of this declaration. Used by `restore_scoped_locals_inner` + /// to identify "inner" bindings (those declared in the closing scope) so + /// assignments to inner bindings can be filtered out — Slack rule 3 vs + /// rule 2. The pattern (rather than name) is needed to distinguish + /// inner-shadow assignments from outer-binding assignments. + pattern: PatId, + previous_local: Option, + previous_declared_type: Option, + previous_let_binding_pattern: Option, +} + +/// One entry in `scoped_local_assignments`: a per-name assignment recorded +/// during type inference. `pattern` carries the binding identity at the +/// assignment site: +/// - `Some(PatId)` means the assignment targets a let-binding's pattern +/// — used to distinguish inner-shadow assignments (drop on scope exit) +/// from outer-binding assignments (propagate). +/// - `None` means the name has no let-binding pattern in scope (e.g. a +/// function parameter assignment). These are always treated as +/// outer-scope and propagate on scope exit. +#[derive(Clone)] +struct ScopedAssignment { + name: Name, + pattern: Option, +} + struct BuilderThrowsAnalysis<'a, 'db> { builder: &'a TypeInferenceBuilder<'db>, } @@ -193,6 +229,19 @@ pub struct TypeInferenceBuilder<'db> { /// establishment can keep declaration-side binding types in sync with the /// flow-sensitive local type seen by MIR lowering. let_binding_patterns: FxHashMap, + /// Per-declaration restore points for active name-keyed lookup maps. + /// + /// A lexical scope exit must remove declarations introduced inside that + /// scope, but it must restore a shadowed name to the state immediately + /// before the shadowing declaration rather than to scope entry. That keeps + /// earlier outer assignments in the same scope visible after the block. + scoped_local_declarations: Vec, + /// Assignments whose active local type was updated by assignment or + /// container establishment. Tracked by binding identity (`PatId`) so + /// scope-restore can filter inner-shadow assignments (which must NOT + /// propagate — rule 3) from outer-binding assignments (which MUST + /// propagate — rule 2). + scoped_local_assignments: Vec, /// Member resolutions: for field-access expressions that resolved to a /// class field, enum variant, method, or free function — records the /// structural path so MIR can emit the correct `QualifiedName` and LSP @@ -263,6 +312,146 @@ pub struct TypeInferenceBuilder<'db> { } impl<'db> TypeInferenceBuilder<'db> { + fn snapshot_scoped_locals(&self) -> ScopedLocalsSnapshot { + ScopedLocalsSnapshot { + locals: self.locals.clone(), + declared_types: self.declared_types.clone(), + let_binding_patterns: self.let_binding_patterns.clone(), + scoped_local_declarations_len: self.scoped_local_declarations.len(), + scoped_local_assignments_len: self.scoped_local_assignments.len(), + } + } + + fn restore_scoped_locals(&mut self, snapshot: ScopedLocalsSnapshot) { + self.restore_scoped_locals_inner(snapshot); + } + + fn restore_scoped_locals_inner(&mut self, snapshot: ScopedLocalsSnapshot) { + // Pull the new assignments and declarations introduced since the + // snapshot. We filter assignments by binding identity below, so the + // names alone are not enough. + let new_assignments: Vec = self + .scoped_local_assignments + .split_off(snapshot.scoped_local_assignments_len); + let scoped_declarations = self + .scoped_local_declarations + .split_off(snapshot.scoped_local_declarations_len); + + // The PatIds of bindings declared inside the closing scope. An + // assignment whose pattern is in this set targets an inner shadow + // and must NOT propagate to the outer scope (Slack rule 3). + let inner_pat_ids: FxHashSet = scoped_declarations + .iter() + .map(|declaration| declaration.pattern) + .collect(); + + // Filter assignments: keep those that target a binding declared in an + // outer scope (or have no pattern, meaning a parameter assignment — + // always propagated). + let kept_assignments: Vec = new_assignments + .into_iter() + .filter(|assignment| match assignment.pattern { + Some(pat) => !inner_pat_ids.contains(&pat), + None => true, + }) + .collect(); + let assigned_names: FxHashSet = kept_assignments + .iter() + .map(|assignment| assignment.name.clone()) + .collect(); + + // Roll back inner declarations: each declaration's previous_* fields + // capture the state of `locals`/`declared_types`/`let_binding_patterns` + // immediately before the declaration. Walking declarations in reverse + // restores the outer snapshot — except where a kept (outer) assignment + // updated the same name, which we preserve in the locals loop below. + for declaration in scoped_declarations.into_iter().rev() { + Self::restore_map_entry( + &mut self.locals, + declaration.name.clone(), + declaration.previous_local, + ); + Self::restore_map_entry( + &mut self.declared_types, + declaration.name.clone(), + declaration.previous_declared_type, + ); + Self::restore_map_entry( + &mut self.let_binding_patterns, + declaration.name, + declaration.previous_let_binding_pattern, + ); + } + + let local_names = self + .locals + .keys() + .chain(snapshot.locals.keys()) + .cloned() + .collect::>(); + for name in local_names { + if assigned_names.contains(&name) { + continue; + } + Self::restore_map_entry( + &mut self.locals, + name.clone(), + snapshot.locals.get(&name).cloned(), + ); + } + + self.declared_types = snapshot.declared_types; + self.let_binding_patterns = snapshot.let_binding_patterns; + // Re-extend the outer scope's assignment record with the kept + // (outer-targeting) assignments so a further enclosing scope's + // restore can also see them. + self.scoped_local_assignments.extend(kept_assignments); + } + + fn restore_map_entry(map: &mut FxHashMap, name: Name, previous: Option) { + if let Some(previous) = previous { + map.insert(name, previous); + } else { + map.remove(&name); + } + } + + fn declare_scoped_local( + &mut self, + name: Name, + pattern: PatId, + ty: Ty, + declared_ty: Option, + ) { + self.scoped_local_declarations.push(ScopedLocalDeclaration { + previous_local: self.locals.get(&name).cloned(), + previous_declared_type: self.declared_types.get(&name).cloned(), + previous_let_binding_pattern: self.let_binding_patterns.get(&name).copied(), + name: name.clone(), + pattern, + }); + + self.let_binding_patterns.insert(name.clone(), pattern); + self.locals.insert(name.clone(), ty); + if let Some(declared_ty) = declared_ty { + self.declared_types.insert(name, declared_ty); + } else { + self.declared_types.remove(&name); + } + } + + fn assign_local(&mut self, name: Name, ty: Ty) { + // Resolve the binding identity at the assignment site. If the name has + // a let-pattern in `let_binding_patterns`, the assignment targets that + // binding (which may be an outer or inner one). If not, the name maps + // to a parameter — record a None pattern so scope-restore always + // propagates the assignment outward. + let pattern = self.let_binding_patterns.get(&name).copied(); + self.locals.insert(name.clone(), ty); + self.scoped_local_assignments + .push(ScopedAssignment { name, pattern }); + } + pub fn new( context: InferContext<'db>, res_ctx: &'db PackageResolutionContext<'db>, @@ -279,6 +468,8 @@ impl<'db> TypeInferenceBuilder<'db> { expressions: FxHashMap::default(), bindings: FxHashMap::default(), let_binding_patterns: FxHashMap::default(), + scoped_local_declarations: Vec::new(), + scoped_local_assignments: Vec::new(), resolutions: FxHashMap::default(), res_ctx, package_items, @@ -352,6 +543,10 @@ impl<'db> TypeInferenceBuilder<'db> { /// Also records the declared type (parameters always have annotations). /// Uses `entry().or_insert()` so repeated calls (e.g. from narrowing /// save/restore) don't overwrite the original declared type. + /// + /// Function and lambda parameters do not have AST `PatId`s, so they + /// cannot flow through `declare_scoped_local`. Their assignments are + /// tracked separately via `ScopedAssignment { pattern: None }`. pub fn add_local(&mut self, name: Name, ty: Ty) { self.declared_types .entry(name.clone()) @@ -359,6 +554,31 @@ impl<'db> TypeInferenceBuilder<'db> { self.locals.insert(name, ty); } + /// Apply a transient type narrowing for `name` — used inside match arms + /// to refine the scrutinee's type for the arm body. This is NOT a + /// binding declaration: the surrounding `snapshot_scoped_locals` / + /// `restore_scoped_locals` pair owns the rollback. Tracked + /// assignments inside the arm body still propagate per Slack rule 2. + /// + /// Exists so all `self.locals` writes are named. + fn narrow_local(&mut self, name: Name, ty: Ty) { + self.locals.insert(name, ty); + } + + /// Seed a captured-name marker as `Ty::Unknown` to suppress false + /// "unresolved name" diagnostics inside a lambda body. This is NOT a + /// binding; the actual capture's type is resolved by the parent scope. + /// + /// Exists so all `self.locals` writes are named. + fn seed_capture_unknown(&mut self, name: Name) { + self.locals.insert( + name, + Ty::Unknown { + attr: TyAttr::default(), + }, + ); + } + fn sync_let_binding_type(&mut self, name: &Name, ty: Ty) { if let Some(pattern_id) = self.let_binding_patterns.get(name).copied() { self.bindings.insert(pattern_id, ty); @@ -1066,6 +1286,7 @@ impl<'db> TypeInferenceBuilder<'db> { ) } Expr::Block { stmts, tail_expr } => { + let snapshot = self.snapshot_scoped_locals(); let mut diverged_at: Option<(usize, StmtId)> = None; for (i, stmt_id) in stmts.iter().enumerate() { if self.check_stmt_with_early_return_narrowing(*stmt_id, body) { @@ -1073,7 +1294,7 @@ impl<'db> TypeInferenceBuilder<'db> { break; } } - if let Some((div_idx, div_stmt)) = diverged_at { + let ty = if let Some((div_idx, div_stmt)) = diverged_at { let remaining = stmts.len() - div_idx - 1 + usize::from(tail_expr.is_some()); if remaining > 0 { self.context.report_warning_at_stmt( @@ -1093,7 +1314,9 @@ impl<'db> TypeInferenceBuilder<'db> { .unwrap_or(Ty::Void { attr: TyAttr::default(), }) - } + }; + self.restore_scoped_locals(snapshot); + ty } Expr::MemberAccess { base, member } => { // `MemberAccess` now only comes from `FIELD_ACCESS_EXPR` (complex base @@ -1488,6 +1711,7 @@ impl<'db> TypeInferenceBuilder<'db> { match expr { // Block: check the tail expression against expected type Expr::Block { stmts, tail_expr } => { + let snapshot = self.snapshot_scoped_locals(); let mut diverged_at: Option<(usize, StmtId)> = None; for (i, stmt_id) in stmts.iter().enumerate() { if self.check_stmt_with_early_return_narrowing(*stmt_id, body) { @@ -1534,6 +1758,7 @@ impl<'db> TypeInferenceBuilder<'db> { attr: TyAttr::default(), } }; + self.restore_scoped_locals(snapshot); self.record_expr_type(expr_id, ty.clone()); ty } @@ -2013,12 +2238,8 @@ impl<'db> TypeInferenceBuilder<'db> { if let Some(ty) = init_ty { self.bindings.insert(*pattern, ty.clone()); if let Some(name) = body.patterns[*pattern].binding_name() { - self.let_binding_patterns.insert(name.clone(), *pattern); - self.locals.insert(name.clone(), ty); // Record declared type only for annotated let-bindings. - if let Some(decl_ty) = ann_ty_for_decl { - self.declared_types.insert(name.clone(), decl_ty); - } + self.declare_scoped_local(name.clone(), *pattern, ty, ann_ty_for_decl); } } diverges @@ -2041,10 +2262,29 @@ impl<'db> TypeInferenceBuilder<'db> { Stmt::While { condition, body: while_body, + after, .. } => { self.infer_expr(*condition, body); + // Snapshot scoped locals before the body and restore after, + // mirroring `Stmt::For`. Without this, a `let x = ...` + // inside the while body (or any narrowing of an outer name) + // leaks past the loop, violating Slack rule 1. + // `restore_scoped_locals` keeps outer-binding mutations from + // the body — Slack rule 2 — by filtering assignments + // through binding identity. + let snapshot = self.snapshot_scoped_locals(); self.infer_expr(*while_body, body); + self.restore_scoped_locals(snapshot); + // Type-check the C-style for `after` step, if present. It + // runs at the same lexical level as the body but in the + // surrounding scope (HIR P1.2.b puts it inside the wrapping + // While scope but outside the body's block scope), so we + // check it AFTER restoring the snapshot — body-declared lets + // are not in scope here. + if let Some(after_stmt) = after { + self.check_stmt(*after_stmt, body); + } false } // Design note: Stmt::For is kept as a first-class construct (not desugared @@ -2074,13 +2314,15 @@ impl<'db> TypeInferenceBuilder<'db> { // 3. Bind the loop variable to the element type let name = body.patterns[*binding].binding_name().cloned(); + let snapshot = self.snapshot_scoped_locals(); self.bindings.insert(*binding, elem_ty.clone()); if let Some(name) = name { - self.locals.insert(name, elem_ty); + self.declare_scoped_local(name, *binding, elem_ty, None); } // 4. Check the body self.infer_expr(*for_body, body); + self.restore_scoped_locals(snapshot); false } Stmt::Assign { target, value } => { @@ -2118,7 +2360,7 @@ impl<'db> TypeInferenceBuilder<'db> { // Update the local to the assigned value's type (invalidates narrowing) if let Expr::Path(segments) = &body.exprs[*target] { if segments.len() == 1 { - self.locals.insert(segments[0].clone(), value_ty); + self.assign_local(segments[0].clone(), value_ty); } } } else { @@ -2311,16 +2553,30 @@ impl<'db> TypeInferenceBuilder<'db> { .report_simple(TirTypeError::UnreachableArm, arm.body); } - let mut saved = Vec::new(); - + // Route arm-pattern bindings through the standard + // snapshot/declare_scoped_local/restore flow. A prior + // ad-hoc `saved: Vec<(Name, Option)>` plumbing wrote directly + // to `self.locals` without registering the binding's pattern, so + // an arm-arm pattern like `Foo(xs)` with the same name as an + // outer `let xs = []` would share the outer's pattern slot — and + // when the arm body widens `xs` (e.g. `xs.push("s")`), it would + // widen the OUTER binding's evolving type. Slack rule 3 requires + // shadowing to keep identities separate; using the binding's + // PatId achieves that via `restore_scoped_locals_inner`'s + // inner_pat_ids filter (P2.1). + let snapshot = self.snapshot_scoped_locals(); + + // Narrow the scrutinee for the duration of this arm. This is a + // type narrowing, not a let-binding — the snapshot's locals map + // captures the pre-narrow type and `restore_scoped_locals` rolls + // it back unless the arm body assigned to it (in which case the + // assignment correctly propagates per Slack rule 2). if let Some(name) = &scrutinee_name { - saved.push((name.clone(), self.locals.get(name).cloned())); - self.locals.insert(name.clone(), tp.narrowed_ty.clone()); + self.narrow_local(name.clone(), tp.narrowed_ty.clone()); } if let Some((bind_name, bind_ty)) = &tp.binding { - saved.push((bind_name.clone(), self.locals.get(bind_name).cloned())); - self.locals.insert(bind_name.clone(), bind_ty.clone()); + self.declare_scoped_local(bind_name.clone(), pattern_id, bind_ty.clone(), None); } if let Some(guard_expr) = arm.guard { @@ -2330,13 +2586,7 @@ impl<'db> TypeInferenceBuilder<'db> { let arm_ty = self.infer_expr(arm.body, body); arm_types.push(arm_ty); - for (name, previous) in saved { - if let Some(prev_ty) = previous { - self.locals.insert(name, prev_ty); - } else { - self.locals.remove(&name); - } - } + self.restore_scoped_locals(snapshot); if arm.guard.is_none() { if tp.covers_all { @@ -2409,6 +2659,15 @@ impl<'db> TypeInferenceBuilder<'db> { .insert(clause.binding, clause_binding_ty.clone()); // Type the optional stack trace binding as baml.errors.StackTrace. + // + // The stack-trace binding's lifetime is the catch-clause body. + // We snapshot scoped locals before introducing it and restore + // after the clause's arms finish, so the binding does not leak + // into the rest of the function. + let st_snapshot = clause + .stack_trace_binding + .is_some() + .then(|| self.snapshot_scoped_locals()); if let Some(st_binding) = clause.stack_trace_binding { let db = self.context.db(); let baml_name = baml_base::Name::new("baml"); @@ -2432,9 +2691,14 @@ impl<'db> TypeInferenceBuilder<'db> { attr: TyAttr::default(), }); self.bindings.insert(st_binding, st_ty.clone()); - // Also insert into locals so name resolution finds it. + // Register the stack-trace name through declare_scoped_local + // so name resolution finds it AND so the binding is unwound + // by the matching restore_scoped_locals at the end of the + // clause. A prior raw `self.locals.insert` had no paired + // snapshot/restore at all and leaked the binding into the + // rest of the function. if let Some(name) = body.patterns[st_binding].binding_name() { - self.locals.insert(name.clone(), st_ty); + self.declare_scoped_local(name.clone(), st_binding, st_ty, None); } } @@ -2489,37 +2753,48 @@ impl<'db> TypeInferenceBuilder<'db> { } }; - let mut saved = Vec::new(); + // Route catch-arm bindings through + // snapshot/declare_scoped_local/restore so the arm pattern's + // PatId is recorded. A prior ad-hoc `saved: Vec<(Name, + // Option)>` would write to `self.locals` without + // registering the pattern, so an arm pattern with the same + // name as an outer binding shared the outer's pattern slot. + let arm_snapshot = self.snapshot_scoped_locals(); + if let Some(name) = &binding_name { - saved.push((name.clone(), self.locals.get(name).cloned())); - self.locals - .insert(name.clone(), catch_binding_ty(clause_binding_ty.clone())); + self.declare_scoped_local( + name.clone(), + clause.binding, + catch_binding_ty(clause_binding_ty.clone()), + None, + ); } if let Some((arm_bind_name, arm_bind_pat_ty)) = tp.binding { - saved.push(( - arm_bind_name.clone(), - self.locals.get(&arm_bind_name).cloned(), - )); - self.locals - .insert(arm_bind_name, catch_binding_ty(arm_bind_pat_ty)); + self.declare_scoped_local( + arm_bind_name, + arm.pattern, + catch_binding_ty(arm_bind_pat_ty), + None, + ); } let arm_ty = self.infer_expr(arm.body, body); result_members.push(arm_ty); - for (name, previous) in saved { - if let Some(prev_ty) = previous { - self.locals.insert(name, prev_ty); - } else { - self.locals.remove(&name); - } - } + self.restore_scoped_locals(arm_snapshot); for handled in &throw_matches.definitely_handled { residual.remove(handled); } } + // Restore the snapshot taken before the clause's stack-trace + // binding was introduced. This unwinds the stack-trace name from + // `locals` so it does not leak past the clause. + if let Some(snapshot) = st_snapshot { + self.restore_scoped_locals(snapshot); + } + if matches!( clause.kind, baml_compiler2_ast::CatchClauseKind::CatchAll @@ -5034,7 +5309,7 @@ impl<'db> TypeInferenceBuilder<'db> { } else { Ty::List(Box::new(widened_arg), container_attr) }; - self.locals.insert(local_name.clone(), new_ty.clone()); + self.assign_local(local_name.clone(), new_ty.clone()); self.sync_let_binding_type(&local_name, new_ty.clone()); new_ty } else if !self.is_subtype(&widened_arg, elem_ty) { @@ -5122,7 +5397,8 @@ impl<'db> TypeInferenceBuilder<'db> { } else { Ty::List(Box::new(widened_val.clone()), container_attr) }; - self.locals.insert(local_name, new_ty); + self.assign_local(local_name.clone(), new_ty.clone()); + self.sync_let_binding_type(&local_name, new_ty); } else if !self.is_subtype(&widened_val, elem_ty) { self.context.report( TirTypeError::TypeMismatch { @@ -5163,7 +5439,8 @@ impl<'db> TypeInferenceBuilder<'db> { container_attr, ) }; - self.locals.insert(local_name, new_ty); + self.assign_local(local_name.clone(), new_ty.clone()); + self.sync_let_binding_type(&local_name, new_ty); } else { if !self.is_subtype(&widened_key, key_ty) { self.context.report( @@ -5784,6 +6061,9 @@ impl<'db> TypeInferenceBuilder<'db> { // Save current state (including expressions to prevent ExprId collisions) let saved_locals = self.locals.clone(); let saved_declared = self.declared_types.clone(); + let saved_let_binding_patterns = std::mem::take(&mut self.let_binding_patterns); + let saved_scoped_local_declarations = std::mem::take(&mut self.scoped_local_declarations); + let saved_scoped_local_assignments = std::mem::take(&mut self.scoped_local_assignments); let saved_return_ty = self.declared_return_ty.clone(); let saved_generic_params = self.generic_params.clone(); let saved_expressions = std::mem::take(&mut self.expressions); @@ -5800,10 +6080,23 @@ impl<'db> TypeInferenceBuilder<'db> { new_generic_params.extend(func_def.generic_params.iter().cloned()); self.generic_params = new_generic_params; - // Seed lambda params (captures remain accessible via parent locals) + // Seed lambda params (captures remain accessible via parent locals). + // + // Directly overwrite `declared_types` and `locals` rather than going + // through `add_local`: that helper uses `entry().or_insert_with()` for + // `declared_types`, which would preserve a stale outer entry when a + // lambda param shadows an annotated outer let. The lambda param's + // declared type must replace any outer declaration so subsequent + // assignments inside the body type-check against the param's type + // (not the shadowed outer's). Also clear any stale + // `let_binding_patterns` entry the parent scope might have had under + // the same name; lambda params shadow outer let-patterns and the + // pattern's binding identity is irrelevant inside the lambda body. for (name_opt, ty) in param_tys { if let Some(name) = name_opt { - self.add_local(name.clone(), ty.clone()); + self.declared_types.insert(name.clone(), ty.clone()); + self.locals.insert(name.clone(), ty.clone()); + self.let_binding_patterns.remove(name); } } @@ -5817,36 +6110,24 @@ impl<'db> TypeInferenceBuilder<'db> { // // Also captures the lambda's `FileScopeId` for use as a position-independent // key in `nested_lambda_types` (avoids TextRange in Salsa-cached output). - let lambda_file_scope_id; - { + let lambda_file_scope_id = { let db = self.context.db(); let file = self.context.scope().file(db); let index = baml_compiler2_ppir::file_semantic_index(db, file); - let lambda_span = func_def.span; - let mut found_fsi = None; - for (i, scope) in index.scopes.iter().enumerate() { - if matches!(scope.kind, baml_compiler2_hir::scope::ScopeKind::Lambda) - && scope.range == lambda_span + // Captures are seeded only if the lambda scope is located (it + // always should be, but be defensive). + let found_fsi = index.lambda_scope_for(func_def.span); + if let Some(fsi) = found_fsi { + for (capture_name, _def_site) in + &index.scope_bindings[fsi.index() as usize].captures { - #[allow(clippy::cast_possible_truncation)] - { - found_fsi = Some(FileScopeId::new(i as u32)); + if !self.locals.contains_key(capture_name) { + self.seed_capture_unknown(capture_name.clone()); } - for (capture_name, _def_site) in &index.scope_bindings[i].captures { - if !self.locals.contains_key(capture_name) { - self.locals.insert( - capture_name.clone(), - Ty::Unknown { - attr: TyAttr::default(), - }, - ); - } - } - break; } } - lambda_file_scope_id = found_fsi; - } + found_fsi + }; // Set return type context for return statement checking inside lambda if let Some(ret) = expected_ret { @@ -5891,6 +6172,9 @@ impl<'db> TypeInferenceBuilder<'db> { self.lambda_effective_throws = saved_lambda_effective_throws; self.locals = saved_locals; self.declared_types = saved_declared; + self.let_binding_patterns = saved_let_binding_patterns; + self.scoped_local_declarations = saved_scoped_local_declarations; + self.scoped_local_assignments = saved_scoped_local_assignments; self.declared_return_ty = saved_return_ty; self.generic_params = saved_generic_params; diff --git a/baml_language/crates/baml_compiler2_tir/src/inference.rs b/baml_language/crates/baml_compiler2_tir/src/inference.rs index 6b79aef9ff..de14742ae9 100644 --- a/baml_language/crates/baml_compiler2_tir/src/inference.rs +++ b/baml_language/crates/baml_compiler2_tir/src/inference.rs @@ -16,16 +16,14 @@ use std::{ }; use baml_base::Name; -use baml_compiler2_ast::{ - AstSourceMap, Expr as AstExpr, ExprBody, ExprId, FunctionDef, PatId, Stmt as AstStmt, -}; +use baml_compiler2_ast::{AstSourceMap, Expr as AstExpr, ExprBody, ExprId, FunctionDef, PatId}; use baml_compiler2_hir::{ body::{FunctionBody, LetBody}, contributions::Definition, loc::{ClassLoc, EnumLoc, FunctionLoc, LetLoc, TypeAliasLoc}, package::{PackageId, PackageItems}, scope::{FileScopeId, ScopeId, ScopeKind}, - semantic_index::DefinitionSite, + semantic_index::{BindingId, DefinitionSite}, }; use rustc_hash::{FxHashMap, FxHashSet}; use text_size::TextRange; @@ -36,6 +34,25 @@ use crate::{ ty::{Ty, TyAttr}, }; +fn inference_owner_scope( + index: &baml_compiler2_hir::semantic_index::FileSemanticIndex<'_>, + mut scope_id: FileScopeId, +) -> FileScopeId { + loop { + let scope = &index.scopes[scope_id.index() as usize]; + if matches!( + scope.kind, + ScopeKind::Function | ScopeKind::Let | ScopeKind::Lambda + ) { + return scope_id; + } + let Some(parent) = scope.parent else { + return scope_id; + }; + scope_id = parent; + } +} + // ── Member Resolution ───────────────────────────────────────────────────── /// Records what a field-access expression resolved to during type inference. @@ -503,7 +520,8 @@ pub fn infer_scope_types<'db>( } ty }; - builder.add_local(param_name.clone(), param_ty); + builder.add_local(param_name.clone(), param_ty.clone()); + builder.param_types.push((param_name.clone(), param_ty)); } // Check root expression against declared return type @@ -550,7 +568,7 @@ pub fn infer_scope_types<'db>( // can resolve references to captures without reporting "unresolved name" // diagnostics. The loop below will override these with proper types. let captures = &index.scope_bindings[file_scope.index() as usize].captures; - for (capture_name, _def_site) in captures { + for (capture_name, _binding_id) in captures { builder.add_local( capture_name.clone(), Ty::Unknown { @@ -572,134 +590,73 @@ pub fn infer_scope_types<'db>( // (not just the enclosing Function/Let) are also resolved correctly. { let captures = &index.scope_bindings[file_scope.index() as usize].captures; + let mut inferred_owner_scopes = Vec::new(); for ancestor_fsi in index.ancestor_scopes(file_scope) { let anc_bindings = &index.scope_bindings[ancestor_fsi.index() as usize]; - let anc_scope = &index.scopes[ancestor_fsi.index() as usize]; - let anc_scope_id = index.scope_ids[ancestor_fsi.index() as usize]; + let inference_fsi = inference_owner_scope(index, ancestor_fsi); + let inference_scope_id = index.scope_ids[inference_fsi.index() as usize]; + let capture_declared_in_ancestor = + |_capture_name: &Name, binding_id: &BindingId| -> bool { + // Under the (scope, site) constraint, a same-name + // distinct binding cannot co-exist: parameter + // indices are unique within their scope and + // DefinitionSite::Statement/PatternBinding carry + // the scope-unique AST id directly. Capture + // resolution is identity-keyed; `capture_name` is + // not load-bearing here. + binding_id.scope == ancestor_fsi + && match binding_id.site { + DefinitionSite::Parameter(idx) => { + anc_bindings.params.iter().any(|(_, i)| *i == idx) + } + DefinitionSite::Statement(_) + | DefinitionSite::PatternBinding(_) => anc_bindings + .bindings + .iter() + .any(|binding| binding.site == binding_id.site), + } + }; // Only call infer_scope_types if this ancestor has any of // the captures we still need (avoids unnecessary Salsa calls). // For efficiency, check if any capture is declared in this scope. - let has_relevant_capture = - captures.iter().any(|(name, def_site)| match def_site { - DefinitionSite::Parameter(idx) => anc_bindings - .params - .iter() - .any(|(n, i)| n == name && i == idx), - DefinitionSite::Statement(_) | DefinitionSite::PatternBinding(_) => { - anc_bindings - .bindings - .iter() - .any(|(n, def, _)| n == name && def == def_site) - } - }); + let has_relevant_capture = captures + .iter() + .any(|(name, binding_id)| capture_declared_in_ancestor(name, binding_id)); if !has_relevant_capture { continue; } - let anc_inference = infer_scope_types(db, anc_scope_id); - for (capture_name, def_site) in captures { - // Check if this ancestor declares this capture. - let is_declared_here = match def_site { - DefinitionSite::Parameter(idx) => anc_bindings - .params - .iter() - .any(|(n, i)| n == capture_name && i == idx), - DefinitionSite::Statement(_) | DefinitionSite::PatternBinding(_) => { - anc_bindings - .bindings - .iter() - .any(|(n, def, _)| n == capture_name && def == def_site) - } - }; + let anc_inference = if let Some(idx) = inferred_owner_scopes + .iter() + .position(|(scope_id, _)| scope_id == &inference_scope_id) + { + inferred_owner_scopes[idx].1 + } else { + let inference = infer_scope_types(db, inference_scope_id); + inferred_owner_scopes.push((inference_scope_id, inference)); + inference + }; + for (capture_name, binding_id) in captures { + let def_site = binding_id.site; + let is_declared_here = + capture_declared_in_ancestor(capture_name, binding_id); if !is_declared_here { continue; } let actual_ty = match def_site { DefinitionSite::Parameter(idx) => { - anc_inference.param_type(*idx).cloned() - } - DefinitionSite::PatternBinding(pat_id) => { - anc_inference.binding_type(*pat_id).cloned() + anc_inference.param_type(idx).cloned() } - DefinitionSite::Statement(stmt_id) => { - // Get the ancestor's body to look up the Pat for this stmt. - // We must use the SAME body that the stmt_id was allocated in. - let body_opt: Option<&baml_compiler2_ast::ExprBody> = - match &anc_scope.kind { - ScopeKind::Function => { - // Find the function body in item_tree. - item_tree - .functions - .values() - .find(|fd| { - fd.span == anc_scope.range - && anc_scope.name.as_ref() == Some(&fd.name) - }) - .and_then(|fd| { - if let Some( - baml_compiler2_ast::FunctionBodyDef::Expr( - ref b, - _, - ), - ) = fd.body - { - // SAFETY: body is stored in item_tree which - // lives for the duration of this query. - Some(b) - } else { - None - } - }) - } - ScopeKind::Let => None, // handled below - _ => None, // Lambda bodies not accessible here - }; - if let Some(body) = body_opt { - let raw: u32 = (*stmt_id).into_raw().into_u32(); - if (raw as usize) < body.stmts.len() { - if let AstStmt::Let { pattern, .. } = &body.stmts[*stmt_id] - { - anc_inference.binding_type(*pattern).cloned() - } else { - None - } - } else { - None - } - } else if matches!(anc_scope.kind, ScopeKind::Let) { - // Let scope: look up the let body. - item_tree - .lets - .iter() - .find(|(_, ld)| { - ld.span == anc_scope.range - && anc_scope.name.as_ref() == Some(&ld.name) - }) - .and_then(|(local_id, _)| { - let let_loc = LetLoc::new(db, file, *local_id); - let body = - baml_compiler2_hir::body::let_body(db, let_loc); - if let LetBody::Expr(let_body) = body.as_ref() { - let raw: u32 = (*stmt_id).into_raw().into_u32(); - if (raw as usize) < let_body.stmts.len() { - if let AstStmt::Let { pattern, .. } = - &let_body.stmts[*stmt_id] - { - anc_inference - .binding_type(*pattern) - .cloned() - } else { - None - } - } else { - None - } - } else { - None - } - }) - } else { - None - } + DefinitionSite::Statement(_) | DefinitionSite::PatternBinding(_) => { + // `binding.site == def_site` uniquely + // identifies the binding under shadowing — + // a name tiebreaker would be redundant. + anc_bindings + .bindings + .iter() + .find(|binding| binding.site == def_site) + .and_then(|binding| { + anc_inference.binding_type(binding.pattern).cloned() + }) } }; if let Some(ty) = actual_ty { diff --git a/baml_language/crates/baml_compiler2_tir/src/resolve.rs b/baml_language/crates/baml_compiler2_tir/src/resolve.rs index 86487bd17b..7df1bb6c43 100644 --- a/baml_language/crates/baml_compiler2_tir/src/resolve.rs +++ b/baml_language/crates/baml_compiler2_tir/src/resolve.rs @@ -77,15 +77,12 @@ pub fn resolve_name_at_in_scope<'db>( let bindings = &index.scope_bindings[ancestor_id.index() as usize]; // Check let-bindings in this scope (reverse order for shadowing) - for (binding_name, def_site, binding_range) in bindings.bindings.iter().rev() { - if binding_name == name { - // Only visible if the binding precedes the use site - if binding_range.start() <= at_offset { - return ResolvedName::Local { - name: name.clone(), - definition_site: Some(*def_site), - }; - } + for binding in bindings.bindings.iter().rev() { + if &binding.name == name && index.binding_visible_at(binding, at_offset) { + return ResolvedName::Local { + name: name.clone(), + definition_site: Some(binding.site), + }; } } diff --git a/baml_language/crates/baml_lsp2_actions/src/completions.rs b/baml_language/crates/baml_lsp2_actions/src/completions.rs index 336280fe69..2ddeef7bf5 100644 --- a/baml_language/crates/baml_lsp2_actions/src/completions.rs +++ b/baml_language/crates/baml_lsp2_actions/src/completions.rs @@ -30,6 +30,8 @@ //! - `resolve_class_fields(class_loc)` — fields for field-access completions. //! - `file_item_tree(file)[enum_loc.id]` — variants for field-access on enums. +use std::collections::HashSet; + use baml_base::{Name, SourceFile, attr::TyAttr}; use baml_compiler_syntax::{SyntaxKind, SyntaxNode}; use baml_compiler2_hir::{ @@ -1036,10 +1038,12 @@ fn extract_pat_from_stmt( stmt_id: baml_compiler2_ast::StmtId, ) -> Option { let stmt = &expr_body.stmts[stmt_id]; - if let baml_compiler2_ast::Stmt::Let { pattern, .. } = stmt { - Some(*pattern) - } else { - None + match stmt { + baml_compiler2_ast::Stmt::Let { pattern, .. } + | baml_compiler2_ast::Stmt::For { + binding: pattern, .. + } => Some(*pattern), + _ => None, } } @@ -1093,20 +1097,20 @@ fn completions_for_value_position( let index = baml_compiler2_hir::file_semantic_index(db, file); let scope_id = index.scope_at_offset(offset, None); + let mut emitted_locals: HashSet = HashSet::new(); let mut sort_prefix = 0usize; for ancestor_id in index.ancestor_scopes(scope_id) { let bindings: &ScopeBindings = &index.scope_bindings[ancestor_id.index() as usize]; // Let bindings (reverse source order so most-recent is first). - for (name, _site, binding_range) in bindings.bindings.iter().rev() { + for binding in bindings.bindings.iter().rev() { // Only show bindings that are visible at the cursor position. - if binding_range.start() <= offset { + if index.binding_visible_at(binding, offset) + && emitted_locals.insert(binding.name.clone()) + { items.push( - Completion::new(name.as_str(), CompletionKind::Variable).with_sort(format!( - "{:03}_{}", - sort_prefix, - name.as_str() - )), + Completion::new(binding.name.as_str(), CompletionKind::Variable) + .with_sort(format!("{:03}_{}", sort_prefix, binding.name.as_str())), ); sort_prefix += 1; } @@ -1114,6 +1118,9 @@ fn completions_for_value_position( // Parameters. for (name, _idx) in &bindings.params { + if !emitted_locals.insert(name.clone()) { + continue; + } items.push( Completion::new(name.as_str(), CompletionKind::Variable) .with_detail("parameter") diff --git a/baml_language/crates/baml_lsp2_actions/src/completions_tests.rs b/baml_language/crates/baml_lsp2_actions/src/completions_tests.rs index fb0b1f3413..4e05364432 100644 --- a/baml_language/crates/baml_lsp2_actions/src/completions_tests.rs +++ b/baml_language/crates/baml_lsp2_actions/src/completions_tests.rs @@ -454,4 +454,45 @@ function Test() -> string { "Detail should contain '-> bool', got: {detail_str}" ); } + + #[test] + fn test_value_completion_hides_shadowed_same_scope_local() { + let test = CursorTest::new( + r#" +function Test() -> int { + let x = 1 + let x = 2 + x<[CURSOR] +} +"#, + ); + + let completions = completions_at(&test.db, test.cursor.file, test.cursor.offset); + let x_count = completions.iter().filter(|c| c.label == "x").count(); + + assert_eq!( + x_count, 1, + "Should only complete the innermost visible 'x', got: {completions:?}" + ); + } + + #[test] + fn test_value_completion_hides_shadowed_parameter() { + let test = CursorTest::new( + r#" +function Test(x: int) -> int { + let x = 2 + x<[CURSOR] +} +"#, + ); + + let completions = completions_at(&test.db, test.cursor.file, test.cursor.offset); + let x_count = completions.iter().filter(|c| c.label == "x").count(); + + assert_eq!( + x_count, 1, + "Should only complete the local that shadows parameter 'x', got: {completions:?}" + ); + } } diff --git a/baml_language/crates/baml_lsp2_actions/src/describe.rs b/baml_language/crates/baml_lsp2_actions/src/describe.rs index 480cb3f213..053128b87b 100644 --- a/baml_language/crates/baml_lsp2_actions/src/describe.rs +++ b/baml_language/crates/baml_lsp2_actions/src/describe.rs @@ -10,7 +10,10 @@ use baml_base::SourceFile; use baml_compiler_syntax::SyntaxKind; -use baml_compiler2_hir::contributions::DefinitionKind; +use baml_compiler2_hir::{ + contributions::DefinitionKind, + scope::{FileScopeId, ScopeKind}, +}; use serde::Serialize; use text_size::TextRange; @@ -353,15 +356,22 @@ fn describe_locals(db: &dyn Db, files: &[SourceFile], name: &str) -> Vec Vec { - let body = baml_compiler2_hir::body::function_body(db, func_loc); - if let baml_compiler2_hir::body::FunctionBody::Expr(expr_body) = - body.as_ref() - { - if let baml_compiler2_ast::Stmt::Let { pattern, .. } = - &expr_body.stmts[*stmt_id] - { - inference - .binding_type(*pattern) - .map(crate::utils::display_ty) - .unwrap_or_else(|| "unknown".to_string()) - } else { - "unknown".to_string() - } - } else { - "unknown".to_string() - } + pattern_from_owner_body(db, func_loc, index, owner_scope, stmt_id) + .and_then(|pattern| inference.binding_type(pattern)) + .map(crate::utils::display_ty) + .unwrap_or_else(|| "unknown".to_string()) } baml_compiler2_hir::semantic_index::DefinitionSite::PatternBinding( pat_id, ) => inference - .binding_type(*pat_id) + .binding_type(pat_id) .map(crate::utils::display_ty) .unwrap_or_else(|| "unknown".to_string()), baml_compiler2_hir::semantic_index::DefinitionSite::Parameter(_) => { @@ -408,7 +405,7 @@ fn describe_locals(db: &dyn Db, files: &[SourceFile], name: &str) -> Vec Vec Vec, + mut scope_id: FileScopeId, +) -> FileScopeId { + loop { + let scope = &index.scopes[scope_id.index() as usize]; + if matches!( + scope.kind, + ScopeKind::Function | ScopeKind::Let | ScopeKind::Lambda + ) { + return scope_id; + } + + let Some(parent) = scope.parent else { + return scope_id; + }; + scope_id = parent; + } +} + +/// Extract the binding pattern for a statement from the body that owns it. +/// +/// `StmtId` is arena-local to an `ExprBody`. For ordinary function/block scopes, +/// that owner is the enclosing function body. For lambda scopes, including block +/// descendants inside lambdas, the owner is the matched lambda body. +fn pattern_from_owner_body( + db: &dyn Db, + func_loc: baml_compiler2_hir::loc::FunctionLoc<'_>, + index: &baml_compiler2_hir::semantic_index::FileSemanticIndex<'_>, + owner_scope: FileScopeId, + stmt_id: baml_compiler2_ast::StmtId, +) -> Option { + let body = baml_compiler2_hir::body::function_body(db, func_loc); + let baml_compiler2_hir::body::FunctionBody::Expr(top_body) = body.as_ref() else { + return None; + }; + + match index.scopes[owner_scope.index() as usize].kind { + ScopeKind::Lambda => { + let source_map = baml_compiler2_hir::body::function_body_source_map(db, func_loc)?; + let mut lambda_ranges = Vec::new(); + + for ancestor_id in index.ancestor_scopes(owner_scope) { + let scope = &index.scopes[ancestor_id.index() as usize]; + match scope.kind { + ScopeKind::Lambda => lambda_ranges.push(scope.range), + ScopeKind::Function => break, + _ => {} + } + } + + lambda_ranges.reverse(); + let owner_body = descend_into_lambdas(top_body, &source_map, &lambda_ranges)?; + extract_pat_from_stmt(owner_body, stmt_id) + } + _ => extract_pat_from_stmt(top_body, stmt_id), + } +} + +/// Extract the binding pattern from a let/for statement in a specific body. +fn extract_pat_from_stmt( + expr_body: &baml_compiler2_ast::ExprBody, + stmt_id: baml_compiler2_ast::StmtId, +) -> Option { + let stmt = expr_body + .stmts + .iter() + .find_map(|(id, stmt)| (id == stmt_id).then_some(stmt))?; + + match stmt { + baml_compiler2_ast::Stmt::Let { pattern, .. } + | baml_compiler2_ast::Stmt::For { + binding: pattern, .. + } => Some(*pattern), + _ => None, + } +} + +/// Descend through nested lambda bodies using scope ranges as stable anchors. +fn descend_into_lambdas<'a>( + body: &'a baml_compiler2_ast::ExprBody, + source_map: &baml_compiler2_ast::AstSourceMap, + lambda_ranges: &[TextRange], +) -> Option<&'a baml_compiler2_ast::ExprBody> { + if lambda_ranges.is_empty() { + return Some(body); + } + + let target_range = lambda_ranges[0]; + for (expr_id, expr) in body.exprs.iter() { + if let baml_compiler2_ast::Expr::Lambda(func_def) = expr { + let expr_span = source_map.expr_span(expr_id); + if expr_span == target_range { + if let Some(baml_compiler2_ast::FunctionBodyDef::Expr( + ref nested_body, + ref nested_source_map, + )) = func_def.body + { + return descend_into_lambdas( + nested_body, + nested_source_map, + &lambda_ranges[1..], + ); + } + } + } + } + + None +} + /// Build a `DepRef` pointing to a function. fn make_function_dep( db: &dyn crate::Db, diff --git a/baml_language/crates/baml_lsp2_actions/src/describe_tests.rs b/baml_language/crates/baml_lsp2_actions/src/describe_tests.rs index 74caeb68fa..6bee4e65ba 100644 --- a/baml_language/crates/baml_lsp2_actions/src/describe_tests.rs +++ b/baml_language/crates/baml_lsp2_actions/src/describe_tests.rs @@ -103,3 +103,28 @@ fn describe_is_case_sensitive() { let descs = project.describe("point"); assert!(descs.is_empty()); } + +#[test] +fn describe_lambda_local_binding_uses_lambda_body() { + let mut builder = ProjectTest::builder(); + builder.source( + "lambda.baml", + r#" +function LambdaLocalDescribe() -> string { + let f = () -> string { + let ignored = 1 + let target = "lambda" + target + } + f() +} +"#, + ); + let project = builder.build(); + + let descs = project.describe("target"); + + assert_eq!(descs.len(), 1); + assert_eq!(descs[0].shape, "let target: string"); + assert_eq!(descs[0].resolved_type.as_deref(), Some("string")); +} diff --git a/baml_language/crates/baml_lsp2_actions/src/type_info.rs b/baml_language/crates/baml_lsp2_actions/src/type_info.rs index 7783c7865a..4cc1379a18 100644 --- a/baml_language/crates/baml_lsp2_actions/src/type_info.rs +++ b/baml_language/crates/baml_lsp2_actions/src/type_info.rs @@ -176,6 +176,13 @@ pub fn type_at(db: &dyn Db, file: SourceFile, offset: TextSize) -> Option Option Option<(DefinitionSite, TextSize)> { + let index = baml_compiler2_hir::file_semantic_index(db, file); + let scope_id = index.scope_at_offset(offset, None); + + for ancestor_id in index.ancestor_scopes(scope_id) { + let bindings = &index.scope_bindings[ancestor_id.index() as usize]; + for binding in bindings.bindings.iter().rev() { + if &binding.name == name + && (binding.name_range.contains(offset) || binding.name_range.end() == offset) + { + return Some((binding.site, binding.name_range.end())); + } + } + } + + None +} + // ── type_info_for_definition ────────────────────────────────────────────────── /// Build `TypeInfo` for a top-level item definition. @@ -515,8 +545,13 @@ fn local_type_info( .binding_type(pat_id) .map(utils::display_ty) .unwrap_or_else(|| { - // Try child scopes if the binding is in a nested block. - find_binding_ty_in_scopes(db, index, pat_id) + // Try the use-site's ancestor scope chain — restricts the + // lookup to inferences for bodies that share the + // use-site's pattern arena. Iterating *every* scope in + // the file would, under PatId collisions across nested + // ExprBodies (e.g. two lambdas with the same arena + // index), surface the wrong type for hover/inlay hints. + find_binding_ty_in_scopes(db, index, scope_id, pat_id) .unwrap_or_else(|| "unknown".to_string()) }); @@ -538,7 +573,7 @@ fn local_type_info( /// Extract the `PatId` for the binding introduced by `stmt_id`. /// -/// For `Stmt::Let { pattern, .. }` statements, returns the pattern ID. +/// For local declaration statements, returns the pattern ID. /// Returns `None` for other statement kinds. fn body_stmt_to_pat_id( body: &baml_compiler2_hir::body::FunctionBody, @@ -551,23 +586,33 @@ fn body_stmt_to_pat_id( let stmt = &expr_body.stmts[stmt_id]; match stmt { - baml_compiler2_ast::Stmt::Let { pattern, .. } => Some(*pattern), + baml_compiler2_ast::Stmt::Let { pattern, .. } + | baml_compiler2_ast::Stmt::For { + binding: pattern, .. + } => Some(*pattern), _ => None, } } -/// Search all scopes in the file for the binding type of `pat_id`. +/// Search the use-site's ancestor-scope chain for the binding type of +/// `pat_id`. /// -/// Used as a fallback when the let binding is in a nested block scope (not -/// directly in the enclosing function scope). Iterates all scope IDs in the -/// file index. +/// `PatId`s are arena-local to a function/lambda body, so iterating *all* +/// scopes in the file can surface a wrong-arena hit if two bodies happen to +/// allocate the same `PatId` index. Walking ancestors only — Function or +/// Lambda scopes that enclose `from_scope` — restricts the lookup to +/// inferences whose binding maps were populated from the use-site's own +/// arena. Mirrors the structure already used by +/// `completions.rs::find_binding_ty_for_local`. fn find_binding_ty_in_scopes( db: &dyn Db, index: &baml_compiler2_hir::semantic_index::FileSemanticIndex<'_>, + from_scope: baml_compiler2_hir::scope::FileScopeId, pat_id: baml_compiler2_ast::PatId, ) -> Option { - for scope_id in &index.scope_ids { - let inference = baml_compiler2_tir::inference::infer_scope_types(db, *scope_id); + for ancestor_id in index.ancestor_scopes(from_scope) { + let scope_id = index.scope_ids[ancestor_id.index() as usize]; + let inference = baml_compiler2_tir::inference::infer_scope_types(db, scope_id); if let Some(ty) = inference.binding_type(pat_id) { return Some(utils::display_ty(ty)); } diff --git a/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs index c7b018c900..efd96b9e64 100644 --- a/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs +++ b/baml_language/crates/baml_lsp2_actions/src/type_info_tests.rs @@ -63,3 +63,22 @@ fn function_hover_shows_explicit_throws_surface() { "expected explicit throws surface in hover, got:\n{markdown}" ); } + +#[test] +fn local_var_hover_for_for_loop_binding_uses_iterable_item_type() { + let test = CursorTest::new( + r#"function sum() -> int { + let total = 0 + for (let <[CURSOR]x in [1, 2]) { + total += x + } + return total +}"#, + ); + + let markdown = type_at(&test.db, test.cursor.file, test.cursor.offset) + .expect("hover info") + .to_hover_markdown(); + + assert_eq!(markdown, "```baml\nx: int\n```"); +} diff --git a/baml_language/crates/baml_lsp2_actions/src/usages.rs b/baml_language/crates/baml_lsp2_actions/src/usages.rs index a630370cee..c9173a434a 100644 --- a/baml_language/crates/baml_lsp2_actions/src/usages.rs +++ b/baml_language/crates/baml_lsp2_actions/src/usages.rs @@ -21,7 +21,9 @@ use baml_base::{Name, SourceFile}; use baml_compiler_syntax::SyntaxKind; use baml_compiler2_ast::{Expr, ExprBody}; -use baml_compiler2_hir::{body::FunctionBody, loc::FunctionLoc, scope::ScopeKind}; +use baml_compiler2_hir::{ + body::FunctionBody, loc::FunctionLoc, scope::ScopeKind, semantic_index::BindingId, +}; use baml_compiler2_tir::resolve::{ResolvedName, resolve_name_at}; use rowan::NodeOrToken; use text_size::{TextRange, TextSize}; @@ -54,6 +56,10 @@ pub fn usages_at(db: &dyn Db, file: SourceFile, offset: TextSize) -> Vec Vec { - // Local variable — only search in the enclosing function body. - find_local_usages(db, file, offset, &name_text, &resolved) - } + } => Vec::new(), ResolvedName::Local { definition_site: None, .. @@ -160,7 +163,7 @@ fn find_local_usages( file: SourceFile, at_offset: TextSize, name_text: &str, - target_resolved: &ResolvedName<'_>, + target_binding: BindingId, ) -> Vec { let index = baml_compiler2_hir::file_semantic_index(db, file); let item_tree = baml_compiler2_hir::file_item_tree(db, file); @@ -213,7 +216,7 @@ fn find_local_usages( file, expr_body, &name, - target_resolved, + target_binding, &source_map, &mut results, ); @@ -228,10 +231,12 @@ fn collect_local_path_usages( file: SourceFile, expr_body: &ExprBody, name: &Name, - target_resolved: &ResolvedName<'_>, + target_binding: BindingId, source_map: &baml_compiler2_ast::AstSourceMap, results: &mut Vec, ) { + let index = baml_compiler2_hir::file_semantic_index(db, file); + for (expr_id, expr) in expr_body.exprs.iter() { let Expr::Path(segments) = expr else { continue; @@ -248,32 +253,44 @@ fn collect_local_path_usages( continue; } - // Confirm that this usage resolves to the same local. + // Confirm that this usage resolves to the exact same visible binding. let use_offset = range.start(); - let resolved_here = resolve_name_at(db, file, use_offset, name); + let Some(use_scope) = index.expression_scope(expr_id) else { + continue; + }; - if same_local_definition(&resolved_here, target_resolved) { + if index.visible_binding_at(use_scope, use_offset, name) == Some(target_binding) { results.push(Location { file, range }); } } } -/// Returns `true` when two `ResolvedName::Local` values refer to the same -/// definition site. -fn same_local_definition(a: &ResolvedName<'_>, b: &ResolvedName<'_>) -> bool { - match (a, b) { - ( - ResolvedName::Local { - definition_site: Some(site_a), - .. - }, - ResolvedName::Local { - definition_site: Some(site_b), - .. - }, - ) => site_a == site_b, - _ => false, +fn local_binding_id_at( + db: &dyn Db, + file: SourceFile, + offset: TextSize, + name: &Name, +) -> Option { + let index = baml_compiler2_hir::file_semantic_index(db, file); + let scope_id = index.scope_at_offset(offset, None); + + // Declaration tokens are intentionally not visible until after their + // initializer/statement, so identify them by their recorded name range. + for ancestor_id in index.ancestor_scopes(scope_id) { + let bindings = &index.scope_bindings[ancestor_id.index() as usize]; + for binding in bindings.bindings.iter().rev() { + if &binding.name == name + && (binding.name_range.contains(offset) || binding.name_range.start() == offset) + { + return Some(BindingId { + scope: ancestor_id, + site: binding.site, + }); + } + } } + + index.visible_binding_at(scope_id, offset, name) } // ── field definition usages ─────────────────────────────────────────────────── diff --git a/baml_language/crates/baml_lsp2_actions/src/usages_at_tests.rs b/baml_language/crates/baml_lsp2_actions/src/usages_at_tests.rs index 0a05db1a53..0699581c01 100644 --- a/baml_language/crates/baml_lsp2_actions/src/usages_at_tests.rs +++ b/baml_language/crates/baml_lsp2_actions/src/usages_at_tests.rs @@ -222,6 +222,34 @@ function Test() -> string { ); } + #[test] + fn test_find_refs_local_variable_ignores_shadowed_binding() { + let test = CursorTest::new( + r#" +function Test() -> string { + let <[CURSOR]x = "outer" + let y = x + { + let x = "inner" + x + } + x +} +"#, + ); + + let usages = test.find_all_usages(); + assert_eq!( + usages.len(), + 2, + "Should only find usages of the outer 'x', found: {:?}", + usages + .iter() + .map(|l| test.format_location_with_name(l)) + .collect::>() + ); + } + #[test] fn test_find_refs_multi_file() { let mut builder = CursorTest::builder(); diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/expr/extra_dot.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/expr/extra_dot.baml index 41a0ac248f..12b9bc2a15 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/expr/extra_dot.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/expr/extra_dot.baml @@ -76,16 +76,3 @@ function ModernFunction(x: int) -> int { // │ // │ Note: Error code: E0010 // ────╯ -// Error: Duplicate binding `result` in `ModernFunction` -// ╭─[ expr_extra_dot.baml:34:7 ] -// │ -// 21 │ .let result = receipt.calculate(); -// │ ───┬── -// │ ╰──── first defined as binding here -// │ -// 34 │ let result = 100; -// │ ───┬── -// │ ╰──── duplicate binding definition -// │ -// │ Note: Error code: E0012 -// ────╯ diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/headers/complex_headers_test.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/headers/complex_headers_test.baml index 369df7e026..30bf9ec0b5 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/headers/complex_headers_test.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/headers/complex_headers_test.baml @@ -66,36 +66,6 @@ function ForLoopWithHeaders() -> int { //---- //- diagnostics -// Error: Duplicate binding `hello` in `ComplexHeaderTest` -// ╭─[ headers_complex_headers_test.baml:20:17 ] -// │ -// 11 │ let hello = "Hello"; -// │ ──┬── -// │ ╰──── first defined as binding here -// │ -// 20 │ let hello = "Hello"; -// │ ──┬── -// │ ╰──── duplicate binding definition -// │ -// 23 │ let hello = "Hello"; -// │ ──┬── -// │ ╰──── duplicate binding definition -// │ -// │ Note: Error code: E0012 -// ────╯ -// Error: Duplicate binding `result` in `ForLoopWithHeaders` -// ╭─[ headers_complex_headers_test.baml:57:13 ] -// │ -// 49 │ let result = 0; -// │ ───┬── -// │ ╰──── first defined as binding here -// │ -// 57 │ let result = result + processed; -// │ ───┬── -// │ ╰──── duplicate binding definition -// │ -// │ Note: Error code: E0012 -// ────╯ // Error: type mismatch: expected string, got int // ╭─[ headers_complex_headers_test.baml:43:5 ] // │ diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/headers/nested_if_statements.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/headers/nested_if_statements.baml index 4faf933aa9..29261c63a2 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/headers/nested_if_statements.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/headers/nested_if_statements.baml @@ -60,40 +60,6 @@ function NestedIfStatements() -> string { //---- //- diagnostics -// Error: Duplicate binding `a` in `NestedIfStatements` -// ╭─[ headers_nested_if_statements.baml:20:17 ] -// │ -// 9 │ let a = if (true) { -// │ ┬ -// │ ╰── first defined as binding here -// │ -// 20 │ let a = if (true) { -// │ ┬ -// │ ╰── duplicate binding definition -// │ -// 35 │ let a = if (true) { -// │ ┬ -// │ ╰── duplicate binding definition -// │ -// 46 │ let a = if (true) { -// │ ┬ -// │ ╰── duplicate binding definition -// │ -// │ Note: Error code: E0012 -// ────╯ -// Error: Duplicate binding `z` in `NestedIfStatements` -// ╭─[ headers_nested_if_statements.baml:33:13 ] -// │ -// 7 │ let z = if (true) { -// │ ┬ -// │ ╰── first defined as binding here -// │ -// 33 │ let z = if (true) { -// │ ┬ -// │ ╰── duplicate binding definition -// │ -// │ Note: Error code: E0012 -// ────╯ // Error: type mismatch: expected string, got int // ╭─[ headers_nested_if_statements.baml:58:5 ] // │ diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/loops/c_for.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/loops/c_for.baml index b1a965d082..882be1d229 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/loops/c_for.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/loops/c_for.baml @@ -119,6 +119,24 @@ function BreakContinue() -> int { // │ // │ Note: Error code: E0003 // ────╯ +// Error: unresolved name: i +// ╭─[ loops_c_for.baml:21:18 ] +// │ +// 21 │ for (; i <= 10; i += 1) { +// │ ┬ +// │ ╰── unresolved name: i +// │ +// │ Note: Error code: E0003 +// ────╯ +// Error: unresolved name: x +// ╭─[ loops_c_for.baml:42:32 ] +// │ +// 42 │ for (let i = 0; i <= 10; i += x) { +// │ ┬ +// │ ╰── unresolved name: x +// │ +// │ Note: Error code: E0003 +// ────╯ // Warning: unreachable code: 1 statement(s) after diverging statement // ╭─[ loops_c_for.baml:51:12 ] // │ diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/loops/header_requires_let_negative.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/loops/header_requires_let_negative.baml index 5ffea8d74f..565e6d66b2 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/loops/header_requires_let_negative.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/loops/header_requires_let_negative.baml @@ -102,6 +102,15 @@ function MissingLetAfter() -> int { // │ Note: Error code: E0010 // ───╯ // Error: unresolved name: i +// ╭─[ loops_header_requires_let_negative.baml:4:14 ] +// │ +// 4 │ for (i = 0; i < 3; i += 1) { +// │ ┬ +// │ ╰── unresolved name: i +// │ +// │ Note: Error code: E0003 +// ───╯ +// Error: unresolved name: i // ╭─[ loops_header_requires_let_negative.baml:4:21 ] // │ // 4 │ for (i = 0; i < 3; i += 1) { @@ -119,3 +128,12 @@ function MissingLetAfter() -> int { // │ // │ Note: Error code: E0003 // ───╯ +// Error: unresolved name: i +// ╭─[ loops_header_requires_let_negative.baml:14:16 ] +// │ +// 14 │ for (; s < 2; i += 1) { +// │ ┬ +// │ ╰── unresolved name: i +// │ +// │ Note: Error code: E0003 +// ────╯ diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/misc/return.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/misc/return.baml index d4c029fbbe..d15326fe03 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/misc/return.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/misc/return.baml @@ -46,19 +46,6 @@ function BadValueReturn(x: int) -> string { //---- //- diagnostics -// Error: Duplicate binding `b` in `WithStack` -// ╭─[ misc_return.baml:23:11 ] -// │ -// 15 │ let b = 1; -// │ ┬ -// │ ╰── first defined as binding here -// │ -// 23 │ let b = 3; -// │ ┬ -// │ ╰── duplicate binding definition -// │ -// │ Note: Error code: E0012 -// ────╯ // Error: type mismatch: expected string, got 1 // ╭─[ misc_return.baml:36:11 ] // │ diff --git a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/parens.baml b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/parens.baml index a962180da3..a32b1ee7ad 100644 --- a/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/parens.baml +++ b/baml_language/crates/baml_lsp2_actions_tests/test_files/syntax/parens.baml @@ -641,23 +641,6 @@ function foo() -> int { // │ // │ Note: Error code: E0010 // ────╯ -// Error: Duplicate binding `a` in `foo` -// ╭─[ parens.baml:59:23 ] -// │ -// 55 │ ╭─▶ for let a in as {} -// ┆ ┆ -// 57 │ ├─▶ for (let a in as {} -// │ │ -// │ ╰──────────────────────────── first defined as binding here -// │ -// 59 │ ╭───▶ for let a in as) {} -// ┆ ┆ -// 61 │ ├───▶ for (let a in as) {} -// │ │ -// │ ╰─────────────────────────────── duplicate binding definition -// │ -// │ Note: Error code: E0012 -// ────╯ // Error: unresolved name: i // ╭─[ parens.baml:39:10 ] // │ @@ -739,6 +722,15 @@ function foo() -> int { // │ // │ Note: Error code: E0006 // ────╯ +// Error: unresolved name: a +// ╭─[ parens.baml:59:12 ] +// │ +// 59 │ for let a in as) {} +// │ ┬ +// │ ╰── unresolved name: a +// │ +// │ Note: Error code: E0003 +// ────╯ // Error: unresolved name: as // ╭─[ parens.baml:59:17 ] // │ diff --git a/baml_language/crates/baml_tests/projects/closure_loop_variable/demo.baml b/baml_language/crates/baml_tests/projects/closure_loop_variable/demo.baml index 47887d4835..9ace196a60 100644 --- a/baml_language/crates/baml_tests/projects/closure_loop_variable/demo.baml +++ b/baml_language/crates/baml_tests/projects/closure_loop_variable/demo.baml @@ -10,4 +10,4 @@ function sum_array(arr: int[]) -> int { cb(); } sum -} \ No newline at end of file +} diff --git a/baml_language/crates/baml_tests/projects/lexical_scoping/lexical_scoping.baml b/baml_language/crates/baml_tests/projects/lexical_scoping/lexical_scoping.baml new file mode 100644 index 0000000000..7ee6f23d1a --- /dev/null +++ b/baml_language/crates/baml_tests/projects/lexical_scoping/lexical_scoping.baml @@ -0,0 +1,102 @@ +function branch_locals(b: bool) -> int { + if (b) { + let a = 1; + a + } else { + let a = 2; + a + } +} + +function same_scope_shadow() -> int { + let x = 1; + let x = 2; + x +} + +function initializer_uses_previous() -> int { + let x = 1; + let x = x + 1; + x +} + +function shadow_param(x: int) -> int { + let x = x + 1; + x +} + +function outer_restored() -> int { + let x = 1; + { + let x = 2; + }; + x +} + +function declared_type_restored() -> string { + let x: string = "outer"; + { + let x: int = 1; + }; + x +} + +function for_loop_restores_outer() -> int { + let x = 1; + for (let x in [2, 3]) { + x; + }; + x +} + +function watch_block_cleanup() -> int { + watch let x = 1; + { + watch let x = 2; + x; + }; + x +} + +function nested_outer_restored() -> int { + let x = 1; + { + let x = 2; + { + let x = 3; + x; + }; + x; + }; + x +} + +function capture_before_after_shadow() -> int { + let x = 1; + let g = () -> int { x }; + let x = 2; + let f = () -> int { x }; + g() * 10 + f() +} + +function nested_lambda_capture() -> int { + let x = 7; + let outer = () -> int { + let inner = () -> int { x }; + inner() + }; + outer() +} + +// A `let` declared inside a `while` body must NOT leak to the enclosing +// scope. The inner `x = 99` shadows the outer `x` for the duration of the +// body only — after the loop, `x` resolves to the outermost binding. +function while_body_restores_outer() -> int { + let x = 1; + let once = true; + while (once) { + let x = 99; + once = false; + }; + x +} diff --git a/baml_language/crates/baml_tests/projects/lexical_scoping_errors/lexical_scoping_errors.baml b/baml_language/crates/baml_tests/projects/lexical_scoping_errors/lexical_scoping_errors.baml new file mode 100644 index 0000000000..5930a98760 --- /dev/null +++ b/baml_language/crates/baml_tests/projects/lexical_scoping_errors/lexical_scoping_errors.baml @@ -0,0 +1,20 @@ +function block_does_not_leak(b: bool) -> int { + if (b) { + let a = 1; + }; + a +} + +function standalone_block_does_not_leak() -> int { + { + let x = 1; + }; + x +} + +function for_binding_does_not_leak() -> int { + for (let x in [1, 2]) { + x; + }; + x +} diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap index 87cd127bad..4cb8ee01da 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____04_5_mir.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 585 --- === MIR2 === @@ -1735,39 +1736,40 @@ fn baml.llm.Stream.final(self: baml.llm.Stream) -> void { let _3: bool let _4: () -> bool throws baml.errors.LlmClient let _5: future - let _6: string? - let _7: () -> string? throws baml.errors.Io - let _8: future - let _9: bool + let _6: string // next + let _7: string? + let _8: () -> string? throws baml.errors.Io + let _9: future let _10: bool - let _11: string // next - let _12: null - let _13: (string) -> null throws baml.errors.LlmClient - let _14: string - let _15: future - let _16: bool - let _17: () -> bool throws baml.errors.LlmClient - let _18: future - let _19: null - let _20: () -> null throws never - let _21: future - let _22: null - let _23: string? - let _24: () -> string? throws baml.errors.LlmClient - let _25: future - let _26: bool - let _27: void - let _28: bool - let _29: string // reason - let _30: (string) -> null throws baml.errors.LlmClient - let _31: string - let _32: future - let _33: string // content - let _34: () -> string throws baml.errors.LlmClient - let _35: future - let _36: string - let _37: baml.llm.StreamCache - let _38: future + let _11: bool + let _12: string // next + let _13: null + let _14: (string) -> null throws baml.errors.LlmClient + let _15: string + let _16: future + let _17: bool + let _18: () -> bool throws baml.errors.LlmClient + let _19: future + let _20: null + let _21: () -> null throws never + let _22: future + let _23: null + let _24: string? + let _25: () -> string? throws baml.errors.LlmClient + let _26: future + let _27: bool + let _28: void + let _29: bool + let _30: string // reason + let _31: (string) -> null throws baml.errors.LlmClient + let _32: string + let _33: future + let _34: string // content + let _35: () -> string throws baml.errors.LlmClient + let _36: future + let _37: string + let _38: baml.llm.StreamCache + let _39: future bb0: { goto -> bb1; @@ -1788,108 +1790,109 @@ fn baml.llm.Stream.final(self: baml.llm.Stream) -> void { } bb4: { - _7 = copy _1.2; - _8 = dispatch_future const fn baml.http.SseStream.next(copy _7) -> bb5; + _8 = copy _1.2; + _9 = dispatch_future const fn baml.http.SseStream.next(copy _8) -> bb5; } bb5: { - _6 = await _8 -> [bb6]; + _7 = await _9 -> [bb6]; } bb6: { - _9 = is_type(copy _6, Null { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); - branch copy _9 -> [bb16, bb7]; + _10 = is_type(copy _7, Null { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); + branch copy _10 -> [bb16, bb7]; } bb7: { - _10 = is_type(copy _6, String { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); - branch copy _10 -> [bb8, bb9]; + _11 = is_type(copy _7, String { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); + branch copy _11 -> [bb8, bb9]; } bb8: { - _11 = copy _6; + _12 = copy _7; + _6 = copy _12; goto -> bb9; } bb9: { - _13 = copy _1.1; - _14 = copy _11; - _15 = dispatch_future const fn baml.llm.StreamAccumulator.add_events(copy _13, copy _14) -> bb10; + _14 = copy _1.1; + _15 = copy _6; + _16 = dispatch_future const fn baml.llm.StreamAccumulator.add_events(copy _14, copy _15) -> bb10; } bb10: { - _12 = await _15 -> [bb11]; + _13 = await _16 -> [bb11]; } bb11: { - _17 = copy _1.1; - _18 = dispatch_future const fn baml.llm.StreamAccumulator.is_done(copy _17) -> bb12; + _18 = copy _1.1; + _19 = dispatch_future const fn baml.llm.StreamAccumulator.is_done(copy _18) -> bb12; } bb12: { - _16 = await _18 -> [bb13]; + _17 = await _19 -> [bb13]; } bb13: { - branch copy _16 -> [bb14, bb1]; + branch copy _17 -> [bb14, bb1]; } bb14: { - _20 = copy _1.2; - _21 = dispatch_future const fn baml.http.SseStream.close(copy _20) -> bb15; + _21 = copy _1.2; + _22 = dispatch_future const fn baml.http.SseStream.close(copy _21) -> bb15; } bb15: { - _19 = await _21 -> [bb16]; + _20 = await _22 -> [bb16]; } bb16: { - _24 = copy _1.1; - _25 = dispatch_future const fn baml.llm.StreamAccumulator.finish_reason(copy _24) -> bb17; + _25 = copy _1.1; + _26 = dispatch_future const fn baml.llm.StreamAccumulator.finish_reason(copy _25) -> bb17; } bb17: { - _23 = await _25 -> [bb18]; + _24 = await _26 -> [bb18]; } bb18: { - _26 = copy _23 == const null; - branch copy _26 -> [bb27, bb19]; + _27 = copy _24 == const null; + branch copy _27 -> [bb27, bb19]; } bb19: { - _28 = is_type(copy _23, String { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); - branch copy _28 -> [bb20, bb22]; + _29 = is_type(copy _24, String { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); + branch copy _29 -> [bb20, bb22]; } bb20: { - _29 = copy _23; - _30 = copy _1.0; - _31 = copy _29; - _32 = dispatch_future const fn baml.llm.PrimitiveClient.validate_finish_reason(copy _30, copy _31) -> bb21; + _30 = copy _24; + _31 = copy _1.0; + _32 = copy _30; + _33 = dispatch_future const fn baml.llm.PrimitiveClient.validate_finish_reason(copy _31, copy _32) -> bb21; } bb21: { - _22 = await _32 -> [bb22]; + _23 = await _33 -> [bb22]; } bb22: { - _34 = copy _1.1; - _35 = dispatch_future const fn baml.llm.StreamAccumulator.content(copy _34) -> bb23; + _35 = copy _1.1; + _36 = dispatch_future const fn baml.llm.StreamAccumulator.content(copy _35) -> bb23; } bb23: { - _33 = await _35 -> [bb24]; + _34 = await _36 -> [bb24]; } bb24: { - _36 = copy _33; - _37 = copy _1.3; - _38 = dispatch_future const fn baml.llm.__sap_parse_final(copy _36, copy _37) -> bb25; + _37 = copy _34; + _38 = copy _1.3; + _39 = dispatch_future const fn baml.llm.__sap_parse_final(copy _37, copy _38) -> bb25; } bb25: { - _0 = await _38 -> [bb26]; + _0 = await _39 -> [bb26]; } bb26: { @@ -1897,8 +1900,8 @@ fn baml.llm.Stream.final(self: baml.llm.Stream) -> void { } bb27: { - _27 = LlmClientError { const "Streaming finished without finish_reason" }; - throw copy _27; + _28 = LlmClientError { const "Streaming finished without finish_reason" }; + throw copy _28; } } @@ -1928,32 +1931,33 @@ fn baml.llm.Stream.next(self: baml.llm.Stream) -> void | baml.stream.StreamFinis // Locals: let _0: void | baml.stream.StreamFinished // _0 // return let _1: baml.llm.Stream // self // param - let _2: string? - let _3: () -> string? throws baml.errors.Io - let _4: future - let _5: bool + let _2: string // next + let _3: string? + let _4: () -> string? throws baml.errors.Io + let _5: future let _6: bool - let _7: string // next - let _8: null - let _9: (string) -> null throws baml.errors.LlmClient - let _10: string - let _11: future - let _12: bool - let _13: () -> bool throws baml.errors.LlmClient - let _14: future - let _15: null - let _16: () -> null throws never - let _17: future - let _18: string // content - let _19: () -> string throws baml.errors.LlmClient - let _20: future - let _21: void | baml.stream.StreamNoYield // parsed - let _22: string - let _23: baml.llm.StreamCache - let _24: future - let _25: bool + let _7: bool + let _8: string // next + let _9: null + let _10: (string) -> null throws baml.errors.LlmClient + let _11: string + let _12: future + let _13: bool + let _14: () -> bool throws baml.errors.LlmClient + let _15: future + let _16: null + let _17: () -> null throws never + let _18: future + let _19: string // content + let _20: () -> string throws baml.errors.LlmClient + let _21: future + let _22: void | baml.stream.StreamNoYield // parsed + let _23: string + let _24: baml.llm.StreamCache + let _25: future let _26: bool - let _27: void // parsed + let _27: bool + let _28: void // parsed bb0: { goto -> bb1; @@ -1968,94 +1972,95 @@ fn baml.llm.Stream.next(self: baml.llm.Stream) -> void | baml.stream.StreamFinis } bb3: { - _3 = copy _1.2; - _4 = dispatch_future const fn baml.http.SseStream.next(copy _3) -> bb4; + _4 = copy _1.2; + _5 = dispatch_future const fn baml.http.SseStream.next(copy _4) -> bb4; } bb4: { - _2 = await _4 -> [bb5]; + _3 = await _5 -> [bb5]; } bb5: { - _5 = is_type(copy _2, Null { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); - branch copy _5 -> [bb23, bb6]; + _6 = is_type(copy _3, Null { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); + branch copy _6 -> [bb23, bb6]; } bb6: { - _6 = is_type(copy _2, String { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); - branch copy _6 -> [bb7, bb8]; + _7 = is_type(copy _3, String { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); + branch copy _7 -> [bb7, bb8]; } bb7: { - _7 = copy _2; + _8 = copy _3; + _2 = copy _8; goto -> bb8; } bb8: { - _9 = copy _1.1; - _10 = copy _7; - _11 = dispatch_future const fn baml.llm.StreamAccumulator.add_events(copy _9, copy _10) -> bb9; + _10 = copy _1.1; + _11 = copy _2; + _12 = dispatch_future const fn baml.llm.StreamAccumulator.add_events(copy _10, copy _11) -> bb9; } bb9: { - _8 = await _11 -> [bb10]; + _9 = await _12 -> [bb10]; } bb10: { - _13 = copy _1.1; - _14 = dispatch_future const fn baml.llm.StreamAccumulator.is_done(copy _13) -> bb11; + _14 = copy _1.1; + _15 = dispatch_future const fn baml.llm.StreamAccumulator.is_done(copy _14) -> bb11; } bb11: { - _12 = await _14 -> [bb12]; + _13 = await _15 -> [bb12]; } bb12: { - branch copy _12 -> [bb20, bb13]; + branch copy _13 -> [bb20, bb13]; } bb13: { - _19 = copy _1.1; - _20 = dispatch_future const fn baml.llm.StreamAccumulator.content(copy _19) -> bb14; + _20 = copy _1.1; + _21 = dispatch_future const fn baml.llm.StreamAccumulator.content(copy _20) -> bb14; } bb14: { - _18 = await _20 -> [bb15]; + _19 = await _21 -> [bb15]; } bb15: { - _22 = copy _18; - _23 = copy _1.3; - _24 = dispatch_future const fn baml.llm.__sap_parse_partial(copy _22, copy _23) -> bb16; + _23 = copy _19; + _24 = copy _1.3; + _25 = dispatch_future const fn baml.llm.__sap_parse_partial(copy _23, copy _24) -> bb16; } bb16: { - _21 = await _24 -> [bb17]; + _22 = await _25 -> [bb17]; } bb17: { - _25 = is_type(copy _21, Class(TypeName { name: "StreamNoYield", module_path: ["baml", "stream"], display_name: "baml.stream.StreamNoYield" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] })); - branch copy _25 -> [bb1, bb18]; + _26 = is_type(copy _22, Class(TypeName { name: "StreamNoYield", module_path: ["baml", "stream"], display_name: "baml.stream.StreamNoYield" }, TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] })); + branch copy _26 -> [bb1, bb18]; } bb18: { - _26 = is_type(copy _21, Void { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); - branch copy _26 -> [bb19, bb1]; + _27 = is_type(copy _22, Void { attr: TyAttr { sap_parse_without_null: Unset, sap_pending_never: Unset, sap_in_progress_never: Unset, asserts: [] } }); + branch copy _27 -> [bb19, bb1]; } bb19: { - _27 = copy _21; - _0 = copy _27; + _28 = copy _22; + _0 = copy _28; goto -> bb24; } bb20: { - _16 = copy _1.2; - _17 = dispatch_future const fn baml.http.SseStream.close(copy _16) -> bb21; + _17 = copy _1.2; + _18 = dispatch_future const fn baml.http.SseStream.close(copy _17) -> bb21; } bb21: { - _15 = await _17 -> [bb22]; + _16 = await _18 -> [bb22]; } bb22: { diff --git a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap index 7e06d4e6b2..920b9fa483 100644 --- a/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/__baml_std__/baml_tests____baml_std____06_codegen.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 690 --- function baml.Array.at(self: null, index: int) -> void? { } @@ -1224,17 +1225,17 @@ function baml.llm.Stream.final(self: null) -> void { load_field ._sse dispatch_future baml.http.SseStream.next await - store_var _6 - load_var _6 + store_var _7 + load_var _7 is_type null pop_jump_if_false L1 jump L3 L1: - load_var _6 + load_var _7 is_type string pop_jump_if_false L2 - load_var _6 + load_var _7 store_var next L2: @@ -1260,20 +1261,20 @@ function baml.llm.Stream.final(self: null) -> void { load_field ._acc dispatch_future baml.llm.StreamAccumulator.finish_reason await - store_var _23 - load_var _23 + store_var _24 + load_var _24 load_const null cmp_op == pop_jump_if_false L4 jump L6 L4: - load_var _23 + load_var _24 is_type string pop_jump_if_false L5 load_var self load_field ._client - load_var _23 + load_var _24 dispatch_future baml.llm.PrimitiveClient.validate_finish_reason await pop 1 @@ -1312,17 +1313,17 @@ function baml.llm.Stream.next(self: null) -> void | baml.stream.StreamFinished { load_field ._sse dispatch_future baml.http.SseStream.next await - store_var _2 - load_var _2 + store_var _3 + load_var _3 is_type null pop_jump_if_false L3 jump L8 L3: - load_var _2 + load_var _3 is_type string pop_jump_if_false L4 - load_var _2 + load_var _3 store_var next L4: diff --git a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap index bebdd7cb1e..42a4d312a2 100644 --- a/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/__testing_std__/baml_tests____testing_std____04_5_mir.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 956 --- === MIR2 === @@ -611,7 +610,7 @@ fn .() -> null { let _2: testing.RunReport // result let _3: unknown // e let _4: void - let _5: void + let _5: () -> void throws unknown let _6: "pass" | "fail" | "error" let _7: testing.RunReport[] let _8: testing.RunReport diff --git a/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__04_5_mir.snap index a57f230212..928dc711b7 100644 --- a/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__04_5_mir.snap @@ -71,7 +71,7 @@ fn user.CatchThenCatchAll(x: int) -> string { // Locals: let _0: string // _0 // return let _1: int // x // param - let _2: unknown // e + let _2: unknown let _3: bool bb0: { diff --git a/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__06_codegen.snap index 8d568e115d..c57512c26f 100644 --- a/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/catch_all_keyword/baml_tests__catch_all_keyword__06_codegen.snap @@ -47,13 +47,13 @@ function user.CatchThenCatchAll(x: int) -> string { load_var x call user.MayFailIntOrString jump L2 - load_var e + load_var _2 is_type string pop_jump_if_false L0 jump L1 L0: - load_var e + load_var _2 throw_if_panic load_const "caught rest" jump L2 diff --git a/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__04_5_mir.snap index 8e79f21448..adef240b32 100644 --- a/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__04_5_mir.snap @@ -54,7 +54,7 @@ fn user.CatchWithFallback(x: int) -> string { // Locals: let _0: string // _0 // return let _1: int // x // param - let _2: unknown // e + let _2: unknown let _3: bool let _4: bool @@ -125,7 +125,7 @@ fn user.ChainedCatchClauses(x: int) -> string { // Locals: let _0: string // _0 // return let _1: int // x // param - let _2: unknown // e + let _2: unknown let _3: bool let _4: bool diff --git a/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__06_codegen.snap index b18a87ec21..7e1ba83cfc 100644 --- a/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__06_codegen.snap @@ -34,19 +34,19 @@ function user.CatchWithFallback(x: int) -> string { load_var x call user.MayFailIntOrString jump L4 - load_var e + load_var _2 is_type string pop_jump_if_false L0 jump L3 L0: - load_var e + load_var _2 is_type int pop_jump_if_false L1 jump L2 L1: - load_var e + load_var _2 throw L2: @@ -84,19 +84,19 @@ function user.ChainedCatchClauses(x: int) -> string { load_var x call user.MayFailIntOrString jump L4 - load_var e + load_var _2 is_type string pop_jump_if_false L0 jump L3 L0: - load_var e + load_var _2 is_type int pop_jump_if_false L1 jump L2 L1: - load_var e + load_var _2 throw L2: diff --git a/baml_language/crates/baml_tests/snapshots/catch_throw_regressions/baml_tests__catch_throw_regressions__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/catch_throw_regressions/baml_tests__catch_throw_regressions__06_codegen.snap index 0ec154b5bd..077e013477 100644 --- a/baml_language/crates/baml_tests/snapshots/catch_throw_regressions/baml_tests__catch_throw_regressions__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/catch_throw_regressions/baml_tests__catch_throw_regressions__06_codegen.snap @@ -20,31 +20,58 @@ function user.AlwaysThrowsStatus(n: int) -> Status { function user.CatchAllStatusVariants(n: int) -> string { load_var n call user.ThrowsAllStatusVariants - jump L5 - load_var e - discriminant - jump_table [L4, L3, L2, L1], default L0 + jump L8 + load_var _2 + load_const user.Status.HttpError + alloc_variant user.Status + cmp_op == + pop_jump_if_false L0 + jump L7 L0: - load_var e + load_var _2 + load_const user.Status.IndexError + alloc_variant user.Status + cmp_op == + pop_jump_if_false L1 + jump L6 + + L1: + load_var _2 + load_const user.Status.AuthError + alloc_variant user.Status + cmp_op == + pop_jump_if_false L2 + jump L5 + + L2: + load_var _2 + load_const user.Status.SomeOtherError + alloc_variant user.Status + cmp_op == + pop_jump_if_false L3 + jump L4 + + L3: + load_var _2 throw - L1: Status.SomeOtherError + L4: load_const "other!" - jump L5 + jump L8 - L2: Status.AuthError + L5: load_const "auth!" - jump L5 + jump L8 - L3: Status.IndexError + L6: load_const "index!" - jump L5 + jump L8 - L4: Status.HttpError + L7: load_const "http!" - L5: + L8: return } @@ -52,19 +79,19 @@ function user.CatchAlwaysThrows(n: int) -> string { load_var n call user.WrapperAlwaysThrows jump L4 - load_var e + load_var _2 is_type string pop_jump_if_false L0 jump L3 L0: - load_var e + load_var _2 is_type int pop_jump_if_false L1 jump L2 L1: - load_var e + load_var _2 throw L2: diff --git a/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__03_hir.snap b/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__03_hir.snap index 07b64a631c..d6ae8ae0ba 100644 --- a/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__03_hir.snap +++ b/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__03_hir.snap @@ -7,4 +7,4 @@ function user.sum_array(arr: int[]) -> int [expr] { } --- captures --- -lambda () in sum_array: captures [sum, i] +lambda () in ?: captures [sum, i] diff --git a/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__04_5_mir.snap index acc379ee2d..7a28cb2e31 100644 --- a/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__04_5_mir.snap @@ -92,7 +92,7 @@ fn user.sum_array(arr: int[]) -> int { fn .() -> null { // Locals: let _0: null // _0 // return - let _1: void + let _1: int bb0: { _1 = copy capture[1]; diff --git a/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__04_tir.snap b/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__04_tir.snap index 53770ed8b5..ea1247d743 100644 --- a/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__04_tir.snap +++ b/baml_language/crates/baml_tests/snapshots/closure_loop_variable/baml_tests__closure_loop_variable__04_tir.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 7166 --- === TIR2 === function user.sum_array(arr: int[]) -> int throws never { diff --git a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__05_diagnostics.snap index 90b4e5051f..bf98e6aa31 100644 --- a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__05_diagnostics.snap +++ b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__05_diagnostics.snap @@ -429,172 +429,6 @@ assertion_line: 11870 │ Note: Error code: E0003 ─────╯ - [hir] Error: Duplicate binding `i` in `LoopStmts` - ╭─[ loop_stmts.baml:76:14 ] - │ - 71 │ for (let i = 0; i < 10; i += 1) { - │ ┬ - │ ╰── first defined as binding here - │ - 76 │ for (let i = 0; i < 100 && flag; i += 2) { - │ ┬ - │ ╰── duplicate binding definition - │ - 88 │ ╭─▶ }; - ┆ ┆ - 95 │ ├─▶ }; - │ │ - │ ╰──────────── duplicate binding definition - │ - 105 │ i // 8 var name trailing - │ ┬ - │ ╰── duplicate binding definition - │ - 134 │ i /* 8 var name trailing */ - │ ┬ - │ ╰── duplicate binding definition - │ - 367 │ ╭───▶ ; - ┆ ┆ - 378 │ ├───▶ }; - │ │ - │ ╰────────────── duplicate binding definition - │ - 393 │ ╭─────────▶ }; - ┆ ┆ - 414 │ ├─────────▶ }; - │ │ - │ ╰──────────────────── duplicate binding definition - │ - 429 │ ╭─────▶ }; - ┆ ┆ - 436 │ ├─────▶ }; - │ │ ▲ - │ ╰──────────────── duplicate binding definition - │ │ - │ ╭───────────────╯ - ┆ ┆ - 452 │ ├───────▶ }; - │ │ - │ ╰────────────────── duplicate binding definition - │ - 459 │ for ( let i = 0 ; i < 10 ; i += 1 ) { - │ ┬ - │ ╰── duplicate binding definition - │ - │ Note: Error code: E0012 -─────╯ - - [hir] Error: Duplicate binding `item` in `LoopStmts` - ╭─[ loop_stmts.baml:151:6 ] - │ - 78 │ ╭───▶ }; - ┆ ┆ - 83 │ ├───▶ }; - │ │ - │ ╰────────────── first defined as binding here - │ - 151 │ ╭───────▶ ; - ┆ ┆ - 174 │ ├───────▶ } // 18 body close trailing - │ │ - │ ╰─────────────────────────────────────────── duplicate binding definition - 175 │ ╭─────────▶ ; - ┆ ┆ - 196 │ ├─────────▶ } /* 18 body close trailing */ - │ │ - │ ╰──────────────────────────────────────────────── duplicate binding definition - │ - 217 │ ╭─────▶ }; - ┆ ┆ - 222 │ ├─────▶ }; - │ │ - │ ╰──────────────── duplicate binding definition - │ - 461 │ ╭─▶ }; - ┆ ┆ - 465 │ ├─▶ }; - │ │ - │ ╰──────────── duplicate binding definition - │ - │ Note: Error code: E0012 -─────╯ - - [hir] Error: Duplicate binding `extremely_long_loop_counter_variable` in `LoopStmts` - ╭─[ loop_stmts.baml:273:5 ] - │ - 215 │ for (let extremely_long_loop_counter_variable = a * a + result * result; extremely_long_loop_counter_variable < 1000000 && flag && a > 0; extremely_long_loop_counter_variable += a + result) { - │ ──────────────────┬───────────────── - │ ╰─────────────────── first defined as binding here - │ - 273 │ extremely_long_loop_counter_variable // 8 name trailing - │ ──────────────────┬───────────────── - │ ╰─────────────────── duplicate binding definition - │ - 302 │ extremely_long_loop_counter_variable /* 8 name trailing */ - │ ──────────────────┬───────────────── - │ ╰─────────────────── duplicate binding definition - │ - │ Note: Error code: E0012 -─────╯ - - [hir] Error: Duplicate binding `extremely_long_item_variable_name` in `LoopStmts` - ╭─[ loop_stmts.baml:319:6 ] - │ - 222 │ ╭─▶ }; - ┆ ┆ - 227 │ ├─▶ }; - │ │ - │ ╰──────────── first defined as binding here - │ - 319 │ ╭───▶ ; - ┆ ┆ - 342 │ ├───▶ } // 18 close trailing - │ │ - │ ╰────────────────────────────────── duplicate binding definition - 343 │ ╭─────▶ ; - ┆ ┆ - 366 │ ├─────▶ } /* 18 close trailing */ - │ │ - │ ╰─────────────────────────────────────── duplicate binding definition - │ - │ Note: Error code: E0012 -─────╯ - - [hir] Error: Duplicate binding `j` in `LoopStmts` - ╭─[ loop_stmts.baml:372:31 ] - │ - 91 │ ╭─▶ for (let i in [1, 2, 3]) { - ┆ ┆ - 94 │ ├─▶ }; - │ │ - │ ╰──────────────── first defined as binding here - │ - 372 │ ╭───▶ for (let i in [1, 2, 3]) { - ┆ ┆ - 377 │ ├───▶ }; - │ │ - │ ╰────────────────── duplicate binding definition - │ - 399 │ ╭─────▶ } - ┆ ┆ - 413 │ ├─────▶ }; - │ │ - │ ╰──────────────────── duplicate binding definition - │ - 433 │ for (let j = i * i + result * result; j < 1000000 && flag && i > 0 && result > 0 && j != 42; j += i + result + 1) { - │ ┬ - │ ╰── duplicate binding definition - │ - 439 │ ╭───────▶ for (let i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) { - ┆ ┆ - 451 │ ├───────▶ }; - │ │ - │ ╰────────────────────── duplicate binding definition - │ - │ Note: Error code: E0012 -─────╯ - [type] Warning: unreachable code: 1 statement(s) after diverging statement ╭─[ loop_stmts.baml:402:23 ] │ @@ -617,42 +451,6 @@ assertion_line: 11870 │ Note: Error code: E0001 ─────╯ - [hir] Error: Duplicate binding `very_long_variable_name_for_testing_wrapping` in `LoopStmts` - ╭─[ loop_stmts.baml:515:5 ] - │ - 473 │ let very_long_variable_name_for_testing_wrapping: map = {"key": 42}; - │ ──────────────────────┬───────────────────── - │ ╰─────────────────────── first defined as binding here - │ - 515 │ very_long_variable_name_for_testing_wrapping // 4 name trailing - │ ──────────────────────┬───────────────────── - │ ╰─────────────────────── duplicate binding definition - │ - 532 │ very_long_variable_name_for_testing_wrapping /* 4 name trailing */ - │ ──────────────────────┬───────────────────── - │ ╰─────────────────────── duplicate binding definition - │ - │ Note: Error code: E0012 -─────╯ - - [hir] Error: Duplicate binding `another_extremely_long_variable_name` in `LoopStmts` - ╭─[ loop_stmts.baml:549:5 ] - │ - 476 │ let another_extremely_long_variable_name = a * a + result * result + a * result + a + result + a * a * a + result * result * result + a * result * a + result * a * result; - │ ──────────────────┬───────────────── - │ ╰─────────────────── first defined as binding here - │ - 549 │ another_extremely_long_variable_name // 4 name trailing - │ ──────────────────┬───────────────── - │ ╰─────────────────── duplicate binding definition - │ - 562 │ another_extremely_long_variable_name /* 4 name trailing */ - │ ──────────────────┬───────────────── - │ ╰─────────────────── duplicate binding definition - │ - │ Note: Error code: E0012 -─────╯ - [type] Error: unreachable arm ╭─[ match_exprs.baml:157:18 ] │ @@ -723,42 +521,6 @@ assertion_line: 11870 │ Note: Error code: E0004 ─────╯ - [hir] Error: Duplicate binding `x` in `OtherExprs` - ╭─[ other_exprs.baml:313:13 ] - │ - 308 │ let x = 1; - │ ┬ - │ ╰── first defined as binding here - │ - 313 │ let x = a + b; - │ ┬ - │ ╰── duplicate binding definition - │ - 320 │ let x = { - │ ┬ - │ ╰── duplicate binding definition - │ - 332 │ let x = if (a > b) { - │ ┬ - │ ╰── duplicate binding definition - │ - │ Note: Error code: E0012 -─────╯ - - [hir] Error: Duplicate binding `y` in `OtherExprs` - ╭─[ other_exprs.baml:321:17 ] - │ - 314 │ let y = x * 2; - │ ┬ - │ ╰── first defined as binding here - │ - 321 │ let y = { - │ ┬ - │ ╰── duplicate binding definition - │ - │ Note: Error code: E0012 -─────╯ - [validation] Error: Name `ManyParams` defined 2 times as: function, template_string ╭─[ function_decls.baml:23:10 ] │ diff --git a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__06_codegen.snap index 0f179b159a..55fbe65809 100644 --- a/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__06_codegen.snap +++ b/baml_language/crates/baml_tests/snapshots/format_checks/baml_tests__format_checks__06_codegen.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 12006 --- function $init() -> null { call $init_let_0 @@ -3729,68 +3730,66 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, pop_jump_if_false L26 L26: - load_var r + load_var x jump_table [L27, L27, L27, L27, L27, L27], default L27 L27: 5 - load_var r + load_var x is_type int pop_jump_if_false L28 - load_var r + load_var x load_const 0 cmp_op > pop_jump_if_false L28 jump L29 L28: - load_var r + load_var x is_type int pop_jump_if_false L29 - load_var r + load_var x load_const 0 cmp_op < pop_jump_if_false L29 L29: - load_var r + load_var x is_type int pop_jump_if_false L35 - load_var r - store_var n - load_var n + load_var x load_const 0 cmp_op > jump_if_false L30 pop 1 - load_var n + load_var x load_const 100 cmp_op < L30: jump_if_false L31 pop 1 - load_var n + load_var x load_const 42 cmp_op != L31: jump_if_false L32 pop 1 - load_var n + load_var x load_const 13 cmp_op != L32: jump_if_false L33 pop 1 - load_var n + load_var x load_const 7 cmp_op != L33: jump_if_false L34 pop 1 - load_var n + load_var x load_const 99 cmp_op != @@ -3798,37 +3797,37 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, pop_jump_if_false L35 L35: - load_var r + load_var x load_const 0 cmp_op == pop_jump_if_false L36 L36: - load_var r + load_var x load_const 1 cmp_op == pop_jump_if_false L37 L37: - load_var r + load_var x jump_table [L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38, L38], default L38 L38: 24 - load_var r + load_var s discriminant load_const MatchStatus.Active cmp_op == pop_jump_if_false L39 L39: - load_var r + load_var s discriminant load_const MatchStatus.Inactive cmp_op == pop_jump_if_false L40 L40: - load_var r + load_var x load_const 0 cmp_op == pop_jump_if_false L41 @@ -3849,38 +3848,36 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, store_var result L44: - load_var r + load_var x is_type int pop_jump_if_false L49 - load_var r - store_var n - load_var n + load_var x load_const 0 cmp_op > jump_if_false L45 pop 1 - load_var n + load_var x load_const 100 cmp_op < L45: jump_if_false L46 pop 1 - load_var n + load_var x load_const 42 cmp_op != L46: jump_if_false L47 pop 1 - load_var n + load_var x load_const 13 cmp_op != L47: jump_if_false L48 pop 1 - load_var n + load_var x load_const 7 cmp_op != @@ -3888,38 +3885,36 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, pop_jump_if_false L49 L49: - load_var r + load_var x is_type int pop_jump_if_false L54 - load_var r - store_var n - load_var n + load_var x load_const 0 cmp_op > jump_if_false L50 pop 1 - load_var n + load_var x load_const 100 cmp_op < L50: jump_if_false L51 pop 1 - load_var n + load_var x load_const 42 cmp_op != L51: jump_if_false L52 pop 1 - load_var n + load_var x load_const 13 cmp_op != L52: jump_if_false L53 pop 1 - load_var n + load_var x load_const 7 cmp_op != @@ -3927,7 +3922,7 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, pop_jump_if_false L54 L54: - load_var r + load_var x load_const 1 bin_op + load_const 0 @@ -3935,25 +3930,25 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, pop_jump_if_false L55 L55: - load_var r - load_var r + load_var x + load_var x bin_op * load_var b load_var b bin_op * bin_op + - load_var r + load_var x load_var b bin_op * bin_op + - load_var r + load_var x bin_op + load_var b bin_op + - load_var r - load_var r + load_var x + load_var x bin_op * - load_var r + load_var x bin_op * bin_op + load_var b @@ -3962,14 +3957,14 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, load_var b bin_op * bin_op + - load_var r + load_var x load_var b bin_op * - load_var r + load_var x bin_op * bin_op + load_var b - load_var r + load_var x bin_op * load_var b bin_op * @@ -3979,7 +3974,7 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, pop_jump_if_false L56 L56: - load_var r + load_var x load_const 1 bin_op + load_const 0 @@ -3987,13 +3982,13 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, pop_jump_if_false L57 L57: - load_var r + load_var x load_const 0 cmp_op == pop_jump_if_false L58 L58: - load_var r + load_var x load_const 0 cmp_op == pop_jump_if_false L59 @@ -4006,14 +4001,14 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, pop_jump_if_false L60 L60: - load_var r + load_var x load_const 0 cmp_op == pop_jump_if_false L61 jump L65 L61: - load_var r + load_var x load_const 1 cmp_op == pop_jump_if_false L62 @@ -4043,21 +4038,21 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, load_const true cmp_op == pop_jump_if_false L67 - load_var r + load_var s discriminant load_const MatchStatus.Active cmp_op == pop_jump_if_false L66 L66: - load_var r + load_var s discriminant load_const MatchStatus.Inactive cmp_op == pop_jump_if_false L67 L67: - load_var r + load_var x load_const 0 cmp_op == pop_jump_if_false L68 @@ -4066,21 +4061,21 @@ function user.MatchExprs(x: int, s: MatchStatus, r: MatchSuccess | MatchFailure, jump L70 load_var b pop_jump_if_false L70 - load_var r + load_var s discriminant load_const MatchStatus.Active cmp_op == pop_jump_if_false L69 L69: - load_var r + load_var s discriminant load_const MatchStatus.Inactive cmp_op == pop_jump_if_false L70 L70: - load_var r + load_var x load_const 0 cmp_op == pop_jump_if_false L71 diff --git a/baml_language/crates/baml_tests/snapshots/generic_field_chain/baml_tests__generic_field_chain__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/generic_field_chain/baml_tests__generic_field_chain__04_5_mir.snap index d3612b0a80..8b68fa1e53 100644 --- a/baml_language/crates/baml_tests/snapshots/generic_field_chain/baml_tests__generic_field_chain__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/generic_field_chain/baml_tests__generic_field_chain__04_5_mir.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 14409 --- === MIR2 === fn user.test_generic_capture(box: Box) -> string { @@ -25,9 +24,11 @@ fn user.test_generic_capture(box: Box) -> string { fn .() -> null { // Locals: let _0: null // _0 // return + let _1: string bb0: { - _0 = const null; + _1 = const "name"; + _0 = copy capture[0].0[_1]; goto -> bb1; } diff --git a/baml_language/crates/baml_tests/snapshots/lambda_field_access/baml_tests__lambda_field_access__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/lambda_field_access/baml_tests__lambda_field_access__04_5_mir.snap index bb16fc9c52..78c94d5e9a 100644 --- a/baml_language/crates/baml_tests/snapshots/lambda_field_access/baml_tests__lambda_field_access__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/lambda_field_access/baml_tests__lambda_field_access__04_5_mir.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 17172 --- === MIR2 === fn user.test_captured_field_access(obj: Outer) -> string { @@ -27,7 +26,7 @@ fn .() -> null { let _0: null // _0 // return bb0: { - _0 = const null; + _0 = copy capture[0].1; goto -> bb1; } @@ -60,7 +59,7 @@ fn .() -> null { let _0: null // _0 // return bb0: { - _0 = const null; + _0 = copy capture[0].0.0; goto -> bb1; } @@ -92,9 +91,11 @@ fn .(x: int) -> null { // Locals: let _0: null // _0 // return let _1: int // x // param + let _2: int bb0: { - _0 = copy _1 + const null; + _2 = copy capture[0].0.0; + _0 = copy _1 + copy _2; goto -> bb1; } @@ -145,7 +146,7 @@ fn ., 1)>() -> null { let _0: null // _0 // return bb0: { - _0 = const null; + _0 = copy capture[0].1; goto -> bb1; } @@ -186,7 +187,7 @@ fn .() -> null { let _0: null // _0 // return bb0: { - _0 = const null; + _0 = copy capture[0].0.0; goto -> bb1; } @@ -228,7 +229,7 @@ fn .() -> null { let _0: null // _0 // return bb0: { - _0 = const null; + _0 = copy capture[0].1; goto -> bb1; } @@ -243,7 +244,7 @@ fn .() -> null { let _0: null // _0 // return bb0: { - _0 = const null; + _0 = copy capture[0].0.0; goto -> bb1; } diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__01_lexer__lexical_scoping.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__01_lexer__lexical_scoping.snap new file mode 100644 index 0000000000..0e4c7a6607 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__01_lexer__lexical_scoping.snap @@ -0,0 +1,417 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Function "function" +Word "branch_locals" +LParen "(" +Word "b" +Colon ":" +Word "bool" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +If "if" +LParen "(" +Word "b" +RParen ")" +LBrace "{" +Let "let" +Word "a" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +Word "a" +RBrace "}" +Else "else" +LBrace "{" +Let "let" +Word "a" +Equals "=" +IntegerLiteral "2" +Semicolon ";" +Word "a" +RBrace "}" +RBrace "}" +Function "function" +Word "same_scope_shadow" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "2" +Semicolon ";" +Word "x" +RBrace "}" +Function "function" +Word "initializer_uses_previous" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +Let "let" +Word "x" +Equals "=" +Word "x" +Plus "+" +IntegerLiteral "1" +Semicolon ";" +Word "x" +RBrace "}" +Function "function" +Word "shadow_param" +LParen "(" +Word "x" +Colon ":" +Word "int" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "x" +Equals "=" +Word "x" +Plus "+" +IntegerLiteral "1" +Semicolon ";" +Word "x" +RBrace "}" +Function "function" +Word "outer_restored" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "2" +Semicolon ";" +RBrace "}" +Semicolon ";" +Word "x" +RBrace "}" +Function "function" +Word "declared_type_restored" +LParen "(" +RParen ")" +Arrow "->" +Word "string" +LBrace "{" +Let "let" +Word "x" +Colon ":" +Word "string" +Equals "=" +Quote "\"" +Word "outer" +Quote "\"" +Semicolon ";" +LBrace "{" +Let "let" +Word "x" +Colon ":" +Word "int" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +RBrace "}" +Semicolon ";" +Word "x" +RBrace "}" +Function "function" +Word "for_loop_restores_outer" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +For "for" +LParen "(" +Let "let" +Word "x" +In "in" +LBracket "[" +IntegerLiteral "2" +Comma "," +IntegerLiteral "3" +RBracket "]" +RParen ")" +LBrace "{" +Word "x" +Semicolon ";" +RBrace "}" +Semicolon ";" +Word "x" +RBrace "}" +Function "function" +Word "watch_block_cleanup" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Watch "watch" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +LBrace "{" +Watch "watch" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "2" +Semicolon ";" +Word "x" +Semicolon ";" +RBrace "}" +Semicolon ";" +Word "x" +RBrace "}" +Function "function" +Word "nested_outer_restored" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "2" +Semicolon ";" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "3" +Semicolon ";" +Word "x" +Semicolon ";" +RBrace "}" +Semicolon ";" +Word "x" +Semicolon ";" +RBrace "}" +Semicolon ";" +Word "x" +RBrace "}" +Function "function" +Word "capture_before_after_shadow" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +Let "let" +Word "g" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "x" +RBrace "}" +Semicolon ";" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "2" +Semicolon ";" +Let "let" +Word "f" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "x" +RBrace "}" +Semicolon ";" +Word "g" +LParen "(" +RParen ")" +Star "*" +IntegerLiteral "10" +Plus "+" +Word "f" +LParen "(" +RParen ")" +RBrace "}" +Function "function" +Word "nested_lambda_capture" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "7" +Semicolon ";" +Let "let" +Word "outer" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "inner" +Equals "=" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Word "x" +RBrace "}" +Semicolon ";" +Word "inner" +LParen "(" +RParen ")" +RBrace "}" +Semicolon ";" +Word "outer" +LParen "(" +RParen ")" +RBrace "}" +Slash "/" +Slash "/" +Word "A" +Error "`" +Let "let" +Error "`" +Word "declared" +Word "inside" +Word "a" +Error "`" +While "while" +Error "`" +Word "body" +Word "must" +Word "NOT" +Word "leak" +Word "to" +Word "the" +Word "enclosing" +Slash "/" +Slash "/" +Word "scope" +Dot "." +Word "The" +Word "inner" +Error "`" +Word "x" +Equals "=" +IntegerLiteral "99" +Error "`" +Word "shadows" +Word "the" +Word "outer" +Error "`" +Word "x" +Error "`" +For "for" +Word "the" +Word "duration" +Word "of" +Word "the" +Slash "/" +Slash "/" +Word "body" +Word "only" +Error "—" +Word "after" +Word "the" +Word "loop" +Comma "," +Error "`" +Word "x" +Error "`" +Word "resolves" +Word "to" +Word "the" +Word "outermost" +Word "binding" +Dot "." +Function "function" +Word "while_body_restores_outer" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +Let "let" +Word "once" +Equals "=" +Word "true" +Semicolon ";" +While "while" +LParen "(" +Word "once" +RParen ")" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "99" +Semicolon ";" +Word "once" +Equals "=" +Word "false" +Semicolon ";" +RBrace "}" +Semicolon ";" +Word "x" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__02_parser__lexical_scoping.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__02_parser__lexical_scoping.snap new file mode 100644 index 0000000000..c0d9304d4c --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__02_parser__lexical_scoping.snap @@ -0,0 +1,504 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "branch_locals" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "b" + COLON ":" + TYPE_EXPR "bool" + WORD "bool" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR "(b)" + L_PAREN "(" + WORD "b" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let a = 1;" + KW_LET "let" + WORD "a" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + WORD "a" + R_BRACE "}" + KW_ELSE "else" + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let a = 2;" + KW_LET "let" + WORD "a" + EQUALS "=" + INTEGER_LITERAL "2" + SEMICOLON ";" + WORD "a" + R_BRACE "}" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "same_scope_shadow" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 1;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + LET_STMT "let x = 2;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "2" + SEMICOLON ";" + WORD "x" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "initializer_uses_previous" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 1;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + LET_STMT + KW_LET "let" + WORD "x" + EQUALS "=" + BINARY_EXPR "x + 1" + WORD "x" + PLUS "+" + INTEGER_LITERAL "1" + SEMICOLON ";" + WORD "x" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "shadow_param" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "x" + EQUALS "=" + BINARY_EXPR "x + 1" + WORD "x" + PLUS "+" + INTEGER_LITERAL "1" + SEMICOLON ";" + WORD "x" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "outer_restored" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 1;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 2;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "2" + SEMICOLON ";" + R_BRACE "}" + SEMICOLON ";" + WORD "x" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "declared_type_restored" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "string" + WORD "string" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "x" + COLON ":" + TYPE_EXPR "string" + WORD "string" + EQUALS "=" + STRING_LITERAL "outer" + QUOTE """ + WORD "outer" + QUOTE """ + SEMICOLON ";" + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "x" + COLON ":" + TYPE_EXPR "int" + WORD "int" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + R_BRACE "}" + SEMICOLON ";" + WORD "x" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "for_loop_restores_outer" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 1;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + FOR_EXPR + KW_FOR "for" + L_PAREN "(" + LET_STMT "let x" + KW_LET "let" + WORD "x" + KW_IN "in" + ARRAY_LITERAL "[2, 3]" + L_BRACKET "[" + INTEGER_LITERAL "2" + COMMA "," + INTEGER_LITERAL "3" + R_BRACKET "]" + R_PAREN ")" + BLOCK_EXPR "{ + x; + }" + L_BRACE "{" + WORD "x" + SEMICOLON ";" + R_BRACE "}" + SEMICOLON ";" + WORD "x" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "watch_block_cleanup" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + WATCH_LET "watch let x = 1;" + KW_WATCH "watch" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + BLOCK_EXPR + L_BRACE "{" + WATCH_LET "watch let x = 2;" + KW_WATCH "watch" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "2" + SEMICOLON ";" + WORD "x" + SEMICOLON ";" + R_BRACE "}" + SEMICOLON ";" + WORD "x" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "nested_outer_restored" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 1;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 2;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "2" + SEMICOLON ";" + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 3;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "3" + SEMICOLON ";" + WORD "x" + SEMICOLON ";" + R_BRACE "}" + SEMICOLON ";" + WORD "x" + SEMICOLON ";" + R_BRACE "}" + SEMICOLON ";" + WORD "x" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "capture_before_after_shadow" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 1;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + LET_STMT + KW_LET "let" + WORD "g" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ x }" + L_BRACE "{" + WORD "x" + R_BRACE "}" + SEMICOLON ";" + LET_STMT "let x = 2;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "2" + SEMICOLON ";" + LET_STMT + KW_LET "let" + WORD "f" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ x }" + L_BRACE "{" + WORD "x" + R_BRACE "}" + SEMICOLON ";" + BINARY_EXPR + BINARY_EXPR + CALL_EXPR + WORD "g" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + STAR "*" + INTEGER_LITERAL "10" + PLUS "+" + CALL_EXPR + WORD "f" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "nested_lambda_capture" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 7;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "7" + SEMICOLON ";" + LET_STMT + KW_LET "let" + WORD "outer" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR + L_BRACE "{" + LET_STMT + KW_LET "let" + WORD "inner" + EQUALS "=" + LAMBDA_EXPR + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + BLOCK_EXPR "{ x }" + L_BRACE "{" + WORD "x" + R_BRACE "}" + SEMICOLON ";" + CALL_EXPR + WORD "inner" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + SEMICOLON ";" + CALL_EXPR + WORD "outer" + CALL_ARGS "()" + L_PAREN "(" + R_PAREN ")" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "while_body_restores_outer" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 1;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + LET_STMT "let once = true;" + KW_LET "let" + WORD "once" + EQUALS "=" + WORD "true" + SEMICOLON ";" + WHILE_STMT + KW_WHILE "while" + PAREN_EXPR "(once)" + L_PAREN "(" + WORD "once" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 99;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "99" + SEMICOLON ";" + BINARY_EXPR "once = false" + WORD "once" + EQUALS "=" + WORD "false" + SEMICOLON ";" + R_BRACE "}" + SEMICOLON ";" + WORD "x" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__03_hir.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__03_hir.snap new file mode 100644 index 0000000000..280360fb90 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__03_hir.snap @@ -0,0 +1,45 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== HIR2 === +function user.branch_locals(b: bool) -> int [expr] { + { } if (b) { let a = 1 } a else { let a = 2 } a +} +function user.capture_before_after_shadow() -> int [expr] { + { let x = 1; let g = () -> int { { } x }; let x = 2; let f = () -> int { { } x } } g() Mul 10 Add f() +} +function user.declared_type_restored() -> string [expr] { + { let x: string = "outer"; { let x: int = 1 } } x +} +function user.for_loop_restores_outer() -> int [expr] { + { let x = 1; for x in [2, 3] { x } } x +} +function user.initializer_uses_previous() -> int [expr] { + { let x = 1; let x = x Add 1 } x +} +function user.nested_lambda_capture() -> int [expr] { + { let x = 7; let outer = () -> int { { let inner = () -> int { { } x } } inner() } } outer() +} +function user.nested_outer_restored() -> int [expr] { + { let x = 1; { let x = 2; { let x = 3; x }; x } } x +} +function user.outer_restored() -> int [expr] { + { let x = 1; { let x = 2 } } x +} +function user.same_scope_shadow() -> int [expr] { + { let x = 1; let x = 2 } x +} +function user.shadow_param(x: int) -> int [expr] { + { let x = x Add 1 } x +} +function user.watch_block_cleanup() -> int [expr] { + { let x = 1; { let x = 2; x } } x +} +function user.while_body_restores_outer() -> int [expr] { + { let x = 1; let once = true; while once { let x = 99; once = false } } x +} + +--- captures --- +lambda () in capture_before_after_shadow: captures [x] +lambda () in capture_before_after_shadow: captures [x] +lambda () in ?: captures [x] diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__04_5_mir.snap new file mode 100644 index 0000000000..8c162caa85 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__04_5_mir.snap @@ -0,0 +1,355 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== MIR2 === +fn user.branch_locals(b: bool) -> int { + // Locals: + let _0: int // _0 // return + let _1: bool // b // param + let _2: int // a + let _3: int // a + + bb0: { + branch copy _1 -> [bb2, bb1]; + } + + bb1: { + _3 = const 2_i64; + _0 = copy _3; + goto -> bb3; + } + + bb2: { + _2 = const 1_i64; + _0 = copy _2; + goto -> bb3; + } + + bb3: { + return; + } +} + +fn user.capture_before_after_shadow() -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x [captured] + let _2: () -> int throws never // g + let _3: int // x [captured] + let _4: () -> int throws never // f + let _5: int + let _6: int + let _7: () -> int throws never + let _8: int + let _9: () -> int throws never + + bb0: { + _1 = const 1_i64; + _2 = make_closure lambda[0](copy _1); + _3 = const 2_i64; + _4 = make_closure lambda[1](copy _3); + _7 = copy _2; + _6 = call copy _7() -> [bb1]; + } + + bb1: { + _5 = copy _6 * const 10_i64; + _9 = copy _4; + _8 = call copy _9() -> [bb2]; + } + + bb2: { + _0 = copy _5 + copy _8; + goto -> bb3; + } + + bb3: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = copy capture[0]; + goto -> bb1; + } + + bb1: { + return; + } +} + +// lambda[1] +fn .() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = copy capture[0]; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.declared_type_restored() -> string { + // Locals: + let _0: string // _0 // return + let _1: "outer" // x + + bb0: { + _1 = const "outer"; + _0 = copy _1; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.for_loop_restores_outer() -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x + let _2: int[] + let _3: int // __for_idx + let _4: int + let _5: bool + let _6: int + let _7: int // x + + bb0: { + _1 = const 1_i64; + _2 = [const 2_i64, const 3_i64]; + _3 = const 0_i64; + goto -> bb1; + } + + bb1: { + _4 = len(_2); + _5 = copy _3 < copy _4; + branch copy _5 -> [bb4, bb2]; + } + + bb2: { + _0 = copy _1; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + _6 = copy _2[_3]; + fresh_cell(_7); + _7 = copy _6; + goto -> bb5; + } + + bb5: { + _3 = copy _3 + const 1_i64; + goto -> bb1; + } +} + +fn user.initializer_uses_previous() -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x + let _2: int // x + let _3: int + + bb0: { + _1 = const 1_i64; + _3 = copy _1; + _2 = copy _3 + const 1_i64; + _0 = copy _2; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.nested_lambda_capture() -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x [captured] + let _2: () -> int throws never // outer + let _3: () -> int throws never + + bb0: { + _1 = const 7_i64; + _2 = make_closure lambda[0](copy _1); + _3 = copy _2; + _0 = call copy _3() -> [bb1]; + } + + bb1: { + return; + } +} + +// lambda[0] +fn .() -> null { + // Locals: + let _0: null // _0 // return + let _1: () -> int throws never // inner + let _2: () -> int throws never + + bb0: { + _1 = make_closure lambda[0](copy capture[0]); + _2 = copy _1; + _0 = call copy _2() -> [bb1]; + } + + bb1: { + return; + } +} + +// lambda[0] +fn ., 1)>() -> null { + // Locals: + let _0: null // _0 // return + + bb0: { + _0 = copy capture[0]; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.nested_outer_restored() -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x + let _2: int // x + let _3: int // x + + bb0: { + _1 = const 1_i64; + _2 = const 2_i64; + _3 = const 3_i64; + _0 = copy _1; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.outer_restored() -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x + + bb0: { + _1 = const 1_i64; + _0 = copy _1; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.same_scope_shadow() -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x + + bb0: { + _1 = const 2_i64; + _0 = copy _1; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.shadow_param(x: int) -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x // param + let _2: int // x + + bb0: { + _2 = copy _1 + const 1_i64; + _0 = copy _2; + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.watch_block_cleanup() -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x + let _2: int // x + + bb0: { + _1 = const 1_i64; + _2 = const 2_i64; + unwatch(_2); + _0 = copy _1; + unwatch(_1); + goto -> bb1; + } + + bb1: { + return; + } +} + +fn user.while_body_restores_outer() -> int { + // Locals: + let _0: int // _0 // return + let _1: int // x + let _2: bool // once + let _3: bool + + bb0: { + _1 = const 1_i64; + _2 = const true; + goto -> bb1; + } + + bb1: { + _3 = copy _2; + branch copy _3 -> [bb4, bb2]; + } + + bb2: { + _0 = copy _1; + goto -> bb3; + } + + bb3: { + return; + } + + bb4: { + _2 = const false; + goto -> bb1; + } +} diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__04_tir.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__04_tir.snap new file mode 100644 index 0000000000..cba6df88d1 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__04_tir.snap @@ -0,0 +1,143 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== TIR2 === +function user.branch_locals(b: bool) -> int throws never { + { : int + if (b : bool) : int + { : int + let a = 1 : 1 -> int + a : int + } + else + { : int + let a = 2 : 2 -> int + a : int + } + } +} +function user.same_scope_shadow() -> int throws never { + { : int + let x = 1 : 1 -> int + let x = 2 : 2 -> int + x : int + } +} +function user.initializer_uses_previous() -> int throws never { + { : int + let x = 1 : 1 -> int + let x = x + 1 : int + x : int + } +} +function user.shadow_param(x: int) -> int throws never { + { : int + let x = x + 1 : int + x : int + } +} +function user.outer_restored() -> int throws never { + { : int + let x = 1 : 1 -> int + { : void + let x = 2 : 2 -> int + } + x : int + } +} +function user.declared_type_restored() -> string throws never { + { : "outer" + let x = "outer" : "outer" + { : void + let x = 1 : 1 + } + x : "outer" + } +} +function user.for_loop_restores_outer() -> int throws never { + { : int + let x = 1 : 1 -> int + for x in [2, 3] + { : void + x : int + } + x : int + } +} +function user.watch_block_cleanup() -> int throws never { + { : int + let x = 1 : 1 -> int + { : void + let x = 2 : 2 -> int + x : int + } + x : int + } +} +function user.nested_outer_restored() -> int throws never { + { : int + let x = 1 : 1 -> int + { : void + let x = 2 : 2 -> int + { : void + let x = 3 : 3 -> int + x : int + } + x : int + } + x : int + } +} +function user.capture_before_after_shadow() -> int throws never { + { : int + let x = 1 : 1 -> int + let g = : () -> int throws never + () -> int { ... } : () -> int throws never + { + x + } + let x = 2 : 2 -> int + let f = : () -> int throws never + () -> int { ... } : () -> int throws never + { + x + } + g() * 10 + f() : int + } +} +lambda user.capture_before_after_shadow { +} +lambda user.capture_before_after_shadow { +} +function user.nested_lambda_capture() -> int throws never { + { : int + let x = 7 : 7 -> int + let outer = : () -> int throws never + () -> int { ... } : () -> int throws never + { + let inner = ... + () -> int { ... } + { + x + } + inner() + } + outer() : int + } +} +lambda user.nested_lambda_capture { +} +lambda user.nested_lambda_capture { +} +function user.while_body_restores_outer() -> int throws never { + { : int + let x = 1 : 1 -> int + let once = true : true -> bool + while once + { : void + let x = 99 : 99 -> int + once = false : false + } + x : int + } +} diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__05_diagnostics.snap new file mode 100644 index 0000000000..40cf564409 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__05_diagnostics.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== COMPILER2 DIAGNOSTICS === +No errors found. diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__06_codegen.snap new file mode 100644 index 0000000000..f80f4edce1 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__06_codegen.snap @@ -0,0 +1,160 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +function user.branch_locals(b: bool) -> int { + load_var b + pop_jump_if_false L0 + jump L1 + + L0: + load_const 2 + jump L2 + + L1: + load_const 1 + + L2: + return +} + +function user.capture_before_after_shadow() -> int { + load_var ?1 + make_cell + store_var ?1 + load_var ?2 + make_cell + store_var ?2 + load_const 1 + store_deref ?1 + load_const 2 + store_deref ?2 + load_var x + make_closure ., 1 + call_indirect + store_var _6 + load_var x + make_closure ., 1 + call_indirect + store_var _8 + load_var _6 + load_const 10 + bin_op * + load_var _8 + bin_op + + return +} + +function user.declared_type_restored() -> string { + load_const "outer" + return +} + +function user.for_loop_restores_outer() -> int { + load_const 2 + load_const 3 + alloc_array 2 + store_var _2 + load_const 0 + store_var __for_idx + + L0: + load_var __for_idx + load_var _2 + call baml.Array.length + cmp_op < + pop_jump_if_false L1 + jump L2 + + L1: + load_const 1 + return + + L2: + load_var _2 + load_var __for_idx + load_array_element + store_var x + load_var __for_idx + load_const 1 + bin_op + + store_var __for_idx + jump L0 +} + +function user.initializer_uses_previous() -> int { + load_const 1 + load_const 1 + bin_op + + return +} + +function user.nested_lambda_capture() -> int { + load_var ?1 + make_cell + store_var ?1 + load_const 7 + store_deref ?1 + load_var x + make_closure ., 1 + call_indirect + return +} + +function user.nested_outer_restored() -> int { + load_const 2 + store_var x + load_const 3 + store_var x + load_const 1 + return +} + +function user.outer_restored() -> int { + load_const 1 + return +} + +function user.same_scope_shadow() -> int { + load_const 2 + return +} + +function user.shadow_param(x: int) -> int { + load_var x + load_const 1 + bin_op + + return +} + +function user.watch_block_cleanup() -> int { + load_const 1 + store_var x + load_const "x" + load_const null + watch x + load_const 2 + store_var x + load_const "x" + load_const null + watch x + unwatch x + load_var x + unwatch x + return +} + +function user.while_body_restores_outer() -> int { + load_const true + + L0: + pop_jump_if_false L1 + jump L2 + + L1: + load_const 1 + return + + L2: + load_const false + jump L0 +} diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__10_formatter__lexical_scoping.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__10_formatter__lexical_scoping.snap new file mode 100644 index 0000000000..d2a68a6233 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping/baml_tests__lexical_scoping__10_formatter__lexical_scoping.snap @@ -0,0 +1,113 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +function branch_locals(b: bool) -> int { + if (b) { + let a = 1; + a + } else { + let a = 2; + a + } +} + +function same_scope_shadow() -> int { + let x = 1; + let x = 2; + x +} + +function initializer_uses_previous() -> int { + let x = 1; + let x = x + 1; + x +} + +function shadow_param(x: int) -> int { + let x = x + 1; + x +} + +function outer_restored() -> int { + let x = 1; + { + let x = 2; + }; + x +} + +function declared_type_restored() -> string { + let x: string = "outer"; + { + let x: int = 1; + }; + x +} + +function for_loop_restores_outer() -> int { + let x = 1; + for (let x in [2, 3]) { + x; + } + ; + x +} + +function watch_block_cleanup() -> int { + watch let x = 1; + { + watch let x = 2; + x; + }; + x +} + +function nested_outer_restored() -> int { + let x = 1; + { + let x = 2; + { + let x = 3; + x; + }; + x; + }; + x +} + +function capture_before_after_shadow() -> int { + let x = 1; + let g = () -> int { + x + }; + let x = 2; + let f = () -> int { + x + }; + g() * 10 + f() +} + +function nested_lambda_capture() -> int { + let x = 7; + let outer = () -> int { + let inner = () -> int { + x + }; + inner() + }; + outer() +} + +// A `let` declared inside a `while` body must NOT leak to the enclosing +// scope. The inner `x = 99` shadows the outer `x` for the duration of the +// body only — after the loop, `x` resolves to the outermost binding. +function while_body_restores_outer() -> int { + let x = 1; + let once = true; + while (once) { + let x = 99; + once = false; + } + ; + x +} diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__01_lexer__lexical_scoping_errors.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__01_lexer__lexical_scoping_errors.snap new file mode 100644 index 0000000000..8955a72f34 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__01_lexer__lexical_scoping_errors.snap @@ -0,0 +1,69 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +Function "function" +Word "block_does_not_leak" +LParen "(" +Word "b" +Colon ":" +Word "bool" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +If "if" +LParen "(" +Word "b" +RParen ")" +LBrace "{" +Let "let" +Word "a" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +RBrace "}" +Semicolon ";" +Word "a" +RBrace "}" +Function "function" +Word "standalone_block_does_not_leak" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +LBrace "{" +Let "let" +Word "x" +Equals "=" +IntegerLiteral "1" +Semicolon ";" +RBrace "}" +Semicolon ";" +Word "x" +RBrace "}" +Function "function" +Word "for_binding_does_not_leak" +LParen "(" +RParen ")" +Arrow "->" +Word "int" +LBrace "{" +For "for" +LParen "(" +Let "let" +Word "x" +In "in" +LBracket "[" +IntegerLiteral "1" +Comma "," +IntegerLiteral "2" +RBracket "]" +RParen ")" +LBrace "{" +Word "x" +Semicolon ";" +RBrace "}" +Semicolon ";" +Word "x" +RBrace "}" diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__02_parser__lexical_scoping_errors.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__02_parser__lexical_scoping_errors.snap new file mode 100644 index 0000000000..35b10c7d5f --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__02_parser__lexical_scoping_errors.snap @@ -0,0 +1,103 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== SYNTAX TREE === +SOURCE_FILE + FUNCTION_DEF + KW_FUNCTION "function" + WORD "block_does_not_leak" + PARAMETER_LIST + L_PAREN "(" + PARAMETER + WORD "b" + COLON ":" + TYPE_EXPR "bool" + WORD "bool" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + IF_EXPR + KW_IF "if" + PAREN_EXPR "(b)" + L_PAREN "(" + WORD "b" + R_PAREN ")" + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let a = 1;" + KW_LET "let" + WORD "a" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + R_BRACE "}" + SEMICOLON ";" + WORD "a" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "standalone_block_does_not_leak" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + BLOCK_EXPR + L_BRACE "{" + LET_STMT "let x = 1;" + KW_LET "let" + WORD "x" + EQUALS "=" + INTEGER_LITERAL "1" + SEMICOLON ";" + R_BRACE "}" + SEMICOLON ";" + WORD "x" + R_BRACE "}" + FUNCTION_DEF + KW_FUNCTION "function" + WORD "for_binding_does_not_leak" + PARAMETER_LIST "()" + L_PAREN "(" + R_PAREN ")" + ARROW "->" + TYPE_EXPR "int" + WORD "int" + EXPR_FUNCTION_BODY + BLOCK_EXPR + L_BRACE "{" + FOR_EXPR + KW_FOR "for" + L_PAREN "(" + LET_STMT "let x" + KW_LET "let" + WORD "x" + KW_IN "in" + ARRAY_LITERAL "[1, 2]" + L_BRACKET "[" + INTEGER_LITERAL "1" + COMMA "," + INTEGER_LITERAL "2" + R_BRACKET "]" + R_PAREN ")" + BLOCK_EXPR "{ + x; + }" + L_BRACE "{" + WORD "x" + SEMICOLON ";" + R_BRACE "}" + SEMICOLON ";" + WORD "x" + R_BRACE "}" + +=== ERRORS === +None diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__03_hir.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__03_hir.snap new file mode 100644 index 0000000000..e40908fe10 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__03_hir.snap @@ -0,0 +1,13 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== HIR2 === +function user.block_does_not_leak(b: bool) -> int [expr] { + { if (b) { let a = 1 } } a +} +function user.for_binding_does_not_leak() -> int [expr] { + { for x in [1, 2] { x } } x +} +function user.standalone_block_does_not_leak() -> int [expr] { + { { let x = 1 } } x +} diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__04_5_mir.snap new file mode 100644 index 0000000000..0778f58486 --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__04_5_mir.snap @@ -0,0 +1,5 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== MIR2 === +Skipped: project has diagnostic errors diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__04_tir.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__04_tir.snap new file mode 100644 index 0000000000..b9e52d773c --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__04_tir.snap @@ -0,0 +1,34 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +assertion_line: 18112 +--- +=== TIR2 === +function user.block_does_not_leak(b: bool) -> int throws never { + { : unknown + if (b : bool) : void + { : void + let a = 1 : 1 -> int + } + a : unknown + } + !! 80..81: unresolved name: a +} +function user.standalone_block_does_not_leak() -> int throws never { + { : unknown + { : void + let x = 1 : 1 -> int + } + x : unknown + } + !! 162..163: unresolved name: x +} +function user.for_binding_does_not_leak() -> int throws never { + { : unknown + for x in [1, 2] + { : void + x : int + } + x : unknown + } + !! 253..254: unresolved name: x +} diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__05_diagnostics.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__05_diagnostics.snap new file mode 100644 index 0000000000..2ebc36679a --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__05_diagnostics.snap @@ -0,0 +1,33 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +=== COMPILER2 DIAGNOSTICS === + [type] Error: unresolved name: a + ╭─[ lexical_scoping_errors.baml:5:3 ] + │ + 5 │ a + │ ┬ + │ ╰── unresolved name: a + │ + │ Note: Error code: E0003 +───╯ + + [type] Error: unresolved name: x + ╭─[ lexical_scoping_errors.baml:12:3 ] + │ + 12 │ x + │ ┬ + │ ╰── unresolved name: x + │ + │ Note: Error code: E0003 +────╯ + + [type] Error: unresolved name: x + ╭─[ lexical_scoping_errors.baml:19:3 ] + │ + 19 │ x + │ ┬ + │ ╰── unresolved name: x + │ + │ Note: Error code: E0003 +────╯ diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__06_codegen.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__06_codegen.snap new file mode 100644 index 0000000000..3cdfd35c2a --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__06_codegen.snap @@ -0,0 +1,48 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +function user.block_does_not_leak(b: bool) -> int { + load_var b + pop_jump_if_false L0 + + L0: + load_const null + return +} + +function user.for_binding_does_not_leak() -> int { + load_const 1 + load_const 2 + alloc_array 2 + store_var _1 + load_const 0 + store_var __for_idx + + L0: + load_var __for_idx + load_var _1 + call baml.Array.length + cmp_op < + pop_jump_if_false L1 + jump L2 + + L1: + load_const null + return + + L2: + load_var _1 + load_var __for_idx + load_array_element + store_var x + load_var __for_idx + load_const 1 + bin_op + + store_var __for_idx + jump L0 +} + +function user.standalone_block_does_not_leak() -> int { + load_const null + return +} diff --git a/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__10_formatter__lexical_scoping_errors.snap b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__10_formatter__lexical_scoping_errors.snap new file mode 100644 index 0000000000..da7be6a1ca --- /dev/null +++ b/baml_language/crates/baml_tests/snapshots/lexical_scoping_errors/baml_tests__lexical_scoping_errors__10_formatter__lexical_scoping_errors.snap @@ -0,0 +1,24 @@ +--- +source: crates/baml_tests/src/generated_tests.rs +--- +function block_does_not_leak(b: bool) -> int { + if (b) { + let a = 1; + }; + a +} + +function standalone_block_does_not_leak() -> int { + { + let x = 1; + }; + x +} + +function for_binding_does_not_leak() -> int { + for (let x in [1, 2]) { + x; + } + ; + x +} diff --git a/baml_language/crates/baml_tests/snapshots/parser_statements/baml_tests__parser_statements__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/parser_statements/baml_tests__parser_statements__04_5_mir.snap index 821b523dd1..972e80f2eb 100644 --- a/baml_language/crates/baml_tests/snapshots/parser_statements/baml_tests__parser_statements__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/parser_statements/baml_tests__parser_statements__04_5_mir.snap @@ -1,6 +1,5 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 28271 --- === MIR2 === fn user.add_assign() -> int { @@ -380,7 +379,7 @@ fn user.continue_in_for() -> int { let _1: int // i let _2: bool let _3: int - let _4: void + let _4: int bb0: { _1 = const 0_i64; diff --git a/baml_language/crates/baml_tests/snapshots/testset_dynamic/baml_tests__testset_dynamic__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/testset_dynamic/baml_tests__testset_dynamic__04_5_mir.snap index 176c531684..58644e6cd3 100644 --- a/baml_language/crates/baml_tests/snapshots/testset_dynamic/baml_tests__testset_dynamic__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/testset_dynamic/baml_tests__testset_dynamic__04_5_mir.snap @@ -1,6 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 37862 +assertion_line: 38844 --- === MIR2 === fn user.$init_test_24(registry: testing.TestCollector) -> null { @@ -87,7 +87,7 @@ fn .(testset: testing.TestCollector) -> null { fn ., 1)>() -> null { // Locals: let _0: null // _0 // return - let _1: void + let _1: string bb0: { _1 = copy capture[0]; diff --git a/baml_language/crates/baml_tests/snapshots/testset_vibes_nested/baml_tests__testset_vibes_nested__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/testset_vibes_nested/baml_tests__testset_vibes_nested__04_5_mir.snap index 17bcb458f9..da4bb034b8 100644 --- a/baml_language/crates/baml_tests/snapshots/testset_vibes_nested/baml_tests__testset_vibes_nested__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/testset_vibes_nested/baml_tests__testset_vibes_nested__04_5_mir.snap @@ -1,5 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs +assertion_line: 39494 --- === MIR2 === fn user.$init_test_24(registry: testing.TestCollector) -> null { @@ -91,7 +92,7 @@ fn ., 1)>(testset: testing.TestCollector) -> n let _1: testing.TestCollector // testset // param let _2: baml.http.Response // req let _3: string - let _4: void + let _4: string let _5: future let _6: string // data let _7: future @@ -252,7 +253,7 @@ fn ., 3)>, 4)>(testset: testing.TestCo let _0: null // _0 // return let _1: testing.TestCollector // testset // param let _2: string[] // tests - let _3: void + let _3: string let _4: string[] let _5: int // __for_idx_3 let _6: int diff --git a/baml_language/crates/baml_tests/snapshots/testset_with_setup/baml_tests__testset_with_setup__04_5_mir.snap b/baml_language/crates/baml_tests/snapshots/testset_with_setup/baml_tests__testset_with_setup__04_5_mir.snap index 24df5d9897..19a5436613 100644 --- a/baml_language/crates/baml_tests/snapshots/testset_with_setup/baml_tests__testset_with_setup__04_5_mir.snap +++ b/baml_language/crates/baml_tests/snapshots/testset_with_setup/baml_tests__testset_with_setup__04_5_mir.snap @@ -1,6 +1,6 @@ --- source: crates/baml_tests/src/generated_tests.rs -assertion_line: 38837 +assertion_line: 39819 --- === MIR2 === fn user.$init_test_24(registry: testing.TestCollector) -> null { @@ -64,8 +64,8 @@ fn ., 1)>() -> null { // Locals: let _0: null // _0 // return let _1: null - let _2: void - let _3: void + let _2: string + let _3: int bb0: { _2 = copy capture[0]; @@ -86,7 +86,7 @@ fn ., 1)>() -> null { fn ., 2)>() -> null { // Locals: let _0: null // _0 // return - let _1: void + let _1: string bb0: { _1 = copy capture[0]; diff --git a/baml_language/crates/baml_tests/src/compiler2_hir.rs b/baml_language/crates/baml_tests/src/compiler2_hir.rs index c539b8d923..eb9a257bf9 100644 --- a/baml_language/crates/baml_tests/src/compiler2_hir.rs +++ b/baml_language/crates/baml_tests/src/compiler2_hir.rs @@ -540,32 +540,224 @@ mod tests { assert!(sites.iter().all(|s| s.kind == DefinitionKind::Variant)); } - /// Duplicate let-bindings in the same function produce a DuplicateDefinition diagnostic. + /// Same-scope let shadowing is legal and does not produce duplicate diagnostics. #[test] - fn duplicate_let_binding_produces_diagnostic() { - use baml_compiler2_hir::{contributions::DefinitionKind, diagnostic::Hir2Diagnostic}; + fn same_scope_let_shadowing_has_no_duplicate_diagnostic() { + use baml_compiler2_hir::diagnostic::Hir2Diagnostic; let mut db = make_db(); let file = db.add_file( - "dup_let.baml", + "shadow_let.baml", "function foo() -> int {\n let x = 1;\n let x = 2;\n return x;\n}", ); let index = file_semantic_index(&db, file); let diags = index.diagnostics(); - let dups: Vec<_> = diags + assert!(!diags.iter().any( + |d| matches!(d, Hir2Diagnostic::DuplicateDefinition { name, .. } if name == &Name::new("x")) + )); + } + + #[test] + fn shadowing_initializer_resolves_previous_binding() { + use baml_compiler2_hir::scope::ScopeKind; + use text_size::TextSize; + + let mut db = make_db(); + let file = db.add_file( + "initializer_shadow.baml", + "function foo() -> int {\n let x = 1;\n let x = x + 1;\n x\n}", + ); + + let index = file_semantic_index(&db, file); + let function_scope = index + .scopes .iter() - .filter(|d| matches!(d, Hir2Diagnostic::DuplicateDefinition { name, .. } if name == &Name::new("x"))) - .collect(); - assert_eq!(dups.len(), 1); + .enumerate() + .find_map(|(idx, scope)| { + matches!(scope.kind, ScopeKind::Function) + .then_some(baml_compiler2_hir::scope::FileScopeId::new(idx as u32)) + }) + .expect("function scope"); + let x_bindings = index.scope_bindings[function_scope.index() as usize] + .bindings + .iter() + .filter(|binding| binding.name == Name::new("x")) + .collect::>(); + assert_eq!(x_bindings.len(), 2); + + let text = file.text(&db); + let init_x_offset = TextSize::from(text.find("x + 1").expect("initializer x") as u32); + let use_scope = index.scope_at_offset(init_x_offset, Some(&Name::new("foo"))); + let resolved = index + .visible_binding_at(use_scope, init_x_offset, &Name::new("x")) + .expect("initializer x should resolve"); + + assert_eq!(resolved.site, x_bindings[0].site); + } - let Hir2Diagnostic::DuplicateDefinition { scope, sites, .. } = dups[0] else { - panic!("expected DuplicateDefinition diagnostic"); + #[test] + fn lambda_does_not_capture_its_own_nested_block_binding() { + use baml_compiler2_hir::scope::ScopeKind; + + let mut db = make_db(); + let file = db.add_file( + "lambda_local_block.baml", + "function foo() -> int {\n let f = () -> int {\n { let x = 1; x }\n };\n f()\n}", + ); + + let index = file_semantic_index(&db, file); + let lambda_scope = index + .scopes + .iter() + .enumerate() + .find_map(|(idx, scope)| { + matches!(scope.kind, ScopeKind::Lambda) + .then_some(baml_compiler2_hir::scope::FileScopeId::new(idx as u32)) + }) + .expect("lambda scope"); + + assert!( + index.scope_bindings[lambda_scope.index() as usize] + .captures + .is_empty(), + "lambda-local block binding should not be recorded as a capture" + ); + } + + /// A `let` inside a while body must not share the enclosing function's + /// scope. The body is an `Expr::Block` which pushes its own scope, so + /// the inner `let x = 99` must register in that scope, not the function + /// scope. + /// + /// This test pins the desired invariant. Without `Stmt::While` walking the + /// body inside its own block scope, find-references / rename / capture + /// analysis would walk ancestors from the wrong starting scope. + #[test] + fn while_body_let_lives_in_inner_scope() { + use baml_compiler2_hir::scope::ScopeKind; + use text_size::TextSize; + + let mut db = make_db(); + let file = db.add_file( + "while_scope.baml", + "function foo() -> int {\n let x = 1;\n let once = true;\n while (once) {\n let x = 99;\n once = false;\n };\n x\n}", + ); + + let index = file_semantic_index(&db, file); + + // Locate the function scope. + let function_scope = index + .scopes + .iter() + .enumerate() + .find_map(|(idx, scope)| { + matches!(scope.kind, ScopeKind::Function) + .then_some(baml_compiler2_hir::scope::FileScopeId::new(idx as u32)) + }) + .expect("function scope"); + + // Find the offset of the inner `x = 99` token. + let text = file.text(&db); + let inner_x_decl = text.find("x = 99").expect("inner x decl"); + let inner_x_offset = TextSize::from(inner_x_decl as u32); + + let scope_at_inner = index.scope_at_offset(inner_x_offset, Some(&Name::new("foo"))); + assert_ne!( + scope_at_inner, function_scope, + "inner `let x = 99` inside while body must resolve to a non-function scope; got function scope" + ); + + // The inner binding must register in some descendant of the function + // scope, not the function scope itself. + let inner_bindings = index.scope_bindings[scope_at_inner.index() as usize] + .bindings + .iter() + .filter(|b| b.name == Name::new("x")) + .count(); + assert!( + inner_bindings >= 1, + "inner scope must contain the inner `x` binding" + ); + + // The outer `x = 1` binding must remain in the function scope. + let outer_x_in_function = index.scope_bindings[function_scope.index() as usize] + .bindings + .iter() + .filter(|b| b.name == Name::new("x")) + .count(); + assert_eq!( + outer_x_in_function, 1, + "outer `let x = 1` must stay in function scope; got {} `x` bindings", + outer_x_in_function + ); + } + + /// Verify the HIR scope tree contains Block, MatchArm, CatchClause, and + /// CatchArm scope kinds nested under a Function. A future regression + /// that drops one of these kinds (e.g. a refactor that name-keys some + /// lookup and "doesn't need" the explicit scope) will fail this test + /// loudly. + #[test] + fn scope_tree_includes_block_match_catch_kinds() { + use baml_compiler2_hir::scope::ScopeKind; + + let mut db = make_db(); + let file = db.add_file( + "scope_kinds.baml", + r#"function f(x: int) -> int { + let _local = 1 + { + let _block_local = 2 + } + let _matched = match (x) { + n => n + _ => 0 + } + try { + 1 + } catch (e) { + _ => 2 + } +}"#, + ); + + let index = file_semantic_index(&db, file); + + // Each kind must appear at least once. + let kinds: Vec<&ScopeKind> = index.scopes.iter().map(|s| &s.kind).collect(); + let has_kind = |kind: ScopeKind| { + kinds + .iter() + .any(|k| std::mem::discriminant(*k) == std::mem::discriminant(&kind)) }; - assert_eq!(scope.as_ref().unwrap(), &Name::new("foo")); - assert_eq!(sites.len(), 2); - assert!(sites.iter().all(|s| s.kind == DefinitionKind::Binding)); + + assert!( + has_kind(ScopeKind::Function), + "scope tree missing Function scope; kinds = {:?}", + kinds + ); + assert!( + has_kind(ScopeKind::Block), + "scope tree missing Block scope; kinds = {:?}", + kinds + ); + assert!( + has_kind(ScopeKind::MatchArm), + "scope tree missing MatchArm scope; kinds = {:?}", + kinds + ); + assert!( + has_kind(ScopeKind::CatchClause), + "scope tree missing CatchClause scope; kinds = {:?}", + kinds + ); + assert!( + has_kind(ScopeKind::CatchArm), + "scope tree missing CatchArm scope; kinds = {:?}", + kinds + ); } /// A field and a method with the same name in a class produce a cross-kind diagnostic. diff --git a/baml_language/crates/baml_tests/src/compiler2_tir/inference.rs b/baml_language/crates/baml_tests/src/compiler2_tir/inference.rs index b5be3e33e3..e4cb3d0149 100644 --- a/baml_language/crates/baml_tests/src/compiler2_tir/inference.rs +++ b/baml_language/crates/baml_tests/src/compiler2_tir/inference.rs @@ -2,7 +2,12 @@ use baml_base::Name; use baml_compiler2_hir::{package::PackageId, scope::ScopeKind}; -use baml_compiler2_tir::{inference::infer_scope_types, package_interface::package_interface}; +use baml_compiler2_tir::{ + inference::infer_scope_types, + package_interface::package_interface, + resolve::{ResolvedName, resolve_name_at_in_scope}, +}; +use text_size::TextSize; use super::support::{expr_type_in_function, make_db, render_tir}; @@ -54,6 +59,45 @@ fn let_binding_widens() { "); } +#[test] +fn resolver_initializer_shadowing_uses_previous_binding() { + let mut db = make_db(); + let file = db.add_file( + "test.baml", + "function f() -> int { let x = 1; let x = x + 1; x }", + ); + + let index = baml_compiler2_ppir::file_semantic_index(&db, file); + let function_scope = index + .scopes + .iter() + .enumerate() + .find_map(|(idx, scope)| { + (matches!(scope.kind, ScopeKind::Function) + && scope.name.as_ref().is_some_and(|name| name.as_str() == "f")) + .then_some(baml_compiler2_hir::scope::FileScopeId::new(idx as u32)) + }) + .expect("function scope"); + let x_bindings = index.scope_bindings[function_scope.index() as usize] + .bindings + .iter() + .filter(|binding| binding.name == Name::new("x")) + .collect::>(); + assert_eq!(x_bindings.len(), 2); + + let offset = TextSize::from(file.text(&db).find("x + 1").expect("initializer x") as u32); + let resolved = + resolve_name_at_in_scope(&db, file, offset, &Name::new("x"), Some(&Name::new("f"))); + + assert_eq!( + resolved, + ResolvedName::Local { + name: Name::new("x"), + definition_site: Some(x_bindings[0].site), + } + ); +} + #[test] fn class_field_access() { let mut db = make_db(); @@ -364,6 +408,89 @@ fn function_type_throws_package_interface_exports_effect_params() { ); } +#[test] +fn lambda_scope_retypes_capture_from_function_parameter() { + let mut db = make_db(); + let file = db.add_file( + "capture_param.baml", + "function main(x: int) -> int { let f = () -> int { x }; return f(); }", + ); + + let index = baml_compiler2_ppir::file_semantic_index(&db, file); + let lambda_scope_id = index + .scope_ids + .iter() + .copied() + .find(|scope_id| { + let scope = &index.scopes[scope_id.file_scope_id(&db).index() as usize]; + matches!(scope.kind, ScopeKind::Lambda) + }) + .expect("lambda scope"); + let lambda_inference = infer_scope_types(&db, lambda_scope_id); + + let item_tree = baml_compiler2_ppir::file_item_tree(&db, file); + let (main_id, _) = item_tree + .functions + .iter() + .find(|(_, func)| func.name.as_str() == "main") + .expect("main function"); + let main_loc = baml_compiler2_hir::loc::FunctionLoc::new(&db, file, main_id); + let main_body = baml_compiler2_ppir::function_body(&db, main_loc); + let baml_compiler2_hir::body::FunctionBody::Expr(main_expr_body) = main_body.as_ref() else { + panic!("main expression body"); + }; + let lambda_body = main_expr_body + .exprs + .iter() + .find_map(|(_, expr)| { + if let baml_compiler2_ast::Expr::Lambda(func_def) = expr + && let Some(baml_compiler2_ast::FunctionBodyDef::Expr(lambda_body, _)) = + &func_def.body + { + Some(lambda_body) + } else { + None + } + }) + .expect("lambda body"); + let root_expr = lambda_body.root_expr.expect("lambda root expr"); + + assert_eq!( + lambda_inference + .expression_type(root_expr) + .map(ToString::to_string), + Some("int".to_string()) + ); +} + +#[test] +fn lambda_parameter_shadowing_uses_parameter_declared_type() { + let mut db = make_db(); + let file = db.add_file( + "lambda_param_shadow.baml", + r#" +function main() -> int { + let x: string = ""; + let f = (x: int) -> int { + x = 1; + x + }; + f(0) +} +"#, + ); + + let output = render_tir(&db, file); + assert!( + !output.contains("type mismatch: expected string, got int"), + "lambda parameter assignment should use the parameter annotation, got:\n{output}" + ); + assert!( + output.contains("(x: int) -> int"), + "expected lambda parameter to keep its int type, got:\n{output}" + ); +} + #[test] fn returning_callback_forwarder_matches_omitted_function_type_return_annotation() { let mut db = make_db(); diff --git a/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs b/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs index f30076d56d..d68ad81050 100644 --- a/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs +++ b/baml_language/crates/baml_tests/src/compiler2_tir/mod.rs @@ -1060,14 +1060,7 @@ pub(crate) mod support { } // Collect expression types for this scope — skip if none - let mut has_expr_types = false; - for (_expr_id, owner_scope) in &index.expr_scopes { - if owner_scope.index() as usize == i { - has_expr_types = true; - break; - } - } - if !has_expr_types { + if inference.iter_expressions().next().is_none() { continue; } diff --git a/baml_language/crates/baml_tests/src/compiler2_tir/phase6.rs b/baml_language/crates/baml_tests/src/compiler2_tir/phase6.rs index b47ff08663..1021f315d0 100644 --- a/baml_language/crates/baml_tests/src/compiler2_tir/phase6.rs +++ b/baml_language/crates/baml_tests/src/compiler2_tir/phase6.rs @@ -1492,3 +1492,86 @@ function f(callback: ((x: int) -> int)?) -> int? { } "#); } + +#[test] +fn index_assignment_establishment_updates_let_binding_type() { + let mut db = make_db(); + let file = db.add_file( + "test.baml", + r#" +function f() -> int { + let xs = [] + xs[0] = 1 + return xs[0] +} +"#, + ); + let output = render_tir(&db, file); + + assert!( + output.contains("let xs = [] : never[] -> int[] (evolving)"), + "expected indexed assignment to sync the let binding type, got:\n{output}" + ); + assert!( + !output.contains("type mismatch"), + "did not expect indexed assignment establishment to produce a mismatch, got:\n{output}" + ); +} + +#[test] +fn lambda_body_container_establishment_does_not_leak_to_parent_scope() { + let mut db = make_db(); + let file = db.add_file( + "test.baml", + r#" +function f() -> int { + let xs = [] + let _f = () -> int { + let xs = [] + xs.push("inner") + 0 + } + xs.push(1) + return xs[0] +} +"#, + ); + let output = render_tir(&db, file); + + assert!( + output.contains("let xs = [] : never[] -> int[] (evolving)"), + "expected parent xs binding to be established by parent push, got:\n{output}" + ); + assert!( + !output.contains("type mismatch"), + "lambda-local container establishment should not affect parent xs, got:\n{output}" + ); +} + +#[test] +fn for_body_container_assignment_establishes_outer_type() { + let mut db = make_db(); + let file = db.add_file( + "test.baml", + r#" +function f() -> int { + let xs = [] + for (let n in []) { + xs.push("not guaranteed") + } + xs.push(1) + return xs[0] +} +"#, + ); + let output = render_tir(&db, file); + + assert!( + output.contains("let xs = [] : never[] -> string[] (evolving)"), + "expected xs to be established by the first push in the loop body, got:\n{output}" + ); + assert!( + output.contains("type mismatch: expected string, got int"), + "post-loop push should be checked against the loop-established element type, got:\n{output}" + ); +} diff --git a/baml_language/crates/baml_tests/src/compiler2_tir/phase7.rs b/baml_language/crates/baml_tests/src/compiler2_tir/phase7.rs index e91ea17a7b..e659912e26 100644 --- a/baml_language/crates/baml_tests/src/compiler2_tir/phase7.rs +++ b/baml_language/crates/baml_tests/src/compiler2_tir/phase7.rs @@ -351,6 +351,99 @@ fn assign_method_result_in_null_branch_works() { "#); } +#[test] +fn assignment_before_shadow_survives_scope_restore() { + let mut db = make_db(); + let file = db.add_file( + "test.baml", + r#"function f(x: int?) -> int { + { + x = 1; + let x: string = "shadow"; + x; + }; + return x; +}"#, + ); + let output = render_tir(&db, file); + assert!( + output.contains("return x : 1"), + "outer assignment before inner shadow should remain visible after block:\n{output}" + ); + assert!( + !output.contains("type mismatch"), + "assignment before shadow should satisfy the int return type:\n{output}" + ); +} + +#[test] +fn inner_declared_type_does_not_leak_after_shadow() { + let mut db = make_db(); + let file = db.add_file( + "test.baml", + r#"function f() -> int { + let x: int = 1; + { + let x: string = "shadow"; + x; + }; + x = 2; + return x; +}"#, + ); + let output = render_tir(&db, file); + assert!( + output.contains("x = 2 : 2"), + "outer declared type should be restored after inner typed shadow:\n{output}" + ); + assert!( + !output.contains("type mismatch"), + "inner declared type metadata should not constrain outer assignment:\n{output}" + ); +} + +#[test] +fn unannotated_inner_shadow_masks_outer_declared_type() { + let mut db = make_db(); + let file = db.add_file( + "test.baml", + r#"function f() -> int { + let x: int = 1; + { + let x = "shadow"; + x = "updated"; + }; + return x; +}"#, + ); + let output = render_tir(&db, file); + assert!( + !output.contains("type mismatch"), + "unannotated inner shadow should not be checked against outer annotation:\n{output}" + ); +} + +#[test] +fn early_return_narrowing_inside_nested_block_does_not_leak() { + let mut db = make_db(); + let file = db.add_file( + "test.baml", + r#"function f(x: int?) -> int? { + { + if (x == null) { + return 0; + } + }; + return x; +}"#, + ); + let output = render_tir(&db, file); + assert!( + output.contains("return x : int?"), + "early-return narrowing should be scoped to the nested block:\n{output}" + ); +} + // ── String type narrowing ───────────────────────────────────────────────────── #[test] diff --git a/baml_language/crates/baml_tests/tests/arrays.rs b/baml_language/crates/baml_tests/tests/arrays.rs index c9fe656777..e83094b2a6 100644 --- a/baml_language/crates/baml_tests/tests/arrays.rs +++ b/baml_language/crates/baml_tests/tests/arrays.rs @@ -230,6 +230,35 @@ async fn array_map_callback_throws() { assert_eq!(output.result, Ok(BexExternalValue::String("caught".into()))); } +/// array.map callback inside a catch base still resolves its own parameter, +/// even though the catch handler introduces a later lexical scope. +#[tokio::test] +async fn array_map_callback_in_catch_base_keeps_parameter_scope() { + let output = baml_test!( + r#" + function main() -> int[] { + let items: int[] = [1, 2, 3] + items.map((x: int) -> int { + if (x == 2) { x + 10 } else { x } + }) catch (e) { + _ => [0] + } + } + "# + ); + assert_eq!( + output.result, + Ok(BexExternalValue::Array { + element_type: Ty::int(), + items: vec![ + BexExternalValue::Int(1), + BexExternalValue::Int(12), + BexExternalValue::Int(3), + ], + }) + ); +} + /// array.map over string[] — exercises heap-object paths in MapContinuation /// (gc_roots, apply_forwarding) that int[] tests don't cover. #[tokio::test] diff --git a/baml_language/crates/baml_tests/tests/functions.rs b/baml_language/crates/baml_tests/tests/functions.rs index e0108c45e0..536c7c831f 100644 --- a/baml_language/crates/baml_tests/tests/functions.rs +++ b/baml_language/crates/baml_tests/tests/functions.rs @@ -261,7 +261,6 @@ async fn early_return() { } #[tokio::test] -#[ignore = "compiler2: duplicate binding error for let-in-nested-scope variables"] async fn early_return_from_nested_scopes() { let output = baml_test!( r#" @@ -316,20 +315,22 @@ async fn early_return_from_nested_scopes() { L2: load_const 7 - return + jump L6 L3: load_const true pop_jump_if_false L1 load_const 0 - return + jump L6 L4: load_const 0 - return + jump L6 L5: load_const 0 + + L6: return } "); diff --git a/baml_language/crates/baml_tests/tests/if_else.rs b/baml_language/crates/baml_tests/tests/if_else.rs index 63351e8000..a086695a16 100644 --- a/baml_language/crates/baml_tests/tests/if_else.rs +++ b/baml_language/crates/baml_tests/tests/if_else.rs @@ -237,7 +237,6 @@ async fn if_else_with_parameter() { } #[tokio::test] -#[ignore = "compiler2: duplicate binding error for let-in-if-branch variables"] async fn if_else_return_expr_with_locals() { let output = baml_test! { baml: " @@ -309,7 +308,6 @@ async fn if_else_assignment_with_param() { } #[tokio::test] -#[ignore = "compiler2: duplicate binding error for let-in-if-branch variables"] async fn if_else_assignment_with_locals() { let output = baml_test! { baml: " @@ -1200,7 +1198,6 @@ async fn block_expression() { // ============================================================================ #[tokio::test] -#[ignore = "compiler2: duplicate binding error for let-in-if-branch variables"] async fn if_else_statement() { let output = baml_test! { baml: " @@ -1239,18 +1236,14 @@ async fn if_else_statement() { jump L1 L0: - load_const 3 - store_var x load_const 4 - call identity + call user.identity pop 1 jump L2 L1: - load_const 2 - store_var y load_const 1 - call identity + call user.identity pop 1 L2: diff --git a/baml_language/crates/baml_tests/tests/lambdas.rs b/baml_language/crates/baml_tests/tests/lambdas.rs index b0ddc8d5e8..b56461d7be 100644 --- a/baml_language/crates/baml_tests/tests/lambdas.rs +++ b/baml_language/crates/baml_tests/tests/lambdas.rs @@ -287,6 +287,25 @@ async fn explicit_throwing_lambda_catches_error() { assert_eq!(output.result, Ok(BexExternalValue::Int(-1))); } +#[tokio::test] +async fn lambda_inside_catch_base_keeps_parameter_scope() { + let output = baml_test!( + r#" + function main() -> int { + { + let f = (x: int) -> int { + if (x == 7) { x } else { 0 } + } + f(7) + } catch (x) { + _ => x + } + } + "# + ); + assert_eq!(output.result, Ok(BexExternalValue::Int(7))); +} + /// Deep nesting (3 levels) with transitive captures at each level. /// a in main, b param of f, c param of g, d param of h. /// a + b + c + d = 1 + 10 + 100 + 1000 = 1111 @@ -359,7 +378,6 @@ async fn issue_e_method_resolution_different_types() { /// let x = 1; let g captures x (=1); let x = "shadow"; let f captures x (="shadow") /// Both lambdas should capture the correct x for their position. #[tokio::test] -#[ignore = "BAML disallows variable shadowing; test kept for when shadowing is added"] async fn issue_f_shadowing_capture_correct_binding() { let output = baml_test!( " @@ -379,7 +397,6 @@ async fn issue_f_shadowing_capture_correct_binding() { /// Issue F (variant): shadowed capture with mutation. /// The first x should be independently cell-wrapped from the second x. #[tokio::test] -#[ignore = "BAML disallows variable shadowing; test kept for when shadowing is added"] async fn issue_f_shadowing_capture_independent_cells() { let output = baml_test!( " @@ -527,3 +544,31 @@ async fn issue_b_watch_let_mutated_by_parent_and_lambda() { // counter: 0 → 1 (parent) → 11 (lambda) assert_eq!(output.result, Ok(BexExternalValue::Int(11))); } + +/// Lambda parameter shadowing an annotated outer let. The lambda param's +/// declared type must replace any outer entry in `declared_types` so that +/// assignments to the param inside the body type-check against the param's +/// type, not the shadowed outer's. Previously `infer_lambda_body` seeded +/// params via `add_local`, which used `or_insert_with` for `declared_types` +/// and therefore preserved the outer entry — causing a phantom TypeMismatch +/// when the param is reassigned to a value of its declared type. With the +/// bug present, `compile_source` would panic via `assert_no_diagnostic_errors` +/// before reaching execution. +#[tokio::test] +async fn lambda_param_shadows_annotated_outer_local() { + let output = baml_test!( + r#" + function main() -> int { + let x: int = 7; + let f = (x: string) -> int { + x = "world"; + x.length() + }; + f("hi") + x + } + "# + ); + // Inside f, x is reassigned to "world" (length 5). Outer x is unchanged + // (the lambda param shadows the outer binding entirely). 5 + 7 = 12. + assert_eq!(output.result, Ok(BexExternalValue::Int(12))); +} diff --git a/baml_language/crates/baml_tests/tests/lexical_scoping.rs b/baml_language/crates/baml_tests/tests/lexical_scoping.rs new file mode 100644 index 0000000000..015feaf6ca --- /dev/null +++ b/baml_language/crates/baml_tests/tests/lexical_scoping.rs @@ -0,0 +1,316 @@ +//! Runtime regressions for lexical block scope and local shadowing. + +use baml_tests::baml_test; +use bex_engine::BexExternalValue; + +const LEXICAL_SCOPE_RUNTIME_REGRESSIONS: &str = r#" +function same_scope_shadow() -> int { + let x = 1 + let x = 2 + x +} + +function outer_restored() -> int { + let x = 1 + { + let x = 2 + } + x +} + +function initializer_uses_previous() -> int { + let x = 1 + let x = x + 1 + x +} + +function shadow_param(x: int) -> int { + let x = x + 1 + x +} + +function for_loop_restores_outer() -> int { + let x = 1 + for (let x in [2, 3]) { + x + } + x +} + +function nested_outer_restored() -> int { + let x = 1 + { + let x = 2 + { + let x = 3 + x + } + x + } + x +} + +function capture_before_after_shadow() -> int { + let x = 1 + let g = () -> int { x } + let x = 2 + let f = () -> int { x } + g() * 10 + f() +} + +// Rule 1: a `let` declared inside a while body must not leak past the +// loop. After the loop, `x` resolves to the outer binding. +function rule1_while_no_leakage() -> int { + let x = 1 + let once = true + while (once) { + let x = 99 + once = false + } + x +} + +// Observe the inner shadow's value to rule out optimizer-induced false +// positives in `rule1_while_no_leakage`. If outer x and inner x were +// conflated (a shadowing bug), `observed` would still see 99 but the +// return would be `99 + 99 = 198`, not `1 + 99 = 100`. +function rule1_while_observed_inner_shadow() -> int { + let x = 1 + let once = true + let observed = 0 + while (once) { + let x = 99 + observed = observed + x + once = false + } + x + observed +} + +// Rule 2: outer-binding mutation in an inner block escapes. +function rule2_block_outer_mutation_escapes() -> int { + let x = 1 + { + x = 2 + } + x +} + +// Rule 2: outer-binding mutation in a `for` body escapes. +function rule2_for_outer_mutation_escapes() -> int { + let x = 1 + for (let _ in [1]) { + x = 2 + } + x +} + +// Rule 3: a block-local shadow's mutation must NOT escape. Without +// binding-identity-keyed assignment tracking, the inner `x = 3` would +// conflate with an outer `x` mutation and propagate. +function rule3_block_shadow_then_assign_inner_does_not_escape() -> int { + let x = 1 + { + let x = 2 + x = 3 + } + x +} + +// Composite test: a pre-shadow outer mutation escapes; a post-shadow +// inner mutation does not. A name-keyed assignment tracker cannot +// distinguish these — the binding-identity-keyed tracker can. +function rule2_pre_shadow_mutation_escapes_post_shadow_does_not() -> int { + let x = 1 + { + x = 2 + let x = 3 + x = 4 + } + x +} +"#; + +async fn assert_lexical_scope_result(entry: &str, expected: i64) { + let source = format!( + r#" + {LEXICAL_SCOPE_RUNTIME_REGRESSIONS} + + function main() -> int {{ + {entry} + }} + "# + ); + let output = baml_test!(&source); + + assert_eq!( + output.result, + Ok(BexExternalValue::Int(expected)), + "unexpected lexical scoping result for `{entry}`" + ); +} + +#[tokio::test] +async fn lexical_scoping_runtime_regressions() { + assert_lexical_scope_result("same_scope_shadow()", 2).await; + assert_lexical_scope_result("outer_restored()", 1).await; + assert_lexical_scope_result("initializer_uses_previous()", 2).await; + assert_lexical_scope_result("shadow_param(4)", 5).await; + assert_lexical_scope_result("for_loop_restores_outer()", 1).await; + assert_lexical_scope_result("nested_outer_restored()", 1).await; + assert_lexical_scope_result("capture_before_after_shadow()", 12).await; + assert_lexical_scope_result("rule1_while_no_leakage()", 1).await; + assert_lexical_scope_result("rule1_while_observed_inner_shadow()", 100).await; + assert_lexical_scope_result("rule2_block_outer_mutation_escapes()", 2).await; + assert_lexical_scope_result("rule2_for_outer_mutation_escapes()", 2).await; + assert_lexical_scope_result("rule3_block_shadow_then_assign_inner_does_not_escape()", 1).await; + assert_lexical_scope_result( + "rule2_pre_shadow_mutation_escapes_post_shadow_does_not()", + 2, + ) + .await; +} + +#[tokio::test] +async fn declared_type_restored_across_scope() { + // Verifies that a typed outer binding is restored after an inner scope + // shadows it with a different declared type. The inner `let x: int = 1` + // exists only inside the block; after the block, `x` resolves back to + // the outer `string` binding. + // + // (The previous `let _ = ...; _` half of this test was removed: under + // the new patterns backend (PR BoundaryML/baml#3417, "implement new + // patterns backend without parser support"), `_` is canonicalized to + // a wildcard at AST construction and is no longer a referenceable name.) + let string_output = baml_test!( + r#" + function main() -> string { + let x: string = "outer" + { + let x: int = 1 + } + x + } + "# + ); + + assert_eq!( + string_output.result, + Ok(BexExternalValue::String("outer".to_string())) + ); +} + +#[tokio::test] +async fn lambdas_capture_match_and_catch_pattern_bindings() { + let output = baml_test!( + r#" + function throw_string() -> string { + throw "caught" + } + + function capture_match_arm() -> string { + match ("matched") { + s: string => { + let f = () -> string { s } + f() + } + } + } + + function capture_catch_clause_binding() -> string { + throw_string() catch (e) { + _: string => { + let f = () -> string { e } + f() + } + } + } + + function capture_catch_arm_binding() -> string { + throw_string() catch (e) { + s: string => { + let f = () -> string { s } + f() + } + } + } + + function main() -> string { + capture_match_arm() + ":" + capture_catch_clause_binding() + ":" + capture_catch_arm_binding() + } + "# + ); + + assert_eq!( + output.result, + Ok(BexExternalValue::String( + "matched:caught:caught".to_string() + )) + ); +} + +#[tokio::test] +async fn match_and_catch_pattern_bindings_restore_outer_locals() { + let output = baml_test!( + r#" + function throw_string(s: string) -> string { + throw s + } + + function match_post_match_restores_outer() -> int { + let x = 10 + let _matched = match (1) { + x: int => x + } + x + } + + function match_later_arm_uses_outer() -> int { + let x = 20 + match ("value") { + x: int => 0, + _: string => x + } + } + + function catch_base_uses_outer_lambda() -> int { + let e = () -> string { "base" } + let caught = throw_string(e()) catch (e) { + _: string => e + } + match (caught) { + "base" => 3, + _ => 0 + } + } + + function main() -> int { + match_post_match_restores_outer() + + match_later_arm_uses_outer() * 100 + + catch_base_uses_outer_lambda() * 10000 + } + "# + ); + + assert_eq!(output.result, Ok(BexExternalValue::Int(32_010))); +} + +#[tokio::test] +async fn multi_clause_catch_uses_clause_local_binding() { + let output = baml_test!( + r#" + function fail() -> int { + throw 7 + } + + function main() -> int { + fail() catch (first) { + _: string => 1 + } catch (second) { + _: int => second + } + } + "# + ); + + assert_eq!(output.result, Ok(BexExternalValue::Int(7))); +} diff --git a/baml_language/crates/baml_tests/tests/watch.rs b/baml_language/crates/baml_tests/tests/watch.rs index 8ee1be3cf2..a9fe1ca705 100644 --- a/baml_language/crates/baml_tests/tests/watch.rs +++ b/baml_language/crates/baml_tests/tests/watch.rs @@ -35,6 +35,7 @@ async fn watch_primitive() { load_const 1 store_var value load_var value + unwatch value return } "#); @@ -71,6 +72,7 @@ async fn watch_primitive_nested_scope() { L0: load_var value + unwatch value return } "#); @@ -105,6 +107,7 @@ async fn watch_default_filter() { load_const 6 store_var value load_var value + unwatch value return } "#); @@ -245,6 +248,7 @@ async fn watch_alias() { store_field .x load_var point load_field .x + unwatch point return } "#); @@ -290,6 +294,7 @@ async fn watch_alias_nested_scope() { L0: load_var point load_field .x + unwatch point return } "#); @@ -333,6 +338,7 @@ async fn watch_scope_exit() { store_field .x load_var point store_var outter_point + unwatch point load_var outter_point load_const 2 store_field .x @@ -345,6 +351,921 @@ async fn watch_scope_exit() { assert_eq!(output.result, Ok(BexExternalValue::Int(2))); } +// ============================================================================ +// Watch teardown across abnormal exits +// ============================================================================ +// +// These tests pin the `unwatch` emission for `break`, `continue`, and early +// `return` so the helper that consolidates them (folding the inline +// loops into `emit_unwatch_to_depth`) cannot regress behavior. They also +// document the per-iteration semantic for `continue`: the watch is re-issued +// at the top of the next iteration, not held for the whole loop. + +#[tokio::test] +async fn watch_break_unwatches() { + // Expected notifications: [["x"]] + // (iter 1 assigns x = 10 → notify; iter 2 hits break before any assignment.) + // + // unwatch x must precede the goto to the loop exit so the watcher is + // torn down before iteration ends. + let output = baml_test!( + r#" + function main() -> int { + let total = 0; + for (let i in [1, 2, 3]) { + watch let x = i; + if (x > 1) { + break; + } + x = x + 9; + total = total + x; + } + total + } + "# + ); + + insta::assert_snapshot!(output.bytecode, @r#" + function main() -> int { + load_const 0 + store_var total + load_const 1 + load_const 2 + load_const 3 + alloc_array 3 + store_var _2 + load_const 0 + store_var __for_idx + + L0: + load_var __for_idx + load_var _2 + call baml.Array.length + cmp_op < + pop_jump_if_false L3 + load_var _2 + load_var __for_idx + load_array_element + store_var x + load_const "x" + load_const null + watch x + load_var x + load_const 1 + cmp_op > + pop_jump_if_false L1 + jump L2 + + L1: + load_var x + load_const 9 + bin_op + + store_var x + load_var total + load_var x + bin_op + + store_var total + unwatch x + load_var __for_idx + load_const 1 + bin_op + + store_var __for_idx + jump L0 + + L2: + unwatch x + + L3: + load_var total + return + } + "#); + + assert_eq!(output.result, Ok(BexExternalValue::Int(10))); +} + +#[tokio::test] +async fn watch_continue_unwatches() { + // Expected notifications: [["x"], ["x"]] + // (iter 1 assigns x = 11; iter 2 hits continue before assigning; iter 3 + // assigns x = 13. Each `watch let x = i` re-issues the watcher at the top + // of its iteration, so unwatch on continue is per-iteration, not + // permanent for the loop.) + // + // unwatch x must precede the goto to the continue target (the increment + // step), AND must also fire on normal fallthrough at end of body. + let output = baml_test!( + r#" + function main() -> int { + let total = 0; + for (let i in [1, 2, 3]) { + watch let x = i; + if (x == 2) { + continue; + } + x = x + 10; + total = total + x; + } + total + } + "# + ); + + insta::assert_snapshot!(output.bytecode, @r#" + function main() -> int { + load_const 0 + store_var total + load_const 1 + load_const 2 + load_const 3 + alloc_array 3 + store_var _2 + load_const 0 + store_var __for_idx + + L0: + load_var __for_idx + load_var _2 + call baml.Array.length + cmp_op < + pop_jump_if_false L1 + jump L2 + + L1: + load_var total + return + + L2: + load_var _2 + load_var __for_idx + load_array_element + store_var x + load_const "x" + load_const null + watch x + load_var x + load_const 2 + cmp_op == + pop_jump_if_false L3 + jump L4 + + L3: + load_var x + load_const 10 + bin_op + + store_var x + load_var total + load_var x + bin_op + + store_var total + unwatch x + jump L5 + + L4: + unwatch x + + L5: + load_var __for_idx + load_const 1 + bin_op + + store_var __for_idx + jump L0 + } + "#); + + assert_eq!(output.result, Ok(BexExternalValue::Int(24))); +} + +#[tokio::test] +async fn watch_early_return_unwatches() { + // Expected notifications: [["x"]] + // (x = 42 notifies; the return path then unwatches before exiting.) + // + // unwatch x must precede the goto to the function's exit block. + let output = baml_test!( + r#" + function main() -> int { + watch let x = 0; + x = 42; + if (true) { + return x; + } + x = 99; + x + } + "# + ); + + insta::assert_snapshot!(output.bytecode, @r#" + function main() -> int { + load_const 0 + store_var x + load_const "x" + load_const null + watch x + load_const 42 + store_var x + load_const true + pop_jump_if_false L0 + jump L1 + + L0: + load_const 99 + store_var x + load_var x + unwatch x + jump L2 + + L1: + load_var x + unwatch x + + L2: + return + } + "#); + + assert_eq!(output.result, Ok(BexExternalValue::Int(42))); +} + +// ============================================================================ +// Watch teardown across throw and arm-body fallthrough +// ============================================================================ +// +// `throw` and match/catch arm-body fallthrough are exit paths that previously +// did not emit `unwatch` ops: +// - `Stmt::Throw` in MIR went straight to a dead block (lower.rs:3884-3889). +// - Match arm bodies (lower.rs:4769-4775, 4785-4797) and catch arm bodies +// (lower.rs:5343-5354) only restored locals; they did not unwatch +// arm-declared `watch let`s before the goto-to-join. +// +// These tests pin the corrected behavior via bytecode snapshots and verify +// the function still produces the expected runtime result. + +#[tokio::test] +async fn watch_throw_unwatches() { + // Expected notifications: [["x"]] + // (x = 5 notifies before throw; the unwatch then runs before the throw + // terminator so the watcher is torn down on the divergent path.) + let output = baml_test!( + r#" + function fails() -> int { + watch let x = 0; + x = 5; + throw "boom"; + } + + function main() -> int { + fails() catch (e) { + "boom" => 99, + _ => -1, + } + } + "# + ); + + insta::assert_snapshot!(output.bytecode, @r#" + function fails() -> int { + load_const 0 + store_var x + load_const "x" + load_const null + watch x + load_const 5 + store_var x + unwatch x + load_const "boom" + throw + } + + function main() -> int { + call user.fails + jump L2 + load_var e + load_const "boom" + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_var e + throw_if_panic + load_const 1 + unary_op - + jump L2 + + L1: + load_const 99 + + L2: + return + } + "#); + + assert_eq!(output.result, Ok(BexExternalValue::Int(99))); +} + +#[tokio::test] +async fn watch_for_throw_unwatches() { + // Expected notifications: [["x"]] + // (iter 1 assigns x = 10 → notify; iter 2 throws — the throw must unwatch + // x before the throw terminator. The watch is also re-issued each iteration.) + let output = baml_test!( + r#" + function fails() -> int { + for (let i in [1, 2, 3]) { + watch let x = i; + if (x == 2) { + throw "boom"; + } + x = x + 9; + } + 0 + } + + function main() -> int { + fails() catch (e) { + "boom" => 99, + _ => -1, + } + } + "# + ); + + insta::assert_snapshot!(output.bytecode, @r#" + function fails() -> int { + load_const 1 + load_const 2 + load_const 3 + alloc_array 3 + store_var _1 + load_const 0 + store_var __for_idx + + L0: + load_var __for_idx + load_var _1 + call baml.Array.length + cmp_op < + pop_jump_if_false L1 + jump L2 + + L1: + load_const 0 + return + + L2: + load_var _1 + load_var __for_idx + load_array_element + store_var x + load_const "x" + load_const null + watch x + load_var x + load_const 2 + cmp_op == + pop_jump_if_false L3 + jump L4 + + L3: + load_var x + load_const 9 + bin_op + + store_var x + unwatch x + load_var __for_idx + load_const 1 + bin_op + + store_var __for_idx + jump L0 + + L4: + unwatch x + load_const "boom" + throw + } + + function main() -> int { + call user.fails + jump L2 + load_var e + load_const "boom" + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_var e + throw_if_panic + load_const 1 + unary_op - + jump L2 + + L1: + load_const 99 + + L2: + return + } + "#); + + assert_eq!(output.result, Ok(BexExternalValue::Int(99))); +} + +#[tokio::test] +async fn watch_while_throw_unwatches() { + // Expected notifications: [["x"]] + // (Same shape as `watch_for_throw_unwatches` but with a `while` loop. + // The MIR's while-body lowering pushes a Block expression for the body, + // so the `watch let x` snapshot/teardown is anchored at the MIR layer.) + let output = baml_test!( + r#" + function fails() -> int { + let i = 0; + while (i < 3) { + watch let x = i; + if (x == 1) { + throw "boom"; + } + x = x + 10; + i = i + 1; + } + 0 + } + + function main() -> int { + fails() catch (e) { + "boom" => 99, + _ => -1, + } + } + "# + ); + + insta::assert_snapshot!(output.bytecode, @r#" + function fails() -> int { + load_const 0 + store_var i + + L0: + load_var i + load_const 3 + cmp_op < + pop_jump_if_false L1 + jump L2 + + L1: + load_const 0 + return + + L2: + load_var i + store_var x + load_const "x" + load_const null + watch x + load_var x + load_const 1 + cmp_op == + pop_jump_if_false L3 + jump L4 + + L3: + load_var x + load_const 10 + bin_op + + store_var x + load_var i + load_const 1 + bin_op + + store_var i + unwatch x + jump L0 + + L4: + unwatch x + load_const "boom" + throw + } + + function main() -> int { + call user.fails + jump L2 + load_var e + load_const "boom" + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_var e + throw_if_panic + load_const 1 + unary_op - + jump L2 + + L1: + load_const 99 + + L2: + return + } + "#); + + assert_eq!(output.result, Ok(BexExternalValue::Int(99))); +} + +#[tokio::test] +async fn watch_match_arm_throw_unwatches() { + // Expected notifications: [["x"]] + // (The match arm declares a watch and assigns to it before throwing. + // The throw path must unwatch x before the throw terminator.) + let output = baml_test!( + r#" + function fails(input: int) -> int { + match (input) { + 1 => { + watch let x = 0; + x = 5; + throw "boom" + } + _ => 0 + } + } + + function main() -> int { + fails(1) catch (e) { + "boom" => 99, + _ => -1, + } + } + "# + ); + + insta::assert_snapshot!(output.bytecode, @r#" + function fails(input: int) -> int { + load_var input + load_const 1 + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_const 0 + return + + L1: + load_const 0 + store_var x + load_const "x" + load_const null + watch x + load_const 5 + store_var x + unwatch x + load_const "boom" + throw + } + + function main() -> int { + load_const 1 + call user.fails + jump L2 + load_var e + load_const "boom" + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_var e + throw_if_panic + load_const 1 + unary_op - + jump L2 + + L1: + load_const 99 + + L2: + return + } + "#); + + assert_eq!(output.result, Ok(BexExternalValue::Int(99))); +} + +#[tokio::test] +async fn watch_catch_arm_throw_unwatches() { + // Expected notifications: [["x"]] + // (The catch arm body declares a watch, assigns to it, then re-throws. + // The throw path must unwatch x before re-throwing — otherwise the + // arm-scoped watch leaks past the function.) + let output = baml_test!( + r#" + function inner() -> int { + throw "first"; + } + + function fails() -> int { + inner() catch (e) { + _ => { + watch let x = 0; + x = 5; + throw "boom" + } + } + } + + function main() -> int { + fails() catch (e) { + "boom" => 99, + _ => -1, + } + } + "# + ); + + insta::assert_snapshot!(output.bytecode, @r#" + function fails() -> int { + call user.inner + jump L0 + load_var e + throw_if_panic + load_const 0 + store_var x + load_const "x" + load_const null + watch x + load_const 5 + store_var x + unwatch x + load_const "boom" + throw + + L0: + return + } + + function inner() -> int { + load_const "first" + throw + } + + function main() -> int { + call user.fails + jump L2 + load_var e + load_const "boom" + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_var e + throw_if_panic + load_const 1 + unary_op - + jump L2 + + L1: + load_const 99 + + L2: + return + } + "#); + + assert_eq!(output.result, Ok(BexExternalValue::Int(99))); +} + +#[tokio::test] +async fn watch_match_arm_fallthrough_unwatches() { + // Expected notifications: [["x"]] + // (The match arm declares a watch let, assigns to it, and falls through + // to the join. `unwatch x` must precede the goto to the join, otherwise + // the arm-scoped watch leaks for the rest of the function.) + // + // After the match expression returns, a subsequent assignment to a + // distinct outer var must NOT notify on channel "x". + let output = baml_test!( + r#" + function entry(input: int) -> int { + let result = match (input) { + 1 => { + watch let x = 0; + x = 5; + x + } + _ => 0 + }; + // If the arm-scoped watch leaked, this assignment would be + // observed by an `x` watcher. After the arm, x must already be + // unwatched. + let result2 = result + 1; + result2 + } + + function main() -> int { + entry(1) + } + "# + ); + + insta::assert_snapshot!(output.bytecode, @r#" + function entry(input: int) -> int { + load_var input + load_const 1 + cmp_op == + pop_jump_if_false L0 + jump L1 + + L0: + load_const 0 + store_var result + jump L2 + + L1: + load_const 0 + store_var x + load_const "x" + load_const null + watch x + load_const 5 + store_var x + load_var x + store_var result + unwatch x + + L2: + load_var result + load_const 1 + bin_op + + return + } + + function main() -> int { + load_const 1 + call user.entry + return + } + "#); + + assert_eq!(output.result, Ok(BexExternalValue::Int(6))); +} + +#[tokio::test] +async fn watch_switch_arm_fallthrough_unwatches() { + // Expected notifications: [["x"]] + // + // Four dense int arms drive `try_lower_as_switch` (lower.rs:4351), which + // emits a Switch terminator and lowers the matching arm body in + // `try_lower_as_switch` itself rather than `lower_match_chain`. The arm + // body declares a `watch let x` and falls through to the join — the + // watcher must be torn down before the goto-to-join so it does not leak + // past the arm. After the match returns, an assignment to a distinct + // outer variable must NOT be observed by an `x` watcher. + let output = baml_test!( + r#" + function entry(input: int) -> int { + let result = match (input) { + 0 => 100, + 1 => { + watch let x = 0; + x = 5; + x + } + 2 => 102, + 3 => 103, + _ => 999 + }; + // If the arm-scoped watch leaked past the arm, this assignment + // would be observed by an `x` watcher. After the arm, x must + // already be unwatched. + let result2 = result + 1; + result2 + } + + function main() -> int { + entry(1) + } + "# + ); + + insta::assert_snapshot!(output.bytecode, @r#" + function entry(input: int) -> int { + load_var input + jump_table [L4, L3, L2, L1], default L0 + + L0: + load_const 999 + store_var result + jump L5 + + L1: 3 + load_const 103 + store_var result + jump L5 + + L2: 2 + load_const 102 + store_var result + jump L5 + + L3: 1 + load_const 0 + store_var x + load_const "x" + load_const null + watch x + load_const 5 + store_var x + load_var x + store_var result + unwatch x + jump L5 + + L4: 0 + load_const 100 + store_var result + + L5: + load_var result + load_const 1 + bin_op + + return + } + + function main() -> int { + load_const 1 + call user.entry + return + } + "#); + + assert_eq!(output.result, Ok(BexExternalValue::Int(6))); +} + +#[tokio::test] +async fn watch_catch_arm_fallthrough_unwatches() { + // Expected notifications: [["x"]] + // (Same shape as watch_match_arm_fallthrough_unwatches, but the watch + // is declared inside a catch arm body that falls through to the join.) + let output = baml_test!( + r#" + function fails() -> int { + throw "boom"; + } + + function main() -> int { + let result = fails() catch (e) { + _ => { + watch let x = 0; + x = 5; + x + } + }; + let result2 = result + 1; + result2 + } + "# + ); + + insta::assert_snapshot!(output.bytecode, @r#" + function fails() -> int { + load_const "boom" + throw + } + + function main() -> int { + call user.fails + store_var result + jump L0 + load_var e + throw_if_panic + load_const 0 + store_var x + load_const "x" + load_const null + watch x + load_const 5 + store_var x + load_var x + store_var result + unwatch x + + L0: + load_var result + load_const 1 + bin_op + + return + } + "#); + + assert_eq!(output.result, Ok(BexExternalValue::Int(6))); +} + // ============================================================================ // Watch with function calls and nested objects // ============================================================================ @@ -406,6 +1327,7 @@ async fn watch_function_call_modifications() { load_var point load_field .y bin_op + + unwatch point return } "#); @@ -484,6 +1406,7 @@ async fn watch_nested_object_added() { load_field .p load_field .x load_field .value + unwatch vec return } "#); @@ -564,6 +1487,7 @@ async fn watch_nested_object_removed() { load_field .p load_field .x load_field .value + unwatch vec return } "#); @@ -668,6 +1592,8 @@ async fn watch_cyclic_graph() { load_var v3 load_const 30 store_field .value + unwatch v4 + unwatch v2 load_const 0 return } diff --git a/baml_language/crates/tools_onionskin/src/compiler.rs b/baml_language/crates/tools_onionskin/src/compiler.rs index 9c9f1686c5..6df10044de 100644 --- a/baml_language/crates/tools_onionskin/src/compiler.rs +++ b/baml_language/crates/tools_onionskin/src/compiler.rs @@ -1543,11 +1543,12 @@ impl CompilerRunner { for (name, idx) in &bindings.params { file_detail.push(format!("{indent} param[{idx}]: {name}")); } - for (name, _site, range) in &bindings.bindings { + for binding in &bindings.bindings { file_detail.push(format!( - "{indent} let {name} {}..{}", - u32::from(range.start()), - u32::from(range.end()), + "{indent} let {} {}..{}", + binding.name, + u32::from(binding.name_range.start()), + u32::from(binding.name_range.end()), )); } }